Skip to content

Commit 71b3b5d

Browse files
committed
feat: Add MaxErrorResponseBodyBytes option to control HTTP error body capture in Dial
1 parent 8bf6dd2 commit 71b3b5d

File tree

2 files changed

+126
-3
lines changed

2 files changed

+126
-3
lines changed

dial.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ type DialOptions struct {
4848
// for CompressionContextTakeover.
4949
CompressionThreshold int
5050

51+
// MaxErrorResponseBodyBytes controls how many bytes of the HTTP response body
52+
// are captured and made available via resp.Body when the WebSocket handshake
53+
// fails (i.e. Dial returns a non-nil error after receiving an HTTP response).
54+
//
55+
// Semantics:
56+
// 0 => preserve current behavior and capture up to 1024 bytes (default)
57+
// >0 => capture up to that many bytes
58+
// <0 => do not capture any bytes; resp.Body will remain nil on error
59+
//
60+
// Regardless of this setting, the original HTTP response body is always closed.
61+
MaxErrorResponseBodyBytes int
62+
5163
// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
5264
//
5365
// The payload contains the application data of the ping frame.
@@ -110,7 +122,8 @@ func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context
110122
// You never need to close resp.Body yourself.
111123
//
112124
// If an error occurs, the returned response may be non nil.
113-
// However, you can only read the first 1024 bytes of the body.
125+
// By default, up to the first 1024 bytes of the body are available; this limit
126+
// can be adjusted via DialOptions.MaxErrorResponseBodyBytes.
114127
//
115128
// This function requires at least Go 1.12 as it uses a new feature
116129
// in net/http to perform WebSocket handshakes.
@@ -148,8 +161,19 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
148161
resp.Body = nil
149162
defer func() {
150163
if err != nil {
151-
// We read a bit of the body for easier debugging.
152-
r := io.LimitReader(respBody, 1024)
164+
// Capture a limited portion of the response body for easier debugging,
165+
// following the limit configured by MaxErrorResponseBodyBytes.
166+
limit := opts.MaxErrorResponseBodyBytes
167+
if limit == 0 {
168+
limit = 1024
169+
}
170+
if limit < 0 {
171+
// Do not capture any body bytes; ensure original body is closed.
172+
respBody.Close()
173+
return
174+
}
175+
176+
r := io.LimitReader(respBody, int64(limit))
153177

154178
timer := time.AfterFunc(time.Second*3, func() {
155179
respBody.Close()

dial_test.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,102 @@ func TestDialViaProxy(t *testing.T) {
416416
assertEcho(t, ctx, c)
417417
assertClose(t, c)
418418
}
419+
420+
// Additional tests for error response body capture behavior.
421+
// A tracking body to verify Close is called when capture is disabled.
422+
type trackingBodyDialTest struct {
423+
io.ReadCloser
424+
closed *bool
425+
}
426+
427+
func (tb trackingBodyDialTest) Close() error {
428+
*tb.closed = true
429+
return tb.ReadCloser.Close()
430+
}
431+
432+
func TestDial_ErrorResponseBodyCapture_DefaultAndCustom(t *testing.T) {
433+
t.Parallel()
434+
435+
longBody := strings.Repeat("x", 4096)
436+
437+
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
438+
w.WriteHeader(http.StatusTeapot)
439+
io.WriteString(w, longBody)
440+
}))
441+
defer s.Close()
442+
443+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
444+
defer cancel()
445+
446+
// Default behavior (zero value options): capture up to 1024 bytes
447+
_, resp, err := websocket.Dial(ctx, s.URL, nil)
448+
assert.Error(t, err)
449+
if resp == nil {
450+
t.Fatal("expected non-nil resp")
451+
}
452+
assert.Equal(t, "StatusCode", http.StatusTeapot, resp.StatusCode)
453+
454+
b, rerr := io.ReadAll(resp.Body)
455+
assert.Success(t, rerr)
456+
if len(b) > 1024 {
457+
t.Fatalf("expected captured body length <= 1024, got %d", len(b))
458+
}
459+
if exp := longBody[:len(b)]; string(b) != exp {
460+
t.Fatalf("unexpected body prefix: expected %d bytes prefix match", len(b))
461+
}
462+
463+
// Custom limit (>0)
464+
limit := 200
465+
_, resp, err = websocket.Dial(ctx, s.URL, &websocket.DialOptions{MaxErrorResponseBodyBytes: limit})
466+
assert.Error(t, err)
467+
if resp == nil {
468+
t.Fatal("expected non-nil resp")
469+
}
470+
assert.Equal(t, "StatusCode", http.StatusTeapot, resp.StatusCode)
471+
472+
b, rerr = io.ReadAll(resp.Body)
473+
assert.Success(t, rerr)
474+
if len(b) > limit {
475+
t.Fatalf("expected captured body length <= %d, got %d", limit, len(b))
476+
}
477+
if exp := longBody[:len(b)]; string(b) != exp {
478+
t.Fatalf("unexpected body prefix: expected %d bytes prefix match", len(b))
479+
}
480+
}
481+
482+
func TestDial_ErrorResponseBodyCapture_Disabled_NoBodyWithClose(t *testing.T) {
483+
t.Parallel()
484+
485+
closed := false
486+
rt := func(r *http.Request) (*http.Response, error) {
487+
// Return a long body and a non-101 status to trigger error path.
488+
return &http.Response{
489+
StatusCode: http.StatusForbidden,
490+
Body: trackingBodyDialTest{io.NopCloser(strings.NewReader(strings.Repeat("y", 4096))), &closed},
491+
}, nil
492+
}
493+
494+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
495+
defer cancel()
496+
497+
_, resp, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
498+
HTTPClient: mockHTTPClient(rt),
499+
MaxErrorResponseBodyBytes: -1,
500+
})
501+
assert.Error(t, err)
502+
if resp == nil {
503+
t.Fatal("expected non-nil resp")
504+
}
505+
assert.Equal(t, "StatusCode", http.StatusForbidden, resp.StatusCode)
506+
if resp.Body != nil {
507+
// If any body is present, ensure it's empty.
508+
b, rerr := io.ReadAll(resp.Body)
509+
assert.Success(t, rerr)
510+
if len(b) != 0 {
511+
t.Fatalf("expected no body bytes when capture disabled, got %d", len(b))
512+
}
513+
}
514+
if !closed {
515+
t.Fatal("expected original body to be closed")
516+
}
517+
}

0 commit comments

Comments
 (0)