Skip to content

Commit ceab625

Browse files
committed
httpapi.WebsocketCloseMessage
1 parent 82ea13c commit ceab625

File tree

6 files changed

+45
-49
lines changed

6 files changed

+45
-49
lines changed

coderd/coderd.go

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
package coderd
22

33
import (
4-
"fmt"
54
"net/http"
65
"net/url"
7-
"strings"
86
"sync"
97
"time"
108

@@ -201,21 +199,3 @@ type api struct {
201199
websocketWaitMutex sync.Mutex
202200
websocketWaitGroup sync.WaitGroup
203201
}
204-
205-
const websocketCloseMaxLen = 123
206-
207-
// fmtWebsocketCloseMsg formats a websocket close message and ensures it is
208-
// truncated to the maximum allowed length.
209-
func FmtWebsocketCloseMsg(format string, vars ...any) string {
210-
msg := fmt.Sprintf(format, vars...)
211-
212-
// Cap msg length at 123 bytes. nhooyr/websocket only allows close messages
213-
// of this length.
214-
if len(msg) > websocketCloseMaxLen {
215-
// Trim the string to 123 bytes. If we accidentally cut in the middle of
216-
// a UTF-8 character, remove it from the string.
217-
return strings.ToValidUTF8(string(msg[123]), "")
218-
}
219-
220-
return msg
221-
}

coderd/coderd_test.go

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,11 @@
11
package coderd_test
22

33
import (
4-
"strings"
54
"testing"
65

7-
"github.com/stretchr/testify/assert"
86
"go.uber.org/goleak"
9-
10-
"github.com/coder/coder/coderd"
117
)
128

139
func TestMain(m *testing.M) {
1410
goleak.VerifyTestMain(m)
1511
}
16-
17-
func TestFmtWebsocketCloseMsg(t *testing.T) {
18-
t.Parallel()
19-
20-
t.Run("TruncateSingleByteCharacters", func(t *testing.T) {
21-
t.Parallel()
22-
23-
msg := strings.Repeat("d", 255)
24-
trunc := coderd.FmtWebsocketCloseMsg(msg)
25-
assert.LessOrEqual(t, len(trunc), 123)
26-
})
27-
28-
t.Run("TruncateMultiByteCharacters", func(t *testing.T) {
29-
t.Parallel()
30-
31-
msg := strings.Repeat("こんにちは", 10)
32-
trunc := coderd.FmtWebsocketCloseMsg(msg)
33-
assert.LessOrEqual(t, len(trunc), 123)
34-
})
35-
}

coderd/httpapi/httpapi.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,21 @@ func Read(rw http.ResponseWriter, r *http.Request, value interface{}) bool {
115115
}
116116
return true
117117
}
118+
119+
const websocketCloseMaxLen = 123
120+
121+
// WebsocketCloseMsg formats a websocket close message and ensures it is
122+
// truncated to the maximum allowed length.
123+
func WebsocketCloseMsg(format string, vars ...any) string {
124+
msg := fmt.Sprintf(format, vars...)
125+
126+
// Cap msg length at 123 bytes. nhooyr/websocket only allows close messages
127+
// of this length.
128+
if len(msg) > websocketCloseMaxLen {
129+
// Trim the string to 123 bytes. If we accidentally cut in the middle of
130+
// a UTF-8 character, remove it from the string.
131+
return strings.ToValidUTF8(string(msg[123]), "")
132+
}
133+
134+
return msg
135+
}

coderd/httpapi/httpapi_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ import (
55
"encoding/json"
66
"net/http"
77
"net/http/httptest"
8+
"strings"
89
"testing"
910

11+
"github.com/stretchr/testify/assert"
1012
"github.com/stretchr/testify/require"
1113

1214
"github.com/coder/coder/coderd/httpapi"
@@ -142,3 +144,23 @@ func TestReadUsername(t *testing.T) {
142144
})
143145
}
144146
}
147+
148+
func WebsocketCloseMsg(t *testing.T) {
149+
t.Parallel()
150+
151+
t.Run("TruncateSingleByteCharacters", func(t *testing.T) {
152+
t.Parallel()
153+
154+
msg := strings.Repeat("d", 255)
155+
trunc := httpapi.WebsocketCloseMsg(msg)
156+
assert.LessOrEqual(t, len(trunc), 123)
157+
})
158+
159+
t.Run("TruncateMultiByteCharacters", func(t *testing.T) {
160+
t.Parallel()
161+
162+
msg := strings.Repeat("こんにちは", 10)
163+
trunc := httpapi.WebsocketCloseMsg(msg)
164+
assert.LessOrEqual(t, len(trunc), 123)
165+
})
166+
}

coderd/provisionerdaemons.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func (api *api) provisionerDaemonsListen(rw http.ResponseWriter, r *http.Request
5656
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform},
5757
})
5858
if err != nil {
59-
_ = conn.Close(websocket.StatusInternalError, FmtWebsocketCloseMsg("insert provisioner daemon: %s", err))
59+
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseMsg("insert provisioner daemon: %s", err))
6060
return
6161
}
6262

@@ -67,7 +67,7 @@ func (api *api) provisionerDaemonsListen(rw http.ResponseWriter, r *http.Request
6767
config.LogOutput = io.Discard
6868
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
6969
if err != nil {
70-
_ = conn.Close(websocket.StatusInternalError, FmtWebsocketCloseMsg("multiplex server: %s", err))
70+
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseMsg("multiplex server: %s", err))
7171
return
7272
}
7373
mux := drpcmux.New()
@@ -80,13 +80,13 @@ func (api *api) provisionerDaemonsListen(rw http.ResponseWriter, r *http.Request
8080
Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
8181
})
8282
if err != nil {
83-
_ = conn.Close(websocket.StatusInternalError, FmtWebsocketCloseMsg("drpc register provisioner daemon: %s", err))
83+
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseMsg("drpc register provisioner daemon: %s", err))
8484
return
8585
}
8686
server := drpcserver.New(mux)
8787
err = server.Serve(r.Context(), session)
8888
if err != nil {
89-
_ = conn.Close(websocket.StatusInternalError, FmtWebsocketCloseMsg("serve: %s", err))
89+
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseMsg("serve: %s", err))
9090
return
9191
}
9292
_ = conn.Close(websocket.StatusGoingAway, "")

coderd/workspaceresources.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ func (api *api) workspaceResourceDial(rw http.ResponseWriter, r *http.Request) {
108108
Pubsub: api.Pubsub,
109109
})
110110
if err != nil {
111-
_ = conn.Close(websocket.StatusInternalError, FmtWebsocketCloseMsg("serve: %s", err))
111+
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseMsg("serve: %s", err))
112112
return
113113
}
114114
}

0 commit comments

Comments
 (0)