diff --git a/mcp/sse.go b/mcp/sse.go index 7f644918..51ebc712 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -363,6 +363,13 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Connection, error) { return nil, fmt.Errorf("first event is %q, want %q", evt.Name, "endpoint") } raw := string(evt.Data) + // If the server sends an absolute path (starting with "/"), convert it + // to a relative path to preserve the base URL's path prefix. This is + // necessary when the server is behind a reverse proxy with path-based + // routing, where the server may not know about the proxy's path prefix. + if len(raw) > 0 && raw[0] == '/' { + raw = raw[1:] + } return parsedURL.Parse(raw) }() if err != nil { diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 25435ff3..25f63233 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -131,3 +131,60 @@ type roundTripperFunc func(*http.Request) (*http.Response, error) func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) } + +// TestSSEClientPreservesPathPrefix verifies that the SSE client preserves the +// path prefix from the endpoint URL when resolving the message endpoint. +// This is important when the server is behind a reverse proxy with path-based +// routing. See https://github.com/modelcontextprotocol/go-sdk/issues/687. +func TestSSEClientPreservesPathPrefix(t *testing.T) { + ctx := context.Background() + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "greet"}, sayHi) + + sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }, nil) + + // Create a server with a path prefix to simulate a reverse proxy setup. + mux := http.NewServeMux() + const pathPrefix = "/api/mcp/backend" + mux.Handle(pathPrefix+"/", http.StripPrefix(pathPrefix, sseHandler)) + httpServer := httptest.NewServer(mux) + defer httpServer.Close() + + // Track the paths that the client POSTs to. + var postedPaths []string + customClient := &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Method == http.MethodPost { + postedPaths = append(postedPaths, req.URL.Path) + } + return http.DefaultTransport.RoundTrip(req) + }), + } + + clientTransport := &SSEClientTransport{ + Endpoint: httpServer.URL + pathPrefix + "/", + HTTPClient: customClient, + } + + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, clientTransport, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Ping to trigger a POST request. + if err := cs.Ping(ctx, nil); err != nil { + t.Fatal(err) + } + + // Verify that the POST request was sent to a path that preserves the prefix. + if len(postedPaths) == 0 { + t.Fatal("expected at least one POST request") + } + for _, path := range postedPaths { + if len(path) < len(pathPrefix) || path[:len(pathPrefix)] != pathPrefix { + t.Errorf("POST path %q does not preserve prefix %q", path, pathPrefix) + } + } +}