From 85a43405dd9f6fc97b40f344e5bb55a66733b454 Mon Sep 17 00:00:00 2001 From: Alex Bumbacea Date: Fri, 19 Dec 2025 21:27:21 +0200 Subject: [PATCH] mcp: improve http transports error handling and make buffer size configurable --- mcp/event.go | 55 +++++++++++++++++++++++++++--------------- mcp/event_test.go | 2 +- mcp/sse.go | 31 +++++++++++++++++------- mcp/streamable.go | 8 ++++-- mcp/streamable_test.go | 4 +-- 5 files changed, 66 insertions(+), 34 deletions(-) diff --git a/mcp/event.go b/mcp/event.go index 5c322c4a..4f0211d1 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -66,10 +66,8 @@ func writeEvent(w io.Writer, evt Event) (int, error) { // // TODO(rfindley): consider a different API here that makes failure modes more // apparent. -func scanEvents(r io.Reader) iter.Seq2[Event, error] { - scanner := bufio.NewScanner(r) - const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size - scanner.Buffer(nil, maxTokenSize) +func scanEvents(ctx context.Context, r io.Reader) iter.Seq2[Event, error] { + scanner := bufio.NewReader(r) // TODO: investigate proper behavior when events are out of order, or have // non-standard names. @@ -100,15 +98,40 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { dataBuf = nil } } - for scanner.Scan() { - line := scanner.Bytes() + emitEvent := func() bool { + flushData() + if evt.Empty() { + return true + } + if !yield(evt, nil) { + return false + } + evt = Event{} + return true + } + for { + line, err := scanner.ReadBytes('\n') + if err != nil { + if errors.Is(err, io.EOF) { + // Handle EOF below + } else if ctx.Err() != nil { + yield(Event{}, fmt.Errorf("context done: %w", ctx.Err())) + return + } else { + yield(Event{}, fmt.Errorf("error reading event: %w", err)) + return + } + } + line = bytes.TrimRight(line, "\r\n") + isEOF := errors.Is(err, io.EOF) + if len(line) == 0 { - flushData() - // \n\n is the record delimiter - if !evt.Empty() && !yield(evt, nil) { + if !emitEvent() { + return + } + if isEOF { return } - evt = Event{} continue } before, after, found := bytes.Cut(line, []byte{':'}) @@ -136,19 +159,11 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { dataBuf.Write(data) } } - } - if err := scanner.Err(); err != nil { - if errors.Is(err, bufio.ErrTooLong) { - err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize) - } - if !yield(Event{}, err) { + if isEOF { + emitEvent() return } } - flushData() - if !evt.Empty() { - yield(evt, nil) - } } } diff --git a/mcp/event_test.go b/mcp/event_test.go index dacf30e8..53a53358 100644 --- a/mcp/event_test.go +++ b/mcp/event_test.go @@ -61,7 +61,7 @@ func TestScanEvents(t *testing.T) { r := strings.NewReader(tt.input) var got []Event var err error - for e, err2 := range scanEvents(r) { + for e, err2 := range scanEvents(t.Context(), r) { if err2 != nil { err = err2 break diff --git a/mcp/sse.go b/mcp/sse.go index a668c6d0..3a99a06d 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -361,7 +361,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { msgEndpoint, err := func() (*url.URL, error) { var evt Event - for evt, err = range scanEvents(resp.Body) { + for evt, err = range scanEvents(ctx, resp.Body) { break } if err != nil { @@ -382,7 +382,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { s := &sseClientConn{ client: httpClient, msgEndpoint: msgEndpoint, - incoming: make(chan []byte, 100), + incoming: make(chan sseMessage, 100), body: resp.Body, done: make(chan struct{}), } @@ -390,12 +390,16 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { go func() { defer s.Close() // close the transport when the GET exits - for evt, err := range scanEvents(resp.Body) { + for evt, err := range scanEvents(ctx, resp.Body) { if err != nil { + select { + case s.incoming <- sseMessage{err: err}: + case <-s.done: + } return } select { - case s.incoming <- evt.Data: + case s.incoming <- sseMessage{data: evt.Data}: case <-s.done: return } @@ -405,15 +409,21 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { return s, nil } +// sseMessage represents a message or error from the SSE stream. +type sseMessage struct { + data []byte + err error +} + // An sseClientConn is a logical jsonrpc2 connection that implements the client // half of the SSE protocol: // - Writes are POSTS to the session endpoint. // - Reads are SSE 'message' events, and pushes them onto a buffered channel. // - Close terminates the GET request. type sseClientConn struct { - client *http.Client // HTTP client to use for requests - msgEndpoint *url.URL // session endpoint for POSTs - incoming chan []byte // queue of incoming messages + client *http.Client // HTTP client to use for requests + msgEndpoint *url.URL // session endpoint for POSTs + incoming chan sseMessage // queue of incoming messages or errors mu sync.Mutex body io.ReadCloser // body of the hanging GET @@ -438,12 +448,15 @@ func (c *sseClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { case <-c.done: return nil, io.EOF - case data := <-c.incoming: + case m := <-c.incoming: + if m.err != nil { + return nil, m.err + } // TODO(rfindley): do we really need to check this? We receive from c.done above. if c.isDone() { return nil, io.EOF } - msg, err := jsonrpc2.DecodeMessage(data) + msg, err := jsonrpc2.DecodeMessage(m.data) if err != nil { return nil, err } diff --git a/mcp/streamable.go b/mcp/streamable.go index b4b2fa31..89c4233c 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1854,12 +1854,16 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary io.Copy(io.Discard, resp.Body) resp.Body.Close() }() - for evt, err := range scanEvents(resp.Body) { + for evt, err := range scanEvents(ctx, resp.Body) { if err != nil { if ctx.Err() != nil { return "", 0, true // don't reconnect: client cancelled } - break + + // Network errors during reading should trigger reconnection, not permanent failure. + // Return from processStream so handleSSE can attempt to reconnect. + c.logger.Debug(fmt.Sprintf("%s: stream read error (will attempt reconnect): %v", requestSummary, err)) + return lastEventID, reconnectDelay, false } if evt.ID != "" { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index b1c3f074..b4d28092 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1425,7 +1425,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, var respBody []byte if strings.HasPrefix(contentType, "text/event-stream") { r := readerInto{resp.Body, new(bytes.Buffer)} - for evt, err := range scanEvents(r) { + for evt, err := range scanEvents(ctx, r) { if err != nil { return newSessionID, resp.StatusCode, nil, fmt.Errorf("reading events: %v", err) } @@ -2143,7 +2143,7 @@ data: {"jsonrpc":"2.0","method":"test2","params":{}} var events []Event // Scan all events - for evt, err := range scanEvents(reader) { + for evt, err := range scanEvents(t.Context(), reader) { if err != nil { if err != io.EOF { t.Fatalf("scanEvents error: %v", err)