diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs index e3226b57d..9424f6252 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs @@ -1,4 +1,5 @@ -using ModelContextProtocol.Server; +using ModelContextProtocol.Core; +using ModelContextProtocol.Server; using System.Diagnostics; using System.Security.Claims; @@ -16,7 +17,7 @@ internal sealed class StreamableHttpSession( private readonly object _stateLock = new(); private int _getRequestStarted; - private readonly CancellationTokenSource _disposeCts = new(); + private CancellationTokenSource _disposeCts = new(); public string Id => sessionId; public StreamableHttpServerTransport Transport => transport; @@ -124,7 +125,8 @@ public async ValueTask DisposeAsync() { sessionManager.DecrementIdleSessionCount(); } - _disposeCts.Dispose(); + + CanceledTokenSource.Defuse(ref _disposeCts); } } diff --git a/src/ModelContextProtocol.Core/AIContentExtensions.cs b/src/ModelContextProtocol.Core/AIContentExtensions.cs index b1ba32bf4..55f058129 100644 --- a/src/ModelContextProtocol.Core/AIContentExtensions.cs +++ b/src/ModelContextProtocol.Core/AIContentExtensions.cs @@ -1,9 +1,8 @@ using Microsoft.Extensions.AI; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; -#if !NET +using System.Buffers.Text; using System.Runtime.InteropServices; -#endif using System.Text.Json; using System.Text.Json.Nodes; @@ -263,10 +262,12 @@ public static IList ToPromptMessages(this ChatMessage chatMessage AIContent? ac = content switch { TextContentBlock textContent => new TextContent(textContent.Text), + + Utf8TextContentBlock utf8TextContent => new TextContent(utf8TextContent.Text), - ImageContentBlock imageContent => new DataContent(Convert.FromBase64String(imageContent.Data), imageContent.MimeType), + ImageContentBlock imageContent => new DataContent(imageContent.DecodedData, imageContent.MimeType), - AudioContentBlock audioContent => new DataContent(Convert.FromBase64String(audioContent.Data), audioContent.MimeType), + AudioContentBlock audioContent => new DataContent(audioContent.DecodedData, audioContent.MimeType), EmbeddedResourceBlock resourceContent => resourceContent.Resource.ToAIContent(), @@ -275,7 +276,9 @@ public static IList ToPromptMessages(this ChatMessage chatMessage ToolResultContentBlock toolResult => new FunctionResultContent( toolResult.ToolUseId, - toolResult.Content.Count == 1 ? toolResult.Content[0].ToAIContent() : toolResult.Content.Select(c => c.ToAIContent()).OfType().ToList()) + toolResult.StructuredContent is JsonElement structured ? structured : + toolResult.Content.Count == 1 ? toolResult.Content[0].ToAIContent() : + toolResult.Content.Select(c => c.ToAIContent()).OfType().ToList()) { Exception = toolResult.IsError is true ? new() : null, }, @@ -307,7 +310,7 @@ public static AIContent ToAIContent(this ResourceContents content) AIContent ac = content switch { - BlobResourceContents blobResource => new DataContent(Convert.FromBase64String(blobResource.Blob), blobResource.MimeType ?? "application/octet-stream"), + BlobResourceContents blobResource => new DataContent(blobResource.DecodedData, blobResource.MimeType ?? "application/octet-stream"), TextResourceContents textResource => new TextContent(textResource.Text), _ => throw new NotSupportedException($"Resource type '{content.GetType().Name}' is not supported.") }; @@ -380,13 +383,17 @@ public static ContentBlock ToContentBlock(this AIContent content) DataContent dataContent when dataContent.HasTopLevelMediaType("image") => new ImageContentBlock { - Data = dataContent.Base64Data.ToString(), + Data = MemoryMarshal.TryGetArray(dataContent.Base64Data, out ArraySegment segment) + ? new string(segment.Array!, segment.Offset, segment.Count) + : new string(dataContent.Base64Data.ToArray()), MimeType = dataContent.MediaType, }, DataContent dataContent when dataContent.HasTopLevelMediaType("audio") => new AudioContentBlock { - Data = dataContent.Base64Data.ToString(), + Data = MemoryMarshal.TryGetArray(dataContent.Base64Data, out ArraySegment segment) + ? new string(segment.Array!, segment.Offset, segment.Count) + : new string(dataContent.Base64Data.ToArray()), MimeType = dataContent.MediaType, }, @@ -394,7 +401,7 @@ public static ContentBlock ToContentBlock(this AIContent content) { Resource = new BlobResourceContents { - Blob = dataContent.Base64Data.ToString(), + DecodedData = dataContent.Data, MimeType = dataContent.MediaType, Uri = string.Empty, } @@ -414,21 +421,51 @@ public static ContentBlock ToContentBlock(this AIContent content) Content = resultContent.Result is AIContent c ? [c.ToContentBlock()] : resultContent.Result is IEnumerable ec ? [.. ec.Select(c => c.ToContentBlock())] : - [new TextContentBlock { Text = JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions.GetTypeInfo()) }], - StructuredContent = resultContent.Result is JsonElement je ? je : null, + [new TextContentBlock { Text = "" }], + StructuredContent = + resultContent.Result is JsonElement je ? je : + resultContent.Result is null ? null : + JsonSerializer.SerializeToElement(resultContent.Result, McpJsonUtilities.DefaultOptions.GetTypeInfo()), }, - _ => new TextContentBlock + _ => CreateJsonResourceContentBlock(content) + }; + + static ContentBlock CreateJsonResourceContentBlock(AIContent content) + { + byte[] jsonUtf8 = JsonSerializer.SerializeToUtf8Bytes(content, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object))); + +#if NET + int maxLength = Base64.GetMaxEncodedToUtf8Length(jsonUtf8.Length); +#else + int maxLength = ((jsonUtf8.Length + 2) / 3) * 4; +#endif + + byte[] base64 = new byte[maxLength]; + if (Base64.EncodeToUtf8(jsonUtf8, base64, out _, out int bytesWritten) != System.Buffers.OperationStatus.Done) { - Text = JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object))), + throw new InvalidOperationException("Failed to base64-encode JSON payload."); } - }; + + ReadOnlyMemory blob = base64.AsMemory(0, bytesWritten); + + return new EmbeddedResourceBlock + { + Resource = new BlobResourceContents + { + Uri = string.Empty, + MimeType = "application/json", + BlobUtf8 = blob, + }, + }; + } contentBlock.Meta = content.AdditionalProperties?.ToJsonObject(); return contentBlock; } + private sealed class ToolAIFunctionDeclaration(Tool tool) : AIFunctionDeclaration { public override string Name => tool.Name; diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index 75126556b..50d2d167d 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -6,6 +6,7 @@ using System.Buffers.Text; #endif using System.Diagnostics.CodeAnalysis; +using ModelContextProtocol.Internal; using System.Net.Http.Headers; using System.Security.Cryptography; using System.Text; @@ -581,8 +582,9 @@ private async Task PerformDynamicClientRegistrationAsync( Scope = GetScopeParameter(protectedResourceMetadata), }; - var requestJson = JsonSerializer.Serialize(registrationRequest, McpJsonUtilities.JsonContext.Default.DynamicClientRegistrationRequest); - using var requestContent = new StringContent(requestJson, Encoding.UTF8, "application/json"); + using var requestContent = new JsonTypeInfoHttpContent( + registrationRequest, + McpJsonUtilities.JsonContext.Default.DynamicClientRegistrationRequest); using var request = new HttpRequestMessage(HttpMethod.Post, authServerMetadata.RegistrationEndpoint) { diff --git a/src/ModelContextProtocol.Core/CanceledTokenSource.cs b/src/ModelContextProtocol.Core/CanceledTokenSource.cs new file mode 100644 index 000000000..a898fa586 --- /dev/null +++ b/src/ModelContextProtocol.Core/CanceledTokenSource.cs @@ -0,0 +1,44 @@ +using System.Diagnostics.CodeAnalysis; + +namespace ModelContextProtocol.Core; + +/// +/// A that is already canceled. +/// Disposal is a no-op. +/// +public sealed class CanceledTokenSource : CancellationTokenSource +{ + /// + /// Gets a singleton instance of a canceled token source. + /// + public static readonly CanceledTokenSource Instance = new(); + + private CanceledTokenSource() + => Cancel(); + + /// + protected override void Dispose(bool disposing) + { + // No-op + } + + /// + /// Defuses the given by optionally canceling it + /// and replacing it with the singleton canceled instance. + /// The original token source is left for garbage collection and finalization provided + /// there are no other references to it outstanding if is false. + /// + /// The token source to pseudo-dispose. May be null. + /// Whether to cancel the token source before pseudo-disposing it. + /// Whether to call Dispose on the token source. + [SuppressMessage("Design", "CA1062:Validate arguments of public methods")] + public static void Defuse(ref CancellationTokenSource cts, bool cancel = true, bool dispose = false) + { + // don't null check; allow replacing null, allow throw on attempt to call Cancel + var orig = cts; + if (cancel) orig.Cancel(); + Interlocked.Exchange(ref cts, Instance); + // presume the GC will finalize and dispose the original CTS as needed + if (dispose) orig.Dispose(); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/McpClientTool.cs b/src/ModelContextProtocol.Core/Client/McpClientTool.cs index f4dc060d9..e42d34c39 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientTool.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientTool.cs @@ -142,8 +142,27 @@ result.StructuredContent is null && case 1 when result.Content[0].ToAIContent() is { } aiContent: return aiContent; - case > 1 when result.Content.Select(c => c.ToAIContent()).ToArray() is { } aiContents && aiContents.All(static c => c is not null): - return aiContents; + case > 1: + AIContent[] aiContents = new AIContent[result.Content.Count]; + bool allConverted = true; + + for (int i = 0; i < aiContents.Length; i++) + { + if (result.Content[i].ToAIContent() is not { } c) + { + allConverted = false; + break; + } + + aiContents[i] = c; + } + + if (allConverted) + { + return aiContents; + } + + break; } } diff --git a/src/ModelContextProtocol.Core/Client/McpHttpClient.cs b/src/ModelContextProtocol.Core/Client/McpHttpClient.cs index 77ca78fb4..b41c5b144 100644 --- a/src/ModelContextProtocol.Core/Client/McpHttpClient.cs +++ b/src/ModelContextProtocol.Core/Client/McpHttpClient.cs @@ -32,10 +32,9 @@ internal virtual async Task SendAsync(HttpRequestMessage re #if NET return JsonContent.Create(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); #else - return new StringContent( - JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), - Encoding.UTF8, - "application/json" + return new ModelContextProtocol.Internal.JsonTypeInfoHttpContent( + message, + McpJsonUtilities.JsonContext.Default.JsonRpcMessage ); #endif } diff --git a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs index 60950dfa5..21104d69d 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs @@ -1,5 +1,6 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Core; using ModelContextProtocol.Protocol; using System.Diagnostics; using System.Net.Http.Headers; @@ -18,7 +19,7 @@ internal sealed partial class SseClientSessionTransport : TransportBase private readonly HttpClientTransportOptions _options; private readonly Uri _sseEndpoint; private Uri? _messageEndpoint; - private readonly CancellationTokenSource _connectionCts; + private CancellationTokenSource _connectionCts; private Task? _receiveTask; private readonly ILogger _logger; private readonly TaskCompletionSource _connectionEstablished; @@ -114,7 +115,7 @@ private async Task CloseAsync() } finally { - _connectionCts.Dispose(); + CanceledTokenSource.Defuse(ref _connectionCts, dispose: true); } } finally diff --git a/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs index a9c228d43..ecca368ef 100644 --- a/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs @@ -7,7 +7,7 @@ namespace ModelContextProtocol.Client; /// Provides the client side of a stdio-based session transport. internal sealed class StdioClientSessionTransport( StdioClientTransportOptions options, Process process, string endpointName, Queue stderrRollingLog, ILoggerFactory? loggerFactory) : - StreamClientSessionTransport(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, encoding: null, endpointName, loggerFactory) + StreamClientSessionTransport(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, endpointName, loggerFactory) { private readonly StdioClientTransportOptions _options = options; private readonly Process _process = process; diff --git a/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs index c896bd433..fcf596705 100644 --- a/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs @@ -1,51 +1,44 @@ using Microsoft.Extensions.Logging; +using ModelContextProtocol.Core; using ModelContextProtocol.Protocol; +using System.Buffers; using System.Text; using System.Text.Json; namespace ModelContextProtocol.Client; -/// Provides the client side of a stream-based session transport. +/// Provides the client side of a stream-based session transport using raw streams. internal class StreamClientSessionTransport : TransportBase { - internal static UTF8Encoding NoBomUtf8Encoding { get; } = new(encoderShouldEmitUTF8Identifier: false); + private static readonly byte[] NewlineUtf8 = [(byte)'\n']; + + private readonly Stream _serverInput; + private readonly Stream _serverOutput; - private readonly TextReader _serverOutput; - private readonly TextWriter _serverInput; private readonly SemaphoreSlim _sendLock = new(1, 1); - private CancellationTokenSource? _shutdownCts = new(); + + // Intentionally not disposed; once this transport instance is collectable, CTS finalization will clean up. + private readonly CancellationTokenSource _shutdownCts = new(); + private Task? _readTask; - /// - /// Initializes a new instance of the class. - /// - /// - /// The text writer connected to the server's input stream. - /// Messages written to this writer will be sent to the server. - /// - /// - /// The text reader connected to the server's output stream. - /// Messages read from this reader will be received from the server. - /// - /// - /// A name that identifies this transport endpoint in logs. - /// - /// - /// Optional factory for creating loggers. If null, a NullLogger is used. - /// - /// - /// This constructor starts a background task to read messages from the server output stream. - /// The transport will be marked as connected once initialized. - /// - public StreamClientSessionTransport( - TextWriter serverInput, TextReader serverOutput, string endpointName, ILoggerFactory? loggerFactory) + internal static UTF8Encoding NoBomUtf8Encoding { get; } = new(encoderShouldEmitUTF8Identifier: false); + + public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, string endpointName, ILoggerFactory? loggerFactory) : base(endpointName, loggerFactory) { - _serverOutput = serverOutput; + Throw.IfNull(serverInput); + Throw.IfNull(serverOutput); + _serverInput = serverInput; + _serverOutput = serverOutput; SetConnected(); + StartReadLoop(); + } + private void StartReadLoop() + { // Start reading messages in the background. We use the rarer pattern of new Task + Start // in order to ensure that the body of the task will always see _readTask initialized. // It is then able to reliably null it out on completion. @@ -53,48 +46,11 @@ public StreamClientSessionTransport( thisRef => ((StreamClientSessionTransport)thisRef!).ReadMessagesAsync(_shutdownCts.Token), this, TaskCreationOptions.DenyChildAttach); + _readTask = readTask.Unwrap(); readTask.Start(); } - /// - /// Initializes a new instance of the class. - /// - /// - /// The server's input stream. Messages written to this stream will be sent to the server. - /// - /// - /// The server's output stream. Messages read from this stream will be received from the server. - /// - /// - /// The encoding used for reading and writing messages from the input and output streams. Defaults to UTF-8 without BOM if null. - /// - /// - /// A name that identifies this transport endpoint in logs. - /// - /// - /// Optional factory for creating loggers. If null, a NullLogger is used. - /// - /// - /// This constructor starts a background task to read messages from the server output stream. - /// The transport will be marked as connected once initialized. - /// - public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, Encoding? encoding, string endpointName, ILoggerFactory? loggerFactory) - : this( - new StreamWriter(serverInput, encoding ?? NoBomUtf8Encoding), -#if NET - new StreamReader(serverOutput, encoding ?? NoBomUtf8Encoding), -#else - new CancellableStreamReader(serverOutput, encoding ?? NoBomUtf8Encoding), -#endif - endpointName, - loggerFactory) - { - Throw.IfNull(serverInput); - Throw.IfNull(serverOutput); - } - - /// public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { string id = "(no id)"; @@ -103,13 +59,16 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation id = messageWithId.Id.ToString(); } - var json = JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); - using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); try { - // Write the message followed by a newline using our UTF-8 writer - await _serverInput.WriteLineAsync(json).ConfigureAwait(false); + await JsonSerializer.SerializeAsync( + _serverInput, + message, + McpJsonUtilities.JsonContext.Default.JsonRpcMessage, + cancellationToken).ConfigureAwait(false); + + await _serverInput.WriteAsync(NewlineUtf8, cancellationToken).ConfigureAwait(false); await _serverInput.FlushAsync(cancellationToken).ConfigureAwait(false); } catch (Exception ex) @@ -119,7 +78,6 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation } } - /// public override ValueTask DisposeAsync() => CleanupAsync(cancellationToken: CancellationToken.None); @@ -129,60 +87,139 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken) try { LogTransportEnteringReadMessagesLoop(Name); + await ReadMessagesFromStreamAsync(_serverOutput, cancellationToken).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + LogTransportReadMessagesCancelled(Name); + } + catch (Exception ex) + { + error = ex; + LogTransportReadMessagesFailed(Name, ex); + } + finally + { + _readTask = null; + await CleanupAsync(error, cancellationToken).ConfigureAwait(false); + } + } + + private async Task ReadMessagesFromStreamAsync(Stream stream, CancellationToken cancellationToken) + { + byte[] buffer = ArrayPool.Shared.Rent(16 * 1024); + try + { + using var lineStream = new MemoryStream(); while (true) { - if (await _serverOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false) is not string line) + int bytesRead = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); + if (bytesRead == 0) { LogTransportEndOfStream(Name); break; } - if (string.IsNullOrWhiteSpace(line)) + int offset = 0; + while (offset < bytesRead) { - continue; + int newlineIndex = Array.IndexOf(buffer, (byte)'\n', offset, bytesRead - offset); + if (newlineIndex < 0) + { +#pragma warning disable CA1849 // WriteAsync on MemoryStream is not necessary + lineStream.Write(buffer, offset, bytesRead - offset); +#pragma warning restore CA1849 + break; + } + + int partLength = newlineIndex - offset; + if (partLength > 0) + { +#pragma warning disable CA1849 // WriteAsync on MemoryStream is not necessary + lineStream.Write(buffer, offset, partLength); +#pragma warning restore CA1849 + } + + offset = newlineIndex + 1; + + if (!lineStream.TryGetBuffer(out ArraySegment segment)) + { + throw new InvalidOperationException("Expected MemoryStream to expose its buffer."); + } + + // IMPORTANT: `lineBytes` is a slice over `lineStream`'s underlying buffer. + // This is intentionally copy-free. + // + // Safety / lifetime: + // - `lineStream` stays alive for the duration of this read loop. + // - `lineStream.SetLength(0)` does not clear, overwrite, or reallocate the underlying array. + // - We do not write to `lineStream` again until after `await ProcessMessageAsync(...)` completes, + // so the bytes referenced by `lineBytes` are not mutated while they're being parsed. + // + // Do not store `lineBytes` (or `segment.Array`) anywhere or queue it for later processing. + ReadOnlyMemory lineBytes = new(segment.Array!, segment.Offset, (int)lineStream.Length); + + if (!lineBytes.IsEmpty && lineBytes.Span[^1] == (byte)'\r') + { + lineBytes = lineBytes[..^1]; + } + + // Reset for buffering the next line. This only updates the length; it does not clear the buffer. + lineStream.SetLength(0); + + if (McpTextUtilities.IsWhiteSpace(lineBytes.Span)) + { + continue; + } + + // Keep the await inline to ensure no subsequent writes to `lineStream` occur until the message + // parsing/dispatch is complete (otherwise the underlying buffer could be overwritten). + await ProcessMessageAsync(lineBytes, cancellationToken).ConfigureAwait(false); } - - LogTransportReceivedMessageSensitive(Name, line); - - await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false); } } - catch (OperationCanceledException) + finally { - LogTransportReadMessagesCancelled(Name); + ArrayPool.Shared.Return(buffer); } - catch (Exception ex) + } + + private async Task ProcessMessageAsync(ReadOnlyMemory lineBytes, CancellationToken cancellationToken) + { + // `lineBytes` may be backed by a reusable buffer owned by the read loop. + // This method must not let the buffer escape (e.g., store/capture `span` across awaits). + ReadOnlySpan span = lineBytes.Span; + + string? lineForLogs = null; + if (Logger.IsEnabled(LogLevel.Trace)) { - error = ex; - LogTransportReadMessagesFailed(Name, ex); +lineForLogs = McpTextUtilities.GetStringFromUtf8(span); } - finally + + if (lineForLogs is not null) { - _readTask = null; - await CleanupAsync(error, cancellationToken).ConfigureAwait(false); + LogTransportReceivedMessageSensitive(Name, lineForLogs); } - } - private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken) - { try { - var message = (JsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))); + var reader = new Utf8JsonReader(span); + var message = (JsonRpcMessage?)JsonSerializer.Deserialize(ref reader, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))); if (message != null) { await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); } - else + else if (lineForLogs is not null) { - LogTransportMessageParseUnexpectedTypeSensitive(Name, line); + LogTransportMessageParseUnexpectedTypeSensitive(Name, lineForLogs); } } catch (JsonException ex) { - if (Logger.IsEnabled(LogLevel.Trace)) + if (Logger.IsEnabled(LogLevel.Trace) && lineForLogs is not null) { - LogTransportMessageParseFailedSensitive(Name, line, ex); + LogTransportMessageParseFailedSensitive(Name, lineForLogs, ex); } else { @@ -195,11 +232,7 @@ protected virtual async ValueTask CleanupAsync(Exception? error = null, Cancella { LogTransportShuttingDown(Name); - if (Interlocked.Exchange(ref _shutdownCts, null) is { } shutdownCts) - { - await shutdownCts.CancelAsync().ConfigureAwait(false); - shutdownCts.Dispose(); - } + await _shutdownCts.CancelAsync().ConfigureAwait(false); if (Interlocked.Exchange(ref _readTask, null) is Task readTask) { diff --git a/src/ModelContextProtocol.Core/Client/StreamClientTransport.cs b/src/ModelContextProtocol.Core/Client/StreamClientTransport.cs index deca7e6ef..c2429cdec 100644 --- a/src/ModelContextProtocol.Core/Client/StreamClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamClientTransport.cs @@ -50,7 +50,6 @@ public Task ConnectAsync(CancellationToken cancellationToken = defau return Task.FromResult(new StreamClientSessionTransport( _serverInput, _serverOutput, - encoding: null, "Client (stream)", _loggerFactory)); } diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index 534249038..4618334f9 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -1,5 +1,6 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Core; using System.Net.Http.Headers; using System.Net.ServerSentEvents; using System.Text.Json; @@ -18,7 +19,7 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa private readonly McpHttpClient _httpClient; private readonly HttpClientTransportOptions _options; - private readonly CancellationTokenSource _connectionCts; + private CancellationTokenSource _connectionCts; private readonly ILogger _logger; private string? _negotiatedProtocolVersion; @@ -172,7 +173,7 @@ public override async ValueTask DisposeAsync() } finally { - _connectionCts.Dispose(); + CanceledTokenSource.Defuse(ref _connectionCts, dispose: true); } } finally diff --git a/src/ModelContextProtocol.Core/Client/TextStreamClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/TextStreamClientSessionTransport.cs new file mode 100644 index 000000000..726803faf --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/TextStreamClientSessionTransport.cs @@ -0,0 +1,335 @@ +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Core; +using ModelContextProtocol.Protocol; +using System.Buffers; +using System.Runtime.InteropServices; +using System.Text; +using System.Text.Json; + +namespace ModelContextProtocol.Client; + +/// Provides the client side of a text-based session transport. +internal class TextStreamClientSessionTransport : TransportBase +{ + internal static UTF8Encoding NoBomUtf8Encoding { get; } = new(encoderShouldEmitUTF8Identifier: false); + + private static readonly byte[] NewlineUtf8 = [(byte)'\n']; + + private readonly TextWriter _serverInput; + private readonly TextReader? _serverOutput; + private readonly Stream? _serverOutputStream; + + private readonly SemaphoreSlim _sendLock = new(1, 1); + + // Intentionally not disposed; once this transport instance is collectable, CTS finalization will clean up. + private readonly CancellationTokenSource _shutdownCts = new(); + + private Task? _readTask; + + public TextStreamClientSessionTransport( + TextWriter serverInput, TextReader serverOutput, string endpointName, ILoggerFactory? loggerFactory) + : base(endpointName, loggerFactory) + { + Throw.IfNull(serverInput); + Throw.IfNull(serverOutput); + + _serverInput = serverInput; + + if (serverOutput is StreamReader sr && sr.CurrentEncoding.CodePage == Encoding.UTF8.CodePage) + { + _serverOutput = null; + _serverOutputStream = sr.BaseStream; + } + else + { + _serverOutput = serverOutput; + _serverOutputStream = null; + } + + SetConnected(); + StartReadLoop(); + } + + public TextStreamClientSessionTransport(Stream serverInput, Stream serverOutput, Encoding? encoding, string endpointName, ILoggerFactory? loggerFactory) + : base(endpointName, loggerFactory) + { + Throw.IfNull(serverInput); + Throw.IfNull(serverOutput); + + _serverInput = new StreamWriter(serverInput, encoding ?? NoBomUtf8Encoding); + _serverOutput = null; + _serverOutputStream = serverOutput; + + SetConnected(); + StartReadLoop(); + } + + private void StartReadLoop() + { + // Start reading messages in the background. We use the rarer pattern of new Task + Start + // in order to ensure that the body of the task will always see _readTask initialized. + // It is then able to reliably null it out on completion. + var readTask = new Task( + thisRef => ((TextStreamClientSessionTransport)thisRef!).ReadMessagesAsync(_shutdownCts.Token), + this, + TaskCreationOptions.DenyChildAttach); + _readTask = readTask.Unwrap(); + readTask.Start(); + } + + /// + public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + { + string id = "(no id)"; + if (message is JsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); + try + { + // Prefer writing UTF-8 directly to avoid staging JSON in UTF-16. + if (_serverInput is StreamWriter sw) + { + using var jsonWriter = new Utf8JsonWriter(sw.BaseStream); + JsonSerializer.Serialize(jsonWriter, message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); + await jsonWriter.FlushAsync(cancellationToken).ConfigureAwait(false); + + await sw.BaseStream.WriteAsync(NewlineUtf8, cancellationToken).ConfigureAwait(false); + await sw.BaseStream.FlushAsync(cancellationToken).ConfigureAwait(false); + return; + } + + // Fallback for arbitrary TextWriter instances: avoid allocating a UTF-16 string. + byte[] utf8JsonBytes = JsonSerializer.SerializeToUtf8Bytes(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage); + + int charCount = Encoding.UTF8.GetCharCount(utf8JsonBytes); + char[] rented = ArrayPool.Shared.Rent(charCount); + try + { + int charsWritten = Encoding.UTF8.GetChars(utf8JsonBytes, 0, utf8JsonBytes.Length, rented, 0); + await _serverInput.WriteAsync(rented, 0, charsWritten).ConfigureAwait(false); + await _serverInput.WriteAsync('\n').ConfigureAwait(false); + await _serverInput.FlushAsync(cancellationToken).ConfigureAwait(false); + } + finally + { + ArrayPool.Shared.Return(rented); + } + } + catch (Exception ex) + { + LogTransportSendFailed(Name, id, ex); + throw new IOException("Failed to send message.", ex); + } + } + + /// + public override ValueTask DisposeAsync() => + CleanupAsync(cancellationToken: CancellationToken.None); + + private async Task ReadMessagesAsync(CancellationToken cancellationToken) + { + Exception? error = null; + try + { + LogTransportEnteringReadMessagesLoop(Name); + + if (_serverOutputStream is not null) + { + await ReadMessagesFromStreamAsync(_serverOutputStream, cancellationToken).ConfigureAwait(false); + return; + } + + TextReader serverOutput = _serverOutput ?? throw new InvalidOperationException("No output stream configured."); + while (true) + { + if (await serverOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false) is not string line) + { + LogTransportEndOfStream(Name); + break; + } + + if (string.IsNullOrWhiteSpace(line)) + { + continue; + } + + LogTransportReceivedMessageSensitive(Name, line); + + await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false); + } + } + catch (OperationCanceledException) + { + LogTransportReadMessagesCancelled(Name); + } + catch (Exception ex) + { + error = ex; + LogTransportReadMessagesFailed(Name, ex); + } + finally + { + _readTask = null; + await CleanupAsync(error, cancellationToken).ConfigureAwait(false); + } + } + + private async Task ReadMessagesFromStreamAsync(Stream stream, CancellationToken cancellationToken) + { + byte[] buffer = System.Buffers.ArrayPool.Shared.Rent(16 * 1024); + try + { + using var lineStream = new MemoryStream(); + + while (true) + { + int bytesRead = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); + if (bytesRead == 0) + { + LogTransportEndOfStream(Name); + break; + } + + int offset = 0; + while (offset < bytesRead) + { + int newlineIndex = Array.IndexOf(buffer, (byte)'\n', offset, bytesRead - offset); + if (newlineIndex < 0) + { + lineStream.Write(buffer, offset, bytesRead - offset); + break; + } + + int partLength = newlineIndex - offset; + if (partLength > 0) + { + lineStream.Write(buffer, offset, partLength); + } + + offset = newlineIndex + 1; + + if (!lineStream.TryGetBuffer(out ArraySegment segment)) + { + throw new InvalidOperationException("Expected MemoryStream to expose its buffer."); + } + + ReadOnlyMemory lineBytes = new(segment.Array!, segment.Offset, (int)lineStream.Length); + + if (!lineBytes.IsEmpty && lineBytes.Span[^1] == (byte)'\r') + { + lineBytes = lineBytes[..^1]; + } + + lineStream.SetLength(0); + + if (McpTextUtilities.IsWhiteSpace(lineBytes.Span)) + { + continue; + } + + await ProcessMessageAsync(lineBytes, cancellationToken).ConfigureAwait(false); + } + } + } + finally + { + System.Buffers.ArrayPool.Shared.Return(buffer); + } + } + + private async Task ProcessMessageAsync(ReadOnlyMemory lineBytes, CancellationToken cancellationToken) + { + ReadOnlySpan span = lineBytes.Span; + + string? lineForLogs = null; + if (Logger.IsEnabled(LogLevel.Trace)) + { +lineForLogs = McpTextUtilities.GetStringFromUtf8(span); + } + + if (lineForLogs is not null) + { + LogTransportReceivedMessageSensitive(Name, lineForLogs); + } + + try + { + var reader = new Utf8JsonReader(span); + var message = (JsonRpcMessage?)JsonSerializer.Deserialize(ref reader, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))); + if (message != null) + { + await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); + } + else if (lineForLogs is not null) + { + LogTransportMessageParseUnexpectedTypeSensitive(Name, lineForLogs); + } + } + catch (JsonException ex) + { + if (Logger.IsEnabled(LogLevel.Trace) && lineForLogs is not null) + { + LogTransportMessageParseFailedSensitive(Name, lineForLogs, ex); + } + else + { + LogTransportMessageParseFailed(Name, ex); + } + } + } + + private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken) + { + try + { + var message = (JsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))); + if (message != null) + { + await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); + } + else + { + LogTransportMessageParseUnexpectedTypeSensitive(Name, line); + } + } + catch (JsonException ex) + { + if (Logger.IsEnabled(LogLevel.Trace)) + { + LogTransportMessageParseFailedSensitive(Name, line, ex); + } + else + { + LogTransportMessageParseFailed(Name, ex); + } + } + } + + protected virtual async ValueTask CleanupAsync(Exception? error = null, CancellationToken cancellationToken = default) + { + LogTransportShuttingDown(Name); + + await _shutdownCts.CancelAsync().ConfigureAwait(false); + + if (Interlocked.Exchange(ref _readTask, null) is Task readTask) + { + try + { + await readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + } + catch (Exception ex) + { + LogTransportCleanupReadTaskFailed(Name, ex); + } + } + + SetDisconnected(error); + LogTransportShutDown(Name); + } +} diff --git a/src/ModelContextProtocol.Core/Internal/JsonTypeInfoHttpContent.cs b/src/ModelContextProtocol.Core/Internal/JsonTypeInfoHttpContent.cs new file mode 100644 index 000000000..bd0e14400 --- /dev/null +++ b/src/ModelContextProtocol.Core/Internal/JsonTypeInfoHttpContent.cs @@ -0,0 +1,45 @@ +using System.Net; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol.Internal; + +internal sealed class JsonTypeInfoHttpContent : HttpContent +{ + private readonly T _value; + private readonly JsonTypeInfo _typeInfo; + + public JsonTypeInfoHttpContent(T value, JsonTypeInfo typeInfo) + { + _value = value; + _typeInfo = typeInfo; + + // Match StringContent's default behavior (application/json; charset=utf-8). + Headers.ContentType = new MediaTypeHeaderValue("application/json") + { + CharSet = "utf-8", + }; + } + +#if NET + protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context, CancellationToken cancellationToken) => + JsonSerializer.SerializeAsync(stream, _value, _typeInfo, cancellationToken); + + protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context) => + SerializeToStreamAsync(stream, context, CancellationToken.None); +#else + // HttpContent.SerializeToStreamAsync does not provide a CancellationToken on non-NET TFMs. + // Cancellation can still abort the underlying HTTP request, but it won't interrupt serialization itself. + protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context) => + JsonSerializer.SerializeAsync(stream, _value, _typeInfo, CancellationToken.None); +#endif + + protected override bool TryComputeLength(out long length) + { + // Intentionally unknown length to avoid buffering the entire JSON payload just to compute Content-Length. + length = 0; + return false; + } +} diff --git a/src/ModelContextProtocol.Core/McpJsonUtilities.cs b/src/ModelContextProtocol.Core/McpJsonUtilities.cs index b3d98dd0e..58010bae4 100644 --- a/src/ModelContextProtocol.Core/McpJsonUtilities.cs +++ b/src/ModelContextProtocol.Core/McpJsonUtilities.cs @@ -28,7 +28,17 @@ public static partial class McpJsonUtilities /// /// /// - public static JsonSerializerOptions DefaultOptions { get; } = CreateDefaultOptions(); + public static JsonSerializerOptions DefaultOptions { get; } = CreateOptionsCore(); + + /// + /// Creates MCP serialization options. + /// + /// + /// When , deserializing a "type":"text" content block will materialize a + /// instead of a . + /// + public static JsonSerializerOptions CreateOptions(bool materializeUtf8TextContentBlocks = false) => + CreateOptionsCore(materializeUtf8TextContentBlocks); /// /// Creates default options to use for MCP-related serialization. @@ -36,7 +46,7 @@ public static partial class McpJsonUtilities /// The configured options. [UnconditionalSuppressMessage("ReflectionAnalysis", "IL3050:RequiresDynamicCode", Justification = "Converter is guarded by IsReflectionEnabledByDefault check.")] [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access", Justification = "Converter is guarded by IsReflectionEnabledByDefault check.")] - private static JsonSerializerOptions CreateDefaultOptions() + private static JsonSerializerOptions CreateOptionsCore(bool materializeUtf8TextContentBlocks = false) { // Copy the configuration from the source generated context. JsonSerializerOptions options = new(JsonContext.Default.Options); @@ -44,6 +54,12 @@ private static JsonSerializerOptions CreateDefaultOptions() // Chain with all supported types from MEAI. options.TypeInfoResolverChain.Add(AIJsonUtilities.DefaultOptions.TypeInfoResolver!); + // Override the per-type converter if requested. + if (materializeUtf8TextContentBlocks) + { + options.Converters.Insert(0, new ContentBlock.Converter(materializeUtf8TextContentBlocks: true)); + } + // Add a converter for user-defined enums, if reflection is enabled by default. if (JsonSerializer.IsReflectionEnabledByDefault) { diff --git a/src/ModelContextProtocol.Core/McpTextUtilities.cs b/src/ModelContextProtocol.Core/McpTextUtilities.cs new file mode 100644 index 000000000..4f9159543 --- /dev/null +++ b/src/ModelContextProtocol.Core/McpTextUtilities.cs @@ -0,0 +1,235 @@ +using System.Buffers.Text; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Text; +using System.Text.Json; + +namespace ModelContextProtocol.Core; + +/// +/// Provides helpers for working with UTF-8 data across all target frameworks. +/// +public static class McpTextUtilities +{ + /// + /// Decodes the provided UTF-8 bytes into a . + /// Uses a pointer-based overload on TFM netstandard2.0. + /// (The specific method differs by target framework.) + /// + public static string GetStringFromUtf8(ReadOnlySpan utf8Bytes) + { +#if NET + return Encoding.UTF8.GetString(utf8Bytes); +#else + if (utf8Bytes.IsEmpty) + { + return string.Empty; + } + + unsafe + { + fixed (byte* p = utf8Bytes) + { + return Encoding.UTF8.GetString(p, utf8Bytes.Length); + } + } +#endif + } + + /// + /// Encodes the provided binary data into a base64 string. + /// Uses a span-based overload on TFM net. + /// + public static string GetBase64String(ReadOnlyMemory data) + { +#if NET + return Convert.ToBase64String(data.Span); +#else + if (MemoryMarshal.TryGetArray(data, out ArraySegment segment)) + { + return Convert.ToBase64String(segment.Array!, segment.Offset, segment.Count); + } + + return Convert.ToBase64String(data.ToArray()); +#endif + } + + /// + /// Determines whether the provided UTF-8 bytes consist only of whitespace characters + /// commonly found in MCP transports (space, tab, carriage return). + /// + public static bool IsWhiteSpace(ReadOnlySpan utf8Bytes) + { + for (int i = 0; i < utf8Bytes.Length; i++) + { + byte b = utf8Bytes[i]; + if (b != (byte)' ' && b != (byte)'\t' && b != (byte)'\r') + { + return false; + } + } + + return true; + } + + internal static byte[] UnescapeJsonStringToUtf8(ReadOnlySpan escaped) + { + // Two-pass: first compute output length, then write, to avoid intermediate buffers/copies. + int outputLength = 0; + for (int i = 0; i < escaped.Length; i++) + { + byte b = escaped[i]; + if (b != (byte)'\\') + { + outputLength++; + continue; + } + + if (++i >= escaped.Length) + { + throw new JsonException(); + } + + switch (escaped[i]) + { + case (byte)'"': + case (byte)'\\': + case (byte)'/': + case (byte)'b': + case (byte)'f': + case (byte)'n': + case (byte)'r': + case (byte)'t': + outputLength++; + break; + + case (byte)'u': + outputLength += GetUtf8ByteCountForEscapedUnicode(escaped, ref i); + break; + + default: + throw new JsonException(); + } + } + + byte[] result = new byte[outputLength]; + int dst = 0; + + for (int i = 0; i < escaped.Length; i++) + { + byte b = escaped[i]; + if (b != (byte)'\\') + { + result[dst++] = b; + continue; + } + + if (++i >= escaped.Length) + { + throw new JsonException(); + } + + byte esc = escaped[i]; + switch (esc) + { + case (byte)'"': result[dst++] = (byte)'"'; break; + case (byte)'\\': result[dst++] = (byte)'\\'; break; + case (byte)'/': result[dst++] = (byte)'/'; break; + case (byte)'b': result[dst++] = 0x08; break; + case (byte)'f': result[dst++] = 0x0C; break; + case (byte)'n': result[dst++] = 0x0A; break; + case (byte)'r': result[dst++] = 0x0D; break; + case (byte)'t': result[dst++] = 0x09; break; + + case (byte)'u': + uint scalar = ReadEscapedUnicodeScalar(escaped, ref i); + WriteUtf8Scalar(scalar, result, ref dst); + break; + + default: + throw new JsonException(); + } + } + + Debug.Assert(dst == result.Length); + return result; + } + + internal static int GetUtf8ByteCountForEscapedUnicode(ReadOnlySpan escaped, ref int i) + { + uint scalar = ReadEscapedUnicodeScalar(escaped, ref i); + return scalar <= 0x7F ? 1 : + scalar <= 0x7FF ? 2 : + scalar <= 0xFFFF ? 3 : + 4; + } + + internal static uint ReadEscapedUnicodeScalar(ReadOnlySpan escaped, ref int i) + { + // i points at 'u'. + if (i + 4 >= escaped.Length) + { + throw new JsonException(); + } + + uint codeUnit = (uint)(FromHex(escaped[i + 1]) << 12 | + FromHex(escaped[i + 2]) << 8 | + FromHex(escaped[i + 3]) << 4 | + FromHex(escaped[i + 4])); + i += 4; + + // Surrogate pair: \uD800-\uDBFF followed by \uDC00-\uDFFF + if (codeUnit is >= 0xD800 and <= 0xDBFF) + { + int lookahead = i + 1; + if (lookahead + 5 < escaped.Length && escaped[lookahead] == (byte)'\\' && escaped[lookahead + 1] == (byte)'u') + { + uint low = (uint)(FromHex(escaped[lookahead + 2]) << 12 | + FromHex(escaped[lookahead + 3]) << 8 | + FromHex(escaped[lookahead + 4]) << 4 | + FromHex(escaped[lookahead + 5])); + + if (low is >= 0xDC00 and <= 0xDFFF) + { + i = lookahead + 5; + return 0x10000u + ((codeUnit - 0xD800u) << 10) + (low - 0xDC00u); + } + } + } + + return codeUnit; + } + + internal static int FromHex(byte b) + { + if ((uint)(b - '0') <= 9) return b - '0'; + if ((uint)((b | 0x20) - 'a') <= 5) return (b | 0x20) - 'a' + 10; + throw new JsonException(); + } + + internal static void WriteUtf8Scalar(uint scalar, byte[] destination, ref int dst) + { + if (scalar <= 0x7F) + { + destination[dst++] = (byte)scalar; + } + else if (scalar <= 0x7FF) + { + destination[dst++] = (byte)(0xC0 | (scalar >> 6)); + destination[dst++] = (byte)(0x80 | (scalar & 0x3F)); + } + else if (scalar <= 0xFFFF) + { + destination[dst++] = (byte)(0xE0 | (scalar >> 12)); + destination[dst++] = (byte)(0x80 | ((scalar >> 6) & 0x3F)); + destination[dst++] = (byte)(0x80 | (scalar & 0x3F)); + } + else + { + destination[dst++] = (byte)(0xF0 | (scalar >> 18)); + destination[dst++] = (byte)(0x80 | ((scalar >> 12) & 0x3F)); + destination[dst++] = (byte)(0x80 | ((scalar >> 6) & 0x3F)); + destination[dst++] = (byte)(0x80 | (scalar & 0x3F)); + } + } +} diff --git a/src/ModelContextProtocol.Core/Protocol/BlobResourceContents.cs b/src/ModelContextProtocol.Core/Protocol/BlobResourceContents.cs index 904e71b00..3eea610e9 100644 --- a/src/ModelContextProtocol.Core/Protocol/BlobResourceContents.cs +++ b/src/ModelContextProtocol.Core/Protocol/BlobResourceContents.cs @@ -1,5 +1,9 @@ +using System.Buffers.Text; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Text; using System.Text.Json.Serialization; +using ModelContextProtocol.Core; namespace ModelContextProtocol.Protocol; @@ -10,7 +14,8 @@ namespace ModelContextProtocol.Protocol; /// /// is used when binary data needs to be exchanged through /// the Model Context Protocol. The binary data is represented as a base64-encoded string -/// in the property. +/// in the and properties and as raw bytes in +/// the property. /// /// /// This class inherits from , which also has a sibling implementation @@ -24,20 +29,120 @@ namespace ModelContextProtocol.Protocol; [DebuggerDisplay("{DebuggerDisplay,nq}")] public sealed class BlobResourceContents : ResourceContents { + private ReadOnlyMemory _decodedData; + private ReadOnlyMemory _blobUtf8; + private string? _blob; + + /// Initializes a new instance of the class. + [SetsRequiredMembers] + public BlobResourceContents() + { + Blob = string.Empty; + Uri = string.Empty; + } + /// /// Gets or sets the base64-encoded string representing the binary data of the item. /// [JsonPropertyName("blob")] - public required string Blob { get; set; } + public required string Blob + { + get + => _blob ??= !_blobUtf8.IsEmpty + ? McpTextUtilities.GetStringFromUtf8(_blobUtf8.Span) + // encode _decodedData back to base64 if needed + : McpTextUtilities.GetBase64String(_decodedData); + set + { + _blob = value; + _blobUtf8 = Encoding.UTF8.GetBytes(value); + _decodedData = default; // Invalidate cache + } + } + + /// + /// Gets or sets the base64-encoded UTF-8 bytes representing the binary data of the item. + /// + /// + /// This is a zero-copy representation of the wire payload of this item. Setting this value will invalidate any cached value of . + /// + [JsonIgnore] + public ReadOnlyMemory BlobUtf8 + { + get => _blobUtf8.IsEmpty + ? _blob is null + ? _decodedData.IsEmpty + ? ReadOnlyMemory.Empty + : EncodeToUtf8(_decodedData) + : Encoding.UTF8.GetBytes(_blob) + : _blobUtf8; + set + { + _blob = null; + _blobUtf8 = value; + _decodedData = default; // Invalidate cache + } + } - [DebuggerBrowsable(DebuggerBrowsableState.Never)] - private string DebuggerDisplay + private ReadOnlyMemory EncodeToUtf8(ReadOnlyMemory decodedData) + { + int maxLength = Base64.GetMaxEncodedToUtf8Length(decodedData.Length); + byte[] buffer = new byte[maxLength]; + if (Base64.EncodeToUtf8(decodedData.Span, buffer, out _, out int bytesWritten) == System.Buffers.OperationStatus.Done) + { + return buffer.AsMemory(0, bytesWritten); + } + else + { + throw new FormatException("Failed to encode base64 data"); + } + } + + [JsonIgnore] + internal bool HasBlobUtf8 => !_blobUtf8.IsEmpty; + + internal ReadOnlySpan GetBlobUtf8Span() => _blobUtf8.Span; + + /// + /// Gets the decoded data represented by . + /// + /// + /// Accessing this member will decode the value in and cache the result. + /// Subsequent accesses return the cached value unless is modified. + /// + [JsonIgnore] + public ReadOnlyMemory DecodedData { get { - string lengthDisplay = DebuggerDisplayHelper.GetBase64LengthDisplay(Blob); - string mimeInfo = MimeType is not null ? $", MimeType = {MimeType}" : ""; - return $"Uri = \"{Uri}\"{mimeInfo}, Length = {lengthDisplay}"; + if (_decodedData.IsEmpty) + { + if (_blob is not null) + { + // Decode from string representation + _decodedData = Convert.FromBase64String(_blob); + return _decodedData; + } + // Decode directly from UTF-8 base64 bytes without string intermediate + int maxLength = Base64.GetMaxDecodedFromUtf8Length(BlobUtf8.Length); + byte[] buffer = new byte[maxLength]; + if (Base64.DecodeFromUtf8(BlobUtf8.Span, buffer, out _, out int bytesWritten) == System.Buffers.OperationStatus.Done) + { + _decodedData = bytesWritten == maxLength ? buffer : buffer.AsMemory(0, bytesWritten).ToArray(); + } + else + { + throw new FormatException("Invalid base64 data"); + } + } + + return _decodedData; + } + set + { + _blob = null; + _blobUtf8 = default; + _decodedData = value; } } } diff --git a/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs b/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs index 04c98bc84..5bc2e452b 100644 --- a/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs +++ b/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs @@ -1,3 +1,5 @@ +using System.Buffers; +using System.Buffers.Text; using System.ComponentModel; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; @@ -13,8 +15,8 @@ namespace ModelContextProtocol.Protocol; /// /// /// The class is a fundamental type in the MCP that can represent different forms of content -/// based on the property. Derived types like , , -/// and provide the type-specific content. +/// based on the property. Derived types like , , +/// , and provide the type-specific content. /// /// /// This class is used throughout the MCP for representing content in messages, tool responses, @@ -71,6 +73,16 @@ private protected ContentBlock() [EditorBrowsable(EditorBrowsableState.Never)] public class Converter : JsonConverter { + private readonly bool _materializeUtf8TextContentBlocks; + + /// Initializes a new instance of the class. + public Converter() + { + } + + internal Converter(bool materializeUtf8TextContentBlocks) => + _materializeUtf8TextContentBlocks = materializeUtf8TextContentBlocks; + /// public override ContentBlock? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { @@ -85,9 +97,9 @@ public class Converter : JsonConverter } string? type = null; - string? text = null; + ReadOnlyMemory? utf8Text = null; string? name = null; - string? data = null; + ReadOnlyMemory? dataUtf8 = null; string? mimeType = null; string? uri = null; string? description = null; @@ -120,7 +132,9 @@ public class Converter : JsonConverter break; case "text": - text = reader.GetString(); + // Always read the JSON string token into UTF-8 bytes directly (including unescaping) without + // allocating an intermediate UTF-16 string. The choice of materialized type happens later. + utf8Text = ReadUtf8StringValueAsBytes(ref reader); break; case "name": @@ -128,7 +142,7 @@ public class Converter : JsonConverter break; case "data": - data = reader.GetString(); + dataUtf8 = ReadUtf8StringValueAsBytes(ref reader); break; case "mimeType": @@ -204,20 +218,25 @@ public class Converter : JsonConverter ContentBlock block = type switch { - "text" => new TextContentBlock - { - Text = text ?? throw new JsonException("Text contents must be provided for 'text' type."), - }, + "text" => _materializeUtf8TextContentBlocks + ? new Utf8TextContentBlock + { + Utf8Text = utf8Text ?? throw new JsonException("Text contents must be provided for 'text' type."), + } + : new TextContentBlock + { + Utf8Text = utf8Text ?? throw new JsonException("Text contents must be provided for 'text' type."), + }, "image" => new ImageContentBlock { - Data = data ?? throw new JsonException("Image data must be provided for 'image' type."), + DataUtf8 = dataUtf8 ?? throw new JsonException("Image data must be provided for 'image' type."), MimeType = mimeType ?? throw new JsonException("MIME type must be provided for 'image' type."), }, "audio" => new AudioContentBlock { - Data = data ?? throw new JsonException("Audio data must be provided for 'audio' type."), + DataUtf8 = dataUtf8 ?? throw new JsonException("Audio data must be provided for 'audio' type."), MimeType = mimeType ?? throw new JsonException("MIME type must be provided for 'audio' type."), }, @@ -259,6 +278,24 @@ public class Converter : JsonConverter return block; } + internal static ReadOnlyMemory ReadUtf8StringValueAsBytes(ref Utf8JsonReader reader) + { + if (reader.TokenType != JsonTokenType.String) + { + throw new JsonException(); + } + + // If the JSON string contained no escape sequences, STJ exposes the UTF-8 bytes directly. + if (!reader.ValueIsEscaped) + { + return reader.HasValueSequence ? reader.ValueSequence.ToArray() : reader.ValueSpan.ToArray(); + } + + // The value is escaped (e.g. contains \uXXXX or \n); unescape into UTF-8 bytes. + ReadOnlySpan escaped = reader.HasValueSequence ? reader.ValueSequence.ToArray() : reader.ValueSpan; + return Core.McpTextUtilities.UnescapeJsonStringToUtf8(escaped); + } + /// public override void Write(Utf8JsonWriter writer, ContentBlock value, JsonSerializerOptions options) { @@ -274,17 +311,43 @@ public override void Write(Utf8JsonWriter writer, ContentBlock value, JsonSerial switch (value) { + case Utf8TextContentBlock utf8TextContent: + writer.WriteString("text", utf8TextContent.Utf8Text.Span); + break; + case TextContentBlock textContent: - writer.WriteString("text", textContent.Text); + // Prefer UTF-8 bytes to avoid materializing a UTF-16 string for serialization. + if (!textContent.Utf8Text.IsEmpty) + { + writer.WriteString("text", textContent.Utf8Text.Span); + } + else + { + writer.WriteString("text", textContent.Text); + } break; case ImageContentBlock imageContent: - writer.WriteString("data", imageContent.Data); + if (imageContent.HasDataUtf8) + { + writer.WriteString("data", imageContent.GetDataUtf8Span()); + } + else + { + writer.WriteString("data", imageContent.Data); + } writer.WriteString("mimeType", imageContent.MimeType); break; case AudioContentBlock audioContent: - writer.WriteString("data", audioContent.Data); + if (audioContent.HasDataUtf8) + { + writer.WriteString("data", audioContent.GetDataUtf8Span()); + } + else + { + writer.WriteString("data", audioContent.Data); + } writer.WriteString("mimeType", audioContent.MimeType); break; @@ -359,23 +422,118 @@ public override void Write(Utf8JsonWriter writer, ContentBlock value, JsonSerial [DebuggerDisplay("Text = \"{Text}\"")] public sealed class TextContentBlock : ContentBlock { + private string? _text; + private ReadOnlyMemory _utf8Text; + /// public override string Type => "text"; + /// + /// Gets or sets the UTF-8 encoded text content. + /// + /// + /// This enables avoiding intermediate UTF-16 string materialization when deserializing JSON. + /// Setting this value will invalidate any cached value of . + /// + [JsonIgnore] + public ReadOnlyMemory Utf8Text + { + get => _utf8Text; + set + { + _utf8Text = value; + _text = null; // Invalidate cache + } + } + /// /// Gets or sets the text content of the message. /// + /// + /// The getter lazily materializes and caches a UTF-16 string from . + /// The setter updates . + /// [JsonPropertyName("text")] - public required string Text { get; set; } + public string Text + { + get => _text ??= Core.McpTextUtilities.GetStringFromUtf8(_utf8Text.Span); + set + { + _text = value; + _utf8Text = string.IsNullOrEmpty(value) ? null : System.Text.Encoding.UTF8.GetBytes(value); + } + } /// public override string ToString() => Text ?? ""; } +/// +/// Represents text provided to or from an LLM in pre-encoded UTF-8 form. +/// +/// +/// This type exists to avoid materializing UTF-16 strings in hot paths when the text content is already +/// available as UTF-8 bytes (for example, JSON serialized tool results). +/// +[DebuggerDisplay("Utf8TextLength = {Utf8Text.Length}")] +public sealed class Utf8TextContentBlock : ContentBlock +{ + /// + [JsonPropertyName("type")] + public override string Type => "text"; + + /// Gets or sets the UTF-8 encoded text content. + [JsonIgnore] + public required ReadOnlyMemory Utf8Text { get; set; } + + /// Gets the UTF-16 string representation of . + [JsonPropertyName("text")] + public string Text + { + get + { + return Core.McpTextUtilities.GetStringFromUtf8(Utf8Text.Span); + } + } + + /// Converts a to a . + public static implicit operator TextContentBlock(Utf8TextContentBlock utf8) + { + Throw.IfNull(utf8); + + return new TextContentBlock + { + Text = utf8.Text, + Annotations = utf8.Annotations, + Meta = utf8.Meta, + }; + } + + /// Converts a to a . + public static implicit operator Utf8TextContentBlock(TextContentBlock text) + { + Throw.IfNull(text); + + return new Utf8TextContentBlock + { + Utf8Text = System.Text.Encoding.UTF8.GetBytes(text.Text), + Annotations = text.Annotations, + Meta = text.Meta, + }; + } + + /// + public override string ToString() => Text; +} + /// Represents an image provided to or from an LLM. [DebuggerDisplay("{DebuggerDisplay,nq}")] public sealed class ImageContentBlock : ContentBlock { + private ReadOnlyMemory _dataUtf8; + private ReadOnlyMemory _decodedData; + private string? _data; + /// public override string Type => "image"; @@ -383,7 +541,73 @@ public sealed class ImageContentBlock : ContentBlock /// Gets or sets the base64-encoded image data. /// [JsonPropertyName("data")] - public required string Data { get; set; } + public string Data + { + get => _data ??= !_dataUtf8.IsEmpty + ? Core.McpTextUtilities.GetStringFromUtf8(_dataUtf8.Span) + : string.Empty; + set + { + _data = value; + _dataUtf8 = System.Text.Encoding.UTF8.GetBytes(value); + _decodedData = default; // Invalidate cache + } + } + + /// + /// Gets or sets the base64-encoded UTF-8 bytes representing the value of . + /// + [JsonIgnore] + public ReadOnlyMemory DataUtf8 + { + get => _dataUtf8.IsEmpty + ? _data is null + ? ReadOnlyMemory.Empty + : System.Text.Encoding.UTF8.GetBytes(_data) + : _dataUtf8; + set + { + _data = null; + _dataUtf8 = value; + _decodedData = default; // Invalidate cache + } + } + + /// + /// Gets the decoded image data represented by . + /// + /// + /// Accessing this member will decode the value in and cache the result. + /// Subsequent accesses return the cached value unless or is modified. + /// + [JsonIgnore] + public ReadOnlyMemory DecodedData + { + get + { + if (_decodedData.IsEmpty) + { + if (_data is not null) + { + _decodedData = Convert.FromBase64String(_data); + return _decodedData; + } + + int maxLength = Base64.GetMaxDecodedFromUtf8Length(DataUtf8.Length); + byte[] buffer = new byte[maxLength]; + if (Base64.DecodeFromUtf8(DataUtf8.Span, buffer, out _, out int bytesWritten) == OperationStatus.Done) + { + _decodedData = bytesWritten == maxLength ? buffer : buffer.AsMemory(0, bytesWritten).ToArray(); + } + else + { + throw new FormatException("Invalid base64 data"); + } + } + + return _decodedData; + } + } /// /// Gets or sets the MIME type (or "media type") of the content, specifying the format of the data. @@ -394,6 +618,10 @@ public sealed class ImageContentBlock : ContentBlock [JsonPropertyName("mimeType")] public required string MimeType { get; set; } + internal bool HasDataUtf8 => !_dataUtf8.IsEmpty; + + internal ReadOnlySpan GetDataUtf8Span() => _dataUtf8.Span; + [DebuggerBrowsable(DebuggerBrowsableState.Never)] private string DebuggerDisplay => $"MimeType = {MimeType}, Length = {DebuggerDisplayHelper.GetBase64LengthDisplay(Data)}"; } @@ -402,6 +630,10 @@ public sealed class ImageContentBlock : ContentBlock [DebuggerDisplay("{DebuggerDisplay,nq}")] public sealed class AudioContentBlock : ContentBlock { + private ReadOnlyMemory _dataUtf8; + private ReadOnlyMemory _decodedData; + private string? _data; + /// public override string Type => "audio"; @@ -409,7 +641,73 @@ public sealed class AudioContentBlock : ContentBlock /// Gets or sets the base64-encoded audio data. /// [JsonPropertyName("data")] - public required string Data { get; set; } + public string Data + { + get => _data ??= !_dataUtf8.IsEmpty + ? Core.McpTextUtilities.GetStringFromUtf8(_dataUtf8.Span) + : string.Empty; + set + { + _data = value; + _dataUtf8 = System.Text.Encoding.UTF8.GetBytes(value); + _decodedData = default; // Invalidate cache + } + } + + /// + /// Gets or sets the base64-encoded UTF-8 bytes representing the value of . + /// + [JsonIgnore] + public ReadOnlyMemory DataUtf8 + { + get => _dataUtf8.IsEmpty + ? _data is null + ? ReadOnlyMemory.Empty + : System.Text.Encoding.UTF8.GetBytes(_data) + : _dataUtf8; + set + { + _data = null; + _dataUtf8 = value; + _decodedData = default; // Invalidate cache + } + } + + /// + /// Gets the decoded audio data represented by . + /// + /// + /// Accessing this member will decode the value in and cache the result. + /// Subsequent accesses return the cached value unless or is modified. + /// + [JsonIgnore] + public ReadOnlyMemory DecodedData + { + get + { + if (_decodedData.IsEmpty) + { + if (_data is not null) + { + _decodedData = Convert.FromBase64String(_data); + return _decodedData; + } + + int maxLength = Base64.GetMaxDecodedFromUtf8Length(DataUtf8.Length); + byte[] buffer = new byte[maxLength]; + if (Base64.DecodeFromUtf8(DataUtf8.Span, buffer, out _, out int bytesWritten) == OperationStatus.Done) + { + _decodedData = bytesWritten == maxLength ? buffer : buffer.AsMemory(0, bytesWritten).ToArray(); + } + else + { + throw new FormatException("Invalid base64 data"); + } + } + + return _decodedData; + } + } /// /// Gets or sets the MIME type (or "media type") of the content, specifying the format of the data. @@ -420,6 +718,10 @@ public sealed class AudioContentBlock : ContentBlock [JsonPropertyName("mimeType")] public required string MimeType { get; set; } + internal bool HasDataUtf8 => !_dataUtf8.IsEmpty; + + internal ReadOnlySpan GetDataUtf8Span() => _dataUtf8.Span; + [DebuggerBrowsable(DebuggerBrowsableState.Never)] private string DebuggerDisplay => $"MimeType = {MimeType}, Length = {DebuggerDisplayHelper.GetBase64LengthDisplay(Data)}"; } diff --git a/src/ModelContextProtocol.Core/Protocol/DebuggerDisplayHelper.cs b/src/ModelContextProtocol.Core/Protocol/DebuggerDisplayHelper.cs index 77f719309..8498b95ae 100644 --- a/src/ModelContextProtocol.Core/Protocol/DebuggerDisplayHelper.cs +++ b/src/ModelContextProtocol.Core/Protocol/DebuggerDisplayHelper.cs @@ -27,4 +27,30 @@ internal static string GetBase64LengthDisplay(string base64Data) return "invalid base64"; } + + /// + /// Gets the decoded length of base64 data (encoded as UTF-8 bytes) for debugger display. + /// + internal static string GetBase64LengthDisplay(ReadOnlySpan base64Utf8Data) + { +#if NET + if (System.Buffers.Text.Base64.IsValid(base64Utf8Data, out int decodedLength)) + { + return $"{decodedLength} bytes"; + } +#else + int len = base64Utf8Data.Length; + if (len != 0 && (len & 3) == 0) + { + int padding = 0; + if (base64Utf8Data[^1] == (byte)'=') padding++; + if (len > 1 && base64Utf8Data[^2] == (byte)'=') padding++; + + int decodedLength = (len / 4) * 3 - padding; + return $"{decodedLength} bytes"; + } +#endif + + return "invalid base64"; + } } diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs index a0220b09d..e944958de 100644 --- a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs @@ -78,36 +78,83 @@ public sealed class Converter : JsonConverter throw new JsonException("Expected StartObject token"); } - using var doc = JsonDocument.ParseValue(ref reader); - var root = doc.RootElement; + // We need to determine the concrete message type without round-tripping the payload + // through a UTF-16 string (e.g. JsonElement.GetRawText()). + var lookahead = reader; + + bool hasId = false; + bool hasMethod = false; + bool hasError = false; + bool hasResult = false; + bool foundJsonRpc = false; + + // Scan the top-level object using a copy of the reader. + while (lookahead.Read()) + { + if (lookahead.TokenType == JsonTokenType.EndObject) + { + break; + } + + if (lookahead.TokenType != JsonTokenType.PropertyName) + { + throw new JsonException("Expected PropertyName token"); + } + + bool isJsonRpc = lookahead.ValueTextEquals("jsonrpc"u8); + bool isId = lookahead.ValueTextEquals("id"u8); + bool isMethod = lookahead.ValueTextEquals("method"u8); + bool isError = lookahead.ValueTextEquals("error"u8); + bool isResult = lookahead.ValueTextEquals("result"u8); + + if (!lookahead.Read()) + { + throw new JsonException("Unexpected end of JSON"); + } + + if (isJsonRpc) + { + foundJsonRpc = lookahead.TokenType == JsonTokenType.String && lookahead.ValueTextEquals("2.0"u8); + } + else if (isId) + { + hasId = true; + } + else if (isMethod) + { + hasMethod = true; + } + else if (isError) + { + hasError = true; + } + else if (isResult) + { + hasResult = true; + } + + SkipValue(ref lookahead); + } // All JSON-RPC messages must have a jsonrpc property with value "2.0" - if (!root.TryGetProperty("jsonrpc", out var versionProperty) || - versionProperty.GetString() != "2.0") + if (!foundJsonRpc) { throw new JsonException("Invalid or missing jsonrpc version"); } - // Determine the message type based on the presence of id, method, and error properties - bool hasId = root.TryGetProperty("id", out _); - bool hasMethod = root.TryGetProperty("method", out _); - bool hasError = root.TryGetProperty("error", out _); - - var rawText = root.GetRawText(); - // Messages with an id but no method are responses if (hasId && !hasMethod) { // Messages with an error property are error responses if (hasError) { - return JsonSerializer.Deserialize(rawText, options.GetTypeInfo()); + return JsonSerializer.Deserialize(ref reader, options.GetTypeInfo()); } // Messages with a result property are success responses - if (root.TryGetProperty("result", out _)) + if (hasResult) { - return JsonSerializer.Deserialize(rawText, options.GetTypeInfo()); + return JsonSerializer.Deserialize(ref reader, options.GetTypeInfo()); } throw new JsonException("Response must have either result or error"); @@ -116,18 +163,38 @@ public sealed class Converter : JsonConverter // Messages with a method but no id are notifications if (hasMethod && !hasId) { - return JsonSerializer.Deserialize(rawText, options.GetTypeInfo()); + return JsonSerializer.Deserialize(ref reader, options.GetTypeInfo()); } // Messages with both method and id are requests if (hasMethod && hasId) { - return JsonSerializer.Deserialize(rawText, options.GetTypeInfo()); + return JsonSerializer.Deserialize(ref reader, options.GetTypeInfo()); } throw new JsonException("Invalid JSON-RPC message format"); } + private static void SkipValue(ref Utf8JsonReader reader) + { + if (reader.TokenType is JsonTokenType.StartObject or JsonTokenType.StartArray) + { + int depth = 0; + do + { + if (reader.TokenType is JsonTokenType.StartObject or JsonTokenType.StartArray) + { + depth++; + } + else if (reader.TokenType is JsonTokenType.EndObject or JsonTokenType.EndArray) + { + depth--; + } + } + while (depth > 0 && reader.Read()); + } + } + /// public override void Write(Utf8JsonWriter writer, JsonRpcMessage value, JsonSerializerOptions options) { diff --git a/src/ModelContextProtocol.Core/Protocol/ResourceContents.cs b/src/ModelContextProtocol.Core/Protocol/ResourceContents.cs index 9c295a1f8..cebac139d 100644 --- a/src/ModelContextProtocol.Core/Protocol/ResourceContents.cs +++ b/src/ModelContextProtocol.Core/Protocol/ResourceContents.cs @@ -78,7 +78,7 @@ public class Converter : JsonConverter string? uri = null; string? mimeType = null; - string? blob = null; + ReadOnlyMemory? blobUtf8 = null; string? text = null; JsonObject? meta = null; @@ -104,7 +104,7 @@ public class Converter : JsonConverter break; case "blob": - blob = reader.GetString(); + blobUtf8 = ContentBlock.Converter.ReadUtf8StringValueAsBytes(ref reader); break; case "text": @@ -121,13 +121,13 @@ public class Converter : JsonConverter } } - if (blob is not null) + if (blobUtf8 is not null) { return new BlobResourceContents { Uri = uri ?? string.Empty, MimeType = mimeType, - Blob = blob, + BlobUtf8 = blobUtf8.Value, Meta = meta, }; } @@ -162,7 +162,14 @@ public override void Write(Utf8JsonWriter writer, ResourceContents value, JsonSe Debug.Assert(value is BlobResourceContents or TextResourceContents); if (value is BlobResourceContents blobResource) { - writer.WriteString("blob", blobResource.Blob); + if (blobResource.HasBlobUtf8) + { + writer.WriteString("blob", blobResource.GetBlobUtf8Span()); + } + else + { + writer.WriteString("blob", blobResource.Blob); + } } else if (value is TextResourceContents textResource) { diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs index fcd855de9..6e5d5a048 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs @@ -391,7 +391,12 @@ public override async ValueTask ReadAsync( DataContent dc => new() { - Contents = [new BlobResourceContents { Uri = request.Params!.Uri, MimeType = dc.MediaType, Blob = dc.Base64Data.ToString() }], + Contents = [new BlobResourceContents + { + Uri = request.Params!.Uri, + MimeType = dc.MediaType, + DecodedData = dc.Data + }], }, string text => new() @@ -420,7 +425,7 @@ public override async ValueTask ReadAsync( { Uri = request.Params!.Uri, MimeType = dc.MediaType, - Blob = dc.Base64Data.ToString() + DecodedData = dc.Data }, _ => throw new InvalidOperationException($"Unsupported AIContent type '{ac.GetType()}' returned from resource function."), @@ -442,4 +447,5 @@ public override async ValueTask ReadAsync( _ => throw new InvalidOperationException($"Unsupported result type '{result.GetType()}' returned from resource function."), }; } + } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs index 571bc3c04..e38536d0e 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs @@ -282,9 +282,14 @@ public override async ValueTask InvokeAsync( CallToolResult callToolResponse => callToolResponse, - _ => new() + _ => structuredContent is not null ? new() { - Content = [new TextContentBlock { Text = JsonSerializer.Serialize(result, AIFunction.JsonSerializerOptions.GetTypeInfo(typeof(object))) }], + Content = [], + StructuredContent = structuredContent, + } : new() + { + // Avoid staging the JSON payload as a UTF-16 string. + Content = [new Utf8TextContentBlock { Utf8Text = JsonSerializer.SerializeToUtf8Bytes(result, AIFunction.JsonSerializerOptions.GetTypeInfo(typeof(object))) }], StructuredContent = structuredContent, }, }; diff --git a/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs b/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs index 7caabf686..43e59efe2 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs @@ -316,13 +316,25 @@ public async ValueTask> ElicitAsync( return new ElicitResult { Action = raw.Action, Content = default }; } - JsonObject obj = []; - foreach (var kvp in raw.Content) + using var stream = new MemoryStream(); + using (var writer = new Utf8JsonWriter(stream)) { - obj[kvp.Key] = JsonNode.Parse(kvp.Value.GetRawText()); + writer.WriteStartObject(); + foreach (var kvp in raw.Content) + { + writer.WritePropertyName(kvp.Key); + kvp.Value.WriteTo(writer); + } + writer.WriteEndObject(); + writer.Flush(); + } + + if (!stream.TryGetBuffer(out ArraySegment segment)) + { + throw new InvalidOperationException("Expected MemoryStream to expose its buffer."); } - T? typed = JsonSerializer.Deserialize(obj, serializerOptions.GetTypeInfo()); + T? typed = JsonSerializer.Deserialize(new ReadOnlySpan(segment.Array!, segment.Offset, (int)stream.Length), serializerOptions.GetTypeInfo()); return new ElicitResult { Action = raw.Action, Content = typed }; } diff --git a/src/ModelContextProtocol.Core/Server/StreamServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamServerTransport.cs index 7747d7f18..470de023a 100644 --- a/src/ModelContextProtocol.Core/Server/StreamServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamServerTransport.cs @@ -1,5 +1,6 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Core; using ModelContextProtocol.Protocol; using System.Text; using System.Text.Json; @@ -20,11 +21,13 @@ public class StreamServerTransport : TransportBase private readonly ILogger _logger; - private readonly TextReader _inputReader; + private readonly Stream _inputStream; private readonly Stream _outputStream; private readonly SemaphoreSlim _sendLock = new(1, 1); - private readonly CancellationTokenSource _shutdownCts = new(); + + // Intentionally not disposed; once this transport instance is collectable, CTS finalization will clean up. + private CancellationTokenSource _shutdownCts = new(); private readonly Task _readLoopCompleted; private int _disposed = 0; @@ -45,11 +48,7 @@ public StreamServerTransport(Stream inputStream, Stream outputStream, string? se _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; -#if NET - _inputReader = new StreamReader(inputStream, Encoding.UTF8); -#else - _inputReader = new CancellableStreamReader(inputStream, Encoding.UTF8); -#endif + _inputStream = inputStream; _outputStream = outputStream; SetConnected(); @@ -87,52 +86,104 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation private async Task ReadMessagesAsync() { - CancellationToken shutdownToken = _shutdownCts.Token; + //CancellationToken shutdownToken = _shutdownCts.Token; // the cts field is not read-only, will be defused Exception? error = null; try { LogTransportEnteringReadMessagesLoop(Name); - while (!shutdownToken.IsCancellationRequested) + byte[] buffer = System.Buffers.ArrayPool.Shared.Rent(16 * 1024); + try { - var line = await _inputReader.ReadLineAsync(shutdownToken).ConfigureAwait(false); - if (string.IsNullOrWhiteSpace(line)) + using var lineStream = new MemoryStream(); + + while (!_shutdownCts.Token.IsCancellationRequested) { - if (line is null) + int bytesRead = await _inputStream.ReadAsync(buffer, 0, buffer.Length, _shutdownCts.Token).ConfigureAwait(false); + if (bytesRead == 0) { LogTransportEndOfStream(Name); break; } - continue; - } - - LogTransportReceivedMessageSensitive(Name, line); - - try - { - if (JsonSerializer.Deserialize(line, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) is JsonRpcMessage message) - { - await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false); - } - else + int offset = 0; + while (offset < bytesRead) { - LogTransportMessageParseUnexpectedTypeSensitive(Name, line); + int newlineIndex = Array.IndexOf(buffer, (byte)'\n', offset, bytesRead - offset); + if (newlineIndex < 0) + { + lineStream.Write(buffer, offset, bytesRead - offset); + break; + } + + int partLength = newlineIndex - offset; + if (partLength > 0) + { + lineStream.Write(buffer, offset, partLength); + } + + offset = newlineIndex + 1; + + if (!lineStream.TryGetBuffer(out ArraySegment segment)) + { + throw new InvalidOperationException("Expected MemoryStream to expose its buffer."); + } + + ReadOnlySpan lineBytes = new(segment.Array!, segment.Offset, (int)lineStream.Length); + + if (!lineBytes.IsEmpty && lineBytes[^1] == (byte)'\r') + { + lineBytes = lineBytes[..^1]; + } + + lineStream.SetLength(0); + + if (McpTextUtilities.IsWhiteSpace(lineBytes)) + { + continue; + } + + string? lineForLogs = null; + if (Logger.IsEnabled(LogLevel.Trace)) + { +lineForLogs = McpTextUtilities.GetStringFromUtf8(lineBytes); + } + if (lineForLogs is not null) + { + LogTransportReceivedMessageSensitive(Name, lineForLogs); + } + + try + { + var reader = new Utf8JsonReader(lineBytes); + if (JsonSerializer.Deserialize(ref reader, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) is JsonRpcMessage message) + { + await WriteMessageAsync(message, _shutdownCts.Token).ConfigureAwait(false); + } + else if (lineForLogs is not null) + { + LogTransportMessageParseUnexpectedTypeSensitive(Name, lineForLogs); + } + } + catch (JsonException ex) + { + if (Logger.IsEnabled(LogLevel.Trace) && lineForLogs is not null) + { + LogTransportMessageParseFailedSensitive(Name, lineForLogs, ex); + } + else + { + LogTransportMessageParseFailed(Name, ex); + } + + // Continue reading even if we fail to parse a message + } } } - catch (JsonException ex) - { - if (Logger.IsEnabled(LogLevel.Trace)) - { - LogTransportMessageParseFailedSensitive(Name, line, ex); - } - else - { - LogTransportMessageParseFailed(Name, ex); - } - - // Continue reading even if we fail to parse a message - } + } + finally + { + System.Buffers.ArrayPool.Shared.Return(buffer); } } catch (OperationCanceledException) @@ -164,11 +215,11 @@ public override async ValueTask DisposeAsync() // Signal to the stdin reading loop to stop. await _shutdownCts.CancelAsync().ConfigureAwait(false); - _shutdownCts.Dispose(); + CanceledTokenSource.Defuse(ref _shutdownCts, dispose: true); // Dispose of stdin/out. Cancellation may not be able to wake up operations // synchronously blocked in a syscall; we need to forcefully close the handle / file descriptor. - _inputReader?.Dispose(); + _inputStream?.Dispose(); _outputStream?.Dispose(); // Make sure the work has quiesced. diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index c99b1fa39..dfbae1d7f 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -1,3 +1,4 @@ +using ModelContextProtocol.Core; using ModelContextProtocol.Protocol; using System.IO.Pipelines; using System.Security.Claims; @@ -33,7 +34,7 @@ public sealed class StreamableHttpServerTransport : ITransport SingleReader = true, SingleWriter = false, }); - private readonly CancellationTokenSource _disposeCts = new(); + private CancellationTokenSource _disposeCts = new(); private int _getRequestStarted; @@ -157,7 +158,7 @@ public async ValueTask DisposeAsync() } finally { - _disposeCts.Dispose(); + CanceledTokenSource.Defuse(ref _disposeCts); } } } diff --git a/src/ModelContextProtocol.Core/UriTemplate.cs b/src/ModelContextProtocol.Core/UriTemplate.cs index 447ec004c..1cf21e791 100644 --- a/src/ModelContextProtocol.Core/UriTemplate.cs +++ b/src/ModelContextProtocol.Core/UriTemplate.cs @@ -388,6 +388,8 @@ private static void AppendJoin(ref DefaultInterpolatedStringHandler builder, str } } + public static ReadOnlySpan HexDigits => "0123456789ABCDEF"; + private static string Encode(string value, bool allowReserved) { if (!allowReserved) @@ -423,36 +425,60 @@ private static string Encode(string value, bool allowReserved) } else { - AppendHex(ref builder, c); + if (c <= 0x7F) + { + AppendHexAscii(ref builder, c); + } + else if (char.IsHighSurrogate(c) && i + 1 < value.Length && char.IsLowSurrogate(value[i + 1])) + { + AppendHexUtf8(ref builder, value.AsSpan(i, 2)); + i++; + } + else + { + AppendHexUtf8(ref builder, value.AsSpan(i, 1)); + } } } return builder.ToStringAndClear(); - static void AppendHex(ref DefaultInterpolatedStringHandler builder, char c) + static void AppendHexAscii(ref DefaultInterpolatedStringHandler builder, char c) { - ReadOnlySpan hexDigits = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F']; + builder.AppendFormatted('%'); + builder.AppendFormatted(HexDigits[c >> 4]); + builder.AppendFormatted(HexDigits[c & 0xF]); + } - if (c <= 0x7F) + static void AppendHexUtf8(ref DefaultInterpolatedStringHandler builder, ReadOnlySpan chars) + { + Span utf8 = stackalloc byte[4]; + +#if NET + int bytesWritten = Encoding.UTF8.GetBytes(chars, utf8); + for (int j = 0; j < bytesWritten; j++) { + byte b = utf8[j]; builder.AppendFormatted('%'); - builder.AppendFormatted(hexDigits[c >> 4]); - builder.AppendFormatted(hexDigits[c & 0xF]); + builder.AppendFormatted(HexDigits[b >> 4]); + builder.AppendFormatted(HexDigits[b & 0xF]); } - else - { -#if NET - Span utf8 = stackalloc byte[Encoding.UTF8.GetMaxByteCount(1)]; - foreach (byte b in utf8.Slice(0, new Rune(c).EncodeToUtf8(utf8))) #else - foreach (byte b in Encoding.UTF8.GetBytes([c])) -#endif + unsafe { + fixed (char* pChars = chars) + fixed (byte* pUtf8 = utf8) { - builder.AppendFormatted('%'); - builder.AppendFormatted(hexDigits[b >> 4]); - builder.AppendFormatted(hexDigits[b & 0xF]); + int bytesWritten = Encoding.UTF8.GetBytes(pChars, chars.Length, pUtf8, utf8.Length); + for (int j = 0; j < bytesWritten; j++) + { + byte b = utf8[j]; + builder.AppendFormatted('%'); + builder.AppendFormatted(HexDigits[b >> 4]); + builder.AppendFormatted(HexDigits[b & 0xF]); + } } } +#endif } } diff --git a/tests/Common/Utils/ProcessStartInfoUtilities.cs b/tests/Common/Utils/ProcessStartInfoUtilities.cs new file mode 100644 index 000000000..84ebcb1e7 --- /dev/null +++ b/tests/Common/Utils/ProcessStartInfoUtilities.cs @@ -0,0 +1,138 @@ +using System.Diagnostics; +using System.Runtime.InteropServices; + +namespace ModelContextProtocol.Tests.Utils; + +internal static class ProcessStartInfoUtilities +{ + private static bool IsWindows => +#if NET + OperatingSystem.IsWindows(); +#else + RuntimeInformation.IsOSPlatform(OSPlatform.Windows); +#endif + + public static ProcessStartInfo CreateOnPath( + string fileName, + string? arguments = null, + bool redirectStandardInput = false, + bool redirectStandardOutput = true, + bool redirectStandardError = true, + bool useShellExecute = false, + bool createNoWindow = true) + { + string resolved = FindOnPath(fileName) ?? throw new InvalidOperationException($"{fileName} was not found on PATH."); + + if (IsWindows && !useShellExecute && + (resolved.EndsWith(".cmd", StringComparison.OrdinalIgnoreCase) || resolved.EndsWith(".bat", StringComparison.OrdinalIgnoreCase))) + { + // Batch files require cmd.exe when UseShellExecute=false. + return new ProcessStartInfo + { + FileName = "cmd.exe", + Arguments = $"/d /s /c \"\"{resolved}\" {arguments ?? string.Empty}\"", + RedirectStandardInput = redirectStandardInput, + RedirectStandardOutput = redirectStandardOutput, + RedirectStandardError = redirectStandardError, + UseShellExecute = false, + CreateNoWindow = createNoWindow, + }; + } + + return new ProcessStartInfo + { + FileName = resolved, + Arguments = arguments ?? string.Empty, + RedirectStandardInput = redirectStandardInput, + RedirectStandardOutput = redirectStandardOutput, + RedirectStandardError = redirectStandardError, + UseShellExecute = useShellExecute, + CreateNoWindow = createNoWindow, + }; + } + + public static string? FindOnPath(string fileName) + { + if (Path.IsPathRooted(fileName)) + { + return File.Exists(fileName) ? fileName : null; + } + + string? path = Environment.GetEnvironmentVariable("PATH"); + if (string.IsNullOrEmpty(path)) + { + return null; + } + + string[] extensions; + if (IsWindows) + { + // Match cmd.exe resolution semantics by honoring PATHEXT. + string? pathext = Environment.GetEnvironmentVariable("PATHEXT"); + if (string.IsNullOrWhiteSpace(pathext)) + { + extensions = [".EXE", ".CMD", ".BAT"]; + } + else + { + string[] raw = pathext.Split(';'); + var list = new List(raw.Length); + foreach (string ext in raw) + { + string trimmed = ext.Trim(); + if (!string.IsNullOrEmpty(trimmed)) + { + list.Add(trimmed); + } + } + + extensions = list.ToArray(); + } + } + else + { + extensions = []; + } + + bool hasExtension = Path.HasExtension(fileName); + + foreach (string dir in path.Split(Path.PathSeparator)) + { + if (string.IsNullOrWhiteSpace(dir)) + { + continue; + } + + string trimmedDir = dir.Trim().Trim('"'); + + if (!IsWindows || hasExtension) + { + string fullPath = Path.Combine(trimmedDir, fileName); + if (File.Exists(fullPath)) + { + return fullPath; + } + + continue; + } + + foreach (string ext in extensions) + { + string fullPath = Path.Combine(trimmedDir, fileName + ext); + if (File.Exists(fullPath)) + { + return fullPath; + } + } + + // Also consider no-extension in case it exists directly on disk. + string noExtPath = Path.Combine(trimmedDir, fileName); + if (File.Exists(noExtPath)) + { + return noExtPath; + } + } + + return null; + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs index 8c0055fe8..bf8d485a1 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs @@ -117,15 +117,7 @@ private void StartConformanceServer() private async Task<(bool Success, string Output, string Error)> RunNpxConformanceTests() { - var startInfo = new ProcessStartInfo - { - FileName = "npx", - Arguments = $"-y @modelcontextprotocol/conformance server --url {_serverUrl}", - RedirectStandardOutput = true, - RedirectStandardError = true, - UseShellExecute = false, - CreateNoWindow = true - }; + var startInfo = CreateNpxStartInfo($"-y @modelcontextprotocol/conformance server --url {_serverUrl}"); var outputBuilder = new StringBuilder(); var errorBuilder = new StringBuilder(); @@ -167,17 +159,9 @@ private static bool IsNodeInstalled() { try { - var startInfo = new ProcessStartInfo - { - FileName = "npx", // Check specifically for npx because windows seems unable to find it - Arguments = "--version", - RedirectStandardOutput = true, - RedirectStandardError = true, - UseShellExecute = false, - CreateNoWindow = true - }; - - using var process = Process.Start(startInfo); + // Check specifically for npx because on Windows the npm install provides npx.cmd and + // CreateProcess won't resolve it from just "npx". + using var process = Process.Start(CreateNpxStartInfo("--version")); if (process == null) { return false; @@ -191,4 +175,7 @@ private static bool IsNodeInstalled() return false; } } + + private static ProcessStartInfo CreateNpxStartInfo(string npxArguments) => + ProcessStartInfoUtilities.CreateOnPath("npx", npxArguments); } diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index 1321c5f62..b1d5e5f59 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -382,7 +382,7 @@ private static void ConfigureResources(McpServerOptions options) { try { - var startIndexAsString = Encoding.UTF8.GetString(Convert.FromBase64String(request.Params.Cursor)); + var startIndexAsString = Core.McpTextUtilities.GetStringFromUtf8(Convert.FromBase64String(request.Params.Cursor)); startIndex = Convert.ToInt32(startIndexAsString); } catch (Exception e) diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index a29c30587..e3e860b85 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -221,7 +221,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { try { - var startIndexAsString = Encoding.UTF8.GetString(Convert.FromBase64String(requestParams.Cursor)); + var startIndexAsString = Core.McpTextUtilities.GetStringFromUtf8(Convert.FromBase64String(requestParams.Cursor)); startIndex = Convert.ToInt32(startIndexAsString); } catch (Exception e) diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientToolTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientToolTests.cs index c1b2bcf25..f8cd62e0b 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientToolTests.cs @@ -1,9 +1,10 @@ +using System.Text; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Client; +using ModelContextProtocol.Core; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; -using System.Text; using System.Text.Json; using System.Text.Json.Nodes; @@ -11,6 +12,9 @@ namespace ModelContextProtocol.Tests.Client; public class McpClientToolTests : ClientServerTestBase { + private static string GetUtf8String(ReadOnlyMemory bytes) => + McpTextUtilities.GetStringFromUtf8(bytes.Span); + public McpClientToolTests(ITestOutputHelper outputHelper) : base(outputHelper) { @@ -23,6 +27,10 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer private class TestTools { + private static readonly string Base64FakeImageData = Convert.ToBase64String("fake-image-data"u8.ToArray()); + private static readonly string Base64FakeAudioData = Convert.ToBase64String("fake-audio-data"u8.ToArray()); + private static readonly string Base64ImageData = Convert.ToBase64String("image-data"u8.ToArray()); + // Tool that returns only text content [McpServerTool] public static TextContentBlock TextOnlyTool() => @@ -37,13 +45,13 @@ public static TextContentBlock TextOnlyTool() => [McpServerTool] public static ImageContentBlock ImageTool() => new() - { Data = Convert.ToBase64String(Encoding.UTF8.GetBytes("fake-image-data")), MimeType = "image/png" }; + { Data = Base64FakeImageData, MimeType = "image/png" }; // Tool that returns audio content as single ContentBlock [McpServerTool] public static AudioContentBlock AudioTool() => new() - { Data = Convert.ToBase64String(Encoding.UTF8.GetBytes("fake-audio-data")), MimeType = "audio/mp3" }; + { Data = Base64FakeAudioData, MimeType = "audio/mp3" }; // Tool that returns embedded resource [McpServerTool] @@ -56,15 +64,15 @@ public static EmbeddedResourceBlock EmbeddedResourceTool() => public static IEnumerable MixedContentTool() { yield return new TextContent("Description of the image"); - yield return new DataContent(Encoding.UTF8.GetBytes("fake-image-data"), "image/png"); + yield return new DataContent("fake-image-data"u8.ToArray(), "image/png"); } // Tool that returns multiple images using IEnumerable [McpServerTool] public static IEnumerable MultipleImagesTool() { - yield return new DataContent(Encoding.UTF8.GetBytes("image1"), "image/png"); - yield return new DataContent(Encoding.UTF8.GetBytes("image2"), "image/jpeg"); + yield return new DataContent("image1"u8.ToArray(), "image/png"); + yield return new DataContent("image2"u8.ToArray(), "image/jpeg"); } // Tool that returns audio + text using IEnumerable @@ -72,7 +80,7 @@ public static IEnumerable MultipleImagesTool() public static IEnumerable AudioWithTextTool() { yield return new TextContent("Audio transcription"); - yield return new DataContent(Encoding.UTF8.GetBytes("fake-audio"), "audio/wav"); + yield return new DataContent("fake-audio"u8.ToArray(), "audio/wav"); } // Tool that returns embedded resource + text using IEnumerable @@ -88,9 +96,9 @@ public static IEnumerable ResourceWithTextTool() public static IEnumerable AllContentTypesTool() { yield return new TextContent("Mixed content"); - yield return new DataContent(Encoding.UTF8.GetBytes("image"), "image/png"); - yield return new DataContent(Encoding.UTF8.GetBytes("audio"), "audio/mp3"); - yield return new DataContent(Encoding.UTF8.GetBytes("blob"), "application/octet-stream"); + yield return new DataContent("image"u8.ToArray(), "image/png"); + yield return new DataContent("audio"u8.ToArray(), "audio/mp3"); + yield return new DataContent("blob"u8.ToArray(), "application/octet-stream"); } // Tool that returns content that can't be converted to AIContent (ResourceLinkBlock) @@ -103,7 +111,7 @@ public static ResourceLinkBlock ResourceLinkTool() => [McpServerTool] public static IEnumerable MixedWithNonConvertibleTool() { - yield return new ImageContentBlock { Data = Convert.ToBase64String(Encoding.UTF8.GetBytes("image-data")), MimeType = "image/png" }; + yield return new ImageContentBlock { Data = Base64ImageData, MimeType = "image/png" }; yield return new ResourceLinkBlock { Uri = "file://linked.txt", Name = "linked.txt" }; } @@ -152,7 +160,7 @@ public static EmbeddedResourceBlock BinaryResourceTool() => Resource = new BlobResourceContents { Uri = "data://blob", - Blob = Convert.ToBase64String(Encoding.UTF8.GetBytes("binary-data")), + Blob = Convert.ToBase64String("binary-data"u8.ToArray()), MimeType = "application/octet-stream" } }; @@ -207,7 +215,7 @@ public async Task ImageTool_ReturnsSingleDataContent() var dataContent = Assert.IsType(result); Assert.Equal("image/png", dataContent.MediaType); - Assert.Equal("fake-image-data", Encoding.UTF8.GetString(dataContent.Data.ToArray())); + Assert.Equal("fake-image-data", GetUtf8String(dataContent.Data)); } [Fact] @@ -221,7 +229,7 @@ public async Task AudioTool_ReturnsSingleDataContent() var dataContent = Assert.IsType(result); Assert.Equal("audio/mp3", dataContent.MediaType); - Assert.Equal("fake-audio-data", Encoding.UTF8.GetString(dataContent.Data.ToArray())); + Assert.Equal("fake-audio-data", GetUtf8String(dataContent.Data)); } [Fact] @@ -270,11 +278,11 @@ public async Task MultipleImagesTool_ReturnsAIContentArray() var dataContent0 = Assert.IsType(aiContents[0]); Assert.Equal("image/png", dataContent0.MediaType); - Assert.Equal("image1", Encoding.UTF8.GetString(dataContent0.Data.ToArray())); + Assert.Equal("image1", GetUtf8String(dataContent0.Data)); var dataContent1 = Assert.IsType(aiContents[1]); Assert.Equal("image/jpeg", dataContent1.MediaType); - Assert.Equal("image2", Encoding.UTF8.GetString(dataContent1.Data.ToArray())); + Assert.Equal("image2", GetUtf8String(dataContent1.Data)); } [Fact] @@ -479,7 +487,7 @@ public async Task BinaryResourceTool_ReturnsSingleDataContent() var dataContent = Assert.IsType(result); Assert.Equal("application/octet-stream", dataContent.MediaType); - Assert.Equal("binary-data", Encoding.UTF8.GetString(dataContent.Data.ToArray())); + Assert.Equal("binary-data", GetUtf8String(dataContent.Data)); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 018e12dbe..2d4ce7620 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -262,7 +262,7 @@ public async Task ReadResource_Stdio_BinaryResource(string clientId) Assert.Single(result.Contents); BlobResourceContents blobResource = Assert.IsType(result.Contents[0]); - Assert.NotNull(blobResource.Blob); + Assert.False(string.IsNullOrEmpty(blobResource.Blob)); } // Not supported by "everything" server version on npx diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs index 55a3b4932..0bd4b6d19 100644 --- a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -163,6 +163,7 @@ public class LoggingStream : Stream { private readonly Stream _innerStream; private readonly Action _logAction; + private readonly MemoryStream _pending = new(); public LoggingStream(Stream innerStream, Action logAction) { @@ -172,11 +173,90 @@ public LoggingStream(Stream innerStream, Action logAction) public override void Write(byte[] buffer, int offset, int count) { - var data = Encoding.UTF8.GetString(buffer, offset, count); - _logAction(data); + Log(buffer, offset, count); _innerStream.Write(buffer, offset, count); } + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + Log(buffer, offset, count); + return _innerStream.WriteAsync(buffer, offset, count, cancellationToken); + } + +#if NET + public override void Write(ReadOnlySpan buffer) + { + Log(buffer); + _innerStream.Write(buffer); + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + Log(buffer.Span); + return _innerStream.WriteAsync(buffer, cancellationToken); + } +#endif + + private void Log(byte[] buffer, int offset, int count) + { + int end = offset + count; + int segmentStart = offset; + + for (int i = offset; i < end; i++) + { + if (buffer[i] == (byte)'\n') + { + int segmentLen = i - segmentStart + 1; + if (segmentLen > 0) + { + _pending.Write(buffer, segmentStart, segmentLen); + } + + FlushPending(); + segmentStart = i + 1; + } + } + + if (segmentStart < end) + { + _pending.Write(buffer, segmentStart, end - segmentStart); + } + } + +#if NET + private void Log(ReadOnlySpan buffer) + { + while (!buffer.IsEmpty) + { + int newlineIndex = buffer.IndexOf((byte)'\n'); + if (newlineIndex < 0) + { + _pending.Write(buffer); + return; + } + + int len = newlineIndex + 1; + _pending.Write(buffer.Slice(0, len)); + FlushPending(); + buffer = buffer.Slice(len); + } + } +#endif + + private void FlushPending() + { + if (_pending.Length == 0) + { + return; + } + + // MemoryStream created by this test is expandable, so GetBuffer() is supported and avoids a ToArray() allocation. + byte[] buffer = _pending.GetBuffer(); + _logAction(Core.McpTextUtilities.GetStringFromUtf8(buffer.AsSpan(0, (int)_pending.Length))); + + _pending.SetLength(0); + } + public override bool CanRead => _innerStream.CanRead; public override bool CanSeek => _innerStream.CanSeek; public override bool CanWrite => _innerStream.CanWrite; diff --git a/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs b/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs index 7a019c896..88dcc6ef4 100644 --- a/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs +++ b/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs @@ -1,4 +1,5 @@ using System.Diagnostics; +using ModelContextProtocol.Tests.Utils; namespace ModelContextProtocol.Tests; @@ -18,15 +19,14 @@ public EverythingSseServerFixture(int port) public async Task StartAsync() { - var processStartInfo = new ProcessStartInfo - { - FileName = "docker", - Arguments = $"run -p {_port}:3001 --name {_containerName} --rm tzolov/mcp-everything-server:v1", - RedirectStandardInput = true, - RedirectStandardOutput = true, - RedirectStandardError = true, - UseShellExecute = false, - }; + var processStartInfo = ProcessStartInfoUtilities.CreateOnPath( + "docker", + $"run -p {_port}:3001 --name {_containerName} --rm tzolov/mcp-everything-server:v1", + redirectStandardInput: true, + redirectStandardOutput: true, + redirectStandardError: true, + useShellExecute: false, + createNoWindow: true); _ = Process.Start(processStartInfo) ?? throw new InvalidOperationException($"Could not start process for {processStartInfo.FileName} with '{processStartInfo.Arguments}'."); @@ -40,12 +40,13 @@ public async ValueTask DisposeAsync() { // Stop the container - var stopInfo = new ProcessStartInfo - { - FileName = "docker", - Arguments = $"stop {_containerName}", - UseShellExecute = false - }; + var stopInfo = ProcessStartInfoUtilities.CreateOnPath( + "docker", + $"stop {_containerName}", + redirectStandardOutput: false, + redirectStandardError: false, + useShellExecute: false, + createNoWindow: true); using var stopProcess = Process.Start(stopInfo) ?? throw new InvalidOperationException($"Could not stop process for {stopInfo.FileName} with '{stopInfo.Arguments}'."); @@ -63,13 +64,13 @@ private static bool CheckIsDockerAvailable() #if NET try { - ProcessStartInfo processStartInfo = new() - { - FileName = "docker", - // "docker info" returns a non-zero exit code if docker engine is not running. - Arguments = "info", - UseShellExecute = false, - }; + ProcessStartInfo processStartInfo = ProcessStartInfoUtilities.CreateOnPath( + "docker", + "info", + redirectStandardOutput: false, + redirectStandardError: false, + useShellExecute: false, + createNoWindow: true); using var process = Process.Start(processStartInfo); process?.WaitForExit(); diff --git a/tests/ModelContextProtocol.Tests/Protocol/ContentBlockTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ContentBlockTests.cs index 0113b77f3..53530f063 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/ContentBlockTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/ContentBlockTests.cs @@ -1,3 +1,4 @@ +using ModelContextProtocol.Core; using ModelContextProtocol.Protocol; using System.Text.Json; @@ -5,6 +6,19 @@ namespace ModelContextProtocol.Tests.Protocol; public class ContentBlockTests { + private static string GetUtf8String(ReadOnlyMemory bytes) => + McpTextUtilities.GetStringFromUtf8(bytes.Span); + + private static JsonSerializerOptions GetOptions(bool materializeUtf8TextContentBlocks) => + materializeUtf8TextContentBlocks + ? McpJsonUtilities.CreateOptions(materializeUtf8TextContentBlocks: true) + : McpJsonUtilities.DefaultOptions; + + private static string AssertTextBlock(ContentBlock contentBlock, bool materializeUtf8TextContentBlocks) => + materializeUtf8TextContentBlocks + ? Assert.IsType(contentBlock).Text + : Assert.IsType(contentBlock).Text; + [Fact] public void ResourceLinkBlock_SerializationRoundTrip_PreservesAllProperties() { @@ -81,8 +95,10 @@ public void ResourceLinkBlock_DeserializationWithoutName_ThrowsJsonException() Assert.Contains("Name must be provided for 'resource_link' type", exception.Message); } - [Fact] - public void Deserialize_IgnoresUnknownArrayProperty() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void Deserialize_IgnoresUnknownArrayProperty(bool materializeUtf8TextContentBlocks) { // This is a regression test where a server returned an unexpected response with // `structuredContent` as an array nested inside a content block. This should be @@ -97,15 +113,17 @@ public void Deserialize_IgnoresUnknownArrayProperty() ] }"; - var contentBlock = JsonSerializer.Deserialize(responseJson, McpJsonUtilities.DefaultOptions); + var options = GetOptions(materializeUtf8TextContentBlocks); + var contentBlock = JsonSerializer.Deserialize(responseJson, options); Assert.NotNull(contentBlock); - var textBlock = Assert.IsType(contentBlock); - Assert.Contains("1234567890", textBlock.Text); + Assert.Contains("1234567890", AssertTextBlock(contentBlock, materializeUtf8TextContentBlocks)); } - [Fact] - public void Deserialize_IgnoresUnknownObjectProperties() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void Deserialize_IgnoresUnknownObjectProperties(bool materializeUtf8TextContentBlocks) { string responseJson = @"{ ""type"": ""text"", @@ -118,15 +136,17 @@ public void Deserialize_IgnoresUnknownObjectProperties() } }"; - var contentBlock = JsonSerializer.Deserialize(responseJson, McpJsonUtilities.DefaultOptions); + var options = GetOptions(materializeUtf8TextContentBlocks); + var contentBlock = JsonSerializer.Deserialize(responseJson, options); Assert.NotNull(contentBlock); - var textBlock = Assert.IsType(contentBlock); - Assert.Contains("Sample text", textBlock.Text); + Assert.Contains("Sample text", AssertTextBlock(contentBlock, materializeUtf8TextContentBlocks)); } - [Fact] - public void ToolResultContentBlock_WithError_SerializationRoundtrips() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void ToolResultContentBlock_WithError_SerializationRoundtrips(bool materializeUtf8TextContentBlocks) { ToolResultContentBlock toolResult = new() { @@ -135,19 +155,21 @@ public void ToolResultContentBlock_WithError_SerializationRoundtrips() IsError = true }; - var json = JsonSerializer.Serialize(toolResult, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + var options = GetOptions(materializeUtf8TextContentBlocks); + var json = JsonSerializer.Serialize(toolResult, options); + var deserialized = JsonSerializer.Deserialize(json, options); var result = Assert.IsType(deserialized); Assert.Equal("call_123", result.ToolUseId); Assert.True(result.IsError); Assert.Single(result.Content); - var textBlock = Assert.IsType(result.Content[0]); - Assert.Equal("Error: City not found", textBlock.Text); + Assert.Equal("Error: City not found", AssertTextBlock(result.Content[0], materializeUtf8TextContentBlocks)); } - [Fact] - public void ToolResultContentBlock_WithStructuredContent_SerializationRoundtrips() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void ToolResultContentBlock_WithStructuredContent_SerializationRoundtrips(bool materializeUtf8TextContentBlocks) { ToolResultContentBlock toolResult = new() { @@ -160,22 +182,24 @@ public void ToolResultContentBlock_WithStructuredContent_SerializationRoundtrips IsError = false }; - var json = JsonSerializer.Serialize(toolResult, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + var options = GetOptions(materializeUtf8TextContentBlocks); + var json = JsonSerializer.Serialize(toolResult, options); + var deserialized = JsonSerializer.Deserialize(json, options); var result = Assert.IsType(deserialized); Assert.Equal("call_123", result.ToolUseId); Assert.Single(result.Content); - var textBlock = Assert.IsType(result.Content[0]); - Assert.Equal("Result data", textBlock.Text); + Assert.Equal("Result data", AssertTextBlock(result.Content[0], materializeUtf8TextContentBlocks)); Assert.NotNull(result.StructuredContent); Assert.Equal(18, result.StructuredContent.Value.GetProperty("temperature").GetInt32()); Assert.Equal("cloudy", result.StructuredContent.Value.GetProperty("condition").GetString()); Assert.False(result.IsError); } - [Fact] - public void ToolResultContentBlock_SerializationRoundTrip() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void ToolResultContentBlock_SerializationRoundTrip(bool materializeUtf8TextContentBlocks) { ToolResultContentBlock toolResult = new() { @@ -189,14 +213,14 @@ public void ToolResultContentBlock_SerializationRoundTrip() IsError = false }; - var json = JsonSerializer.Serialize(toolResult, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + var options = GetOptions(materializeUtf8TextContentBlocks); + var json = JsonSerializer.Serialize(toolResult, options); + var deserialized = JsonSerializer.Deserialize(json, options); var result = Assert.IsType(deserialized); Assert.Equal("call_123", result.ToolUseId); Assert.Equal(2, result.Content.Count); - var textBlock = Assert.IsType(result.Content[0]); - Assert.Equal("Result data", textBlock.Text); + Assert.Equal("Result data", AssertTextBlock(result.Content[0], materializeUtf8TextContentBlocks)); var imageBlock = Assert.IsType(result.Content[1]); Assert.Equal("base64data", imageBlock.Data); Assert.Equal("image/png", imageBlock.MimeType); @@ -225,4 +249,52 @@ public void ToolUseContentBlock_SerializationRoundTrip() Assert.Equal("Paris", result.Input.GetProperty("city").GetString()); Assert.Equal("metric", result.Input.GetProperty("units").GetString()); } -} \ No newline at end of file + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void Utf8TextContentBlock_SerializesAsText_AndDeserializesAsTextContentBlock(bool materializeUtf8TextContentBlocks) + { + // Utf8TextContentBlock is an optimization for write paths; the wire format is still a normal "text" block. + ContentBlock original = new Utf8TextContentBlock + { + Utf8Text = "Sample text"u8.ToArray() + }; + + var options = GetOptions(materializeUtf8TextContentBlocks); + + string json = JsonSerializer.Serialize(original, options); + Assert.Contains("\"type\":\"text\"", json); + Assert.Contains("\"text\":\"Sample text\"", json); + + ContentBlock? deserialized = JsonSerializer.Deserialize(json, options); + Assert.NotNull(deserialized); + Assert.Equal("Sample text", AssertTextBlock(deserialized, materializeUtf8TextContentBlocks)); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void ToolResultContentBlock_WithUtf8TextContent_SerializationRoundtrips(bool materializeUtf8TextContentBlocks) + { + ToolResultContentBlock toolResult = new() + { + ToolUseId = "call_123", + Content = + [ + new Utf8TextContentBlock { Utf8Text = "Result data"u8.ToArray() } + ], + IsError = false + }; + + var options = GetOptions(materializeUtf8TextContentBlocks); + var json = JsonSerializer.Serialize(toolResult, options); + var deserialized = JsonSerializer.Deserialize(json, options); + + var result = Assert.IsType(deserialized); + Assert.Equal("call_123", result.ToolUseId); + Assert.Single(result.Content); + Assert.Equal("Result data", AssertTextBlock(result.Content[0], materializeUtf8TextContentBlocks)); + Assert.False(result.IsError); + } +} diff --git a/tests/ModelContextProtocol.Tests/Protocol/CreateMessageRequestParamsTests.cs b/tests/ModelContextProtocol.Tests/Protocol/CreateMessageRequestParamsTests.cs index f57faf1d8..6b1f82751 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/CreateMessageRequestParamsTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/CreateMessageRequestParamsTests.cs @@ -5,8 +5,10 @@ namespace ModelContextProtocol.Tests.Protocol; public class CreateMessageRequestParamsTests { - [Fact] - public void WithTools_SerializationRoundtrips() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void WithTools_SerializationRoundtrips(bool materializeUtf8TextContentBlocks) { CreateMessageRequestParams requestParams = new() { @@ -39,8 +41,9 @@ public void WithTools_SerializationRoundtrips() ToolChoice = new ToolChoice { Mode = "auto" } }; - var json = JsonSerializer.Serialize(requestParams, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + var options = TextMaterializationTestHelpers.GetOptions(materializeUtf8TextContentBlocks); + var json = JsonSerializer.Serialize(requestParams, options); + var deserialized = JsonSerializer.Deserialize(json, options); Assert.NotNull(deserialized); Assert.Equal(1000, deserialized.MaxTokens); @@ -48,8 +51,7 @@ public void WithTools_SerializationRoundtrips() Assert.Single(deserialized.Messages); Assert.Equal(Role.User, deserialized.Messages[0].Role); Assert.Single(deserialized.Messages[0].Content); - var textContent = Assert.IsType(deserialized.Messages[0].Content[0]); - Assert.Equal("What's the weather in Paris?", textContent.Text); + Assert.Equal("What's the weather in Paris?", TextMaterializationTestHelpers.GetText(deserialized.Messages[0].Content[0], materializeUtf8TextContentBlocks)); Assert.NotNull(deserialized.Tools); Assert.Single(deserialized.Tools); Assert.Equal("get_weather", deserialized.Tools[0].Name); @@ -63,8 +65,10 @@ public void WithTools_SerializationRoundtrips() Assert.Equal("auto", deserialized.ToolChoice.Mode); } - [Fact] - public void WithToolChoiceRequired_SerializationRoundtrips() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void WithToolChoiceRequired_SerializationRoundtrips(bool materializeUtf8TextContentBlocks) { CreateMessageRequestParams requestParams = new() { @@ -95,8 +99,9 @@ public void WithToolChoiceRequired_SerializationRoundtrips() ToolChoice = new ToolChoice { Mode = "required" } }; - var json = JsonSerializer.Serialize(requestParams, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + var options = TextMaterializationTestHelpers.GetOptions(materializeUtf8TextContentBlocks); + var json = JsonSerializer.Serialize(requestParams, options); + var deserialized = JsonSerializer.Deserialize(json, options); Assert.NotNull(deserialized); Assert.Equal(1000, deserialized.MaxTokens); @@ -104,8 +109,7 @@ public void WithToolChoiceRequired_SerializationRoundtrips() Assert.Single(deserialized.Messages); Assert.Equal(Role.User, deserialized.Messages[0].Role); Assert.Single(deserialized.Messages[0].Content); - var textContent = Assert.IsType(deserialized.Messages[0].Content[0]); - Assert.Equal("What's the weather?", textContent.Text); + Assert.Equal("What's the weather?", TextMaterializationTestHelpers.GetText(deserialized.Messages[0].Content[0], materializeUtf8TextContentBlocks)); Assert.NotNull(deserialized.Tools); Assert.Single(deserialized.Tools); Assert.Equal("get_weather", deserialized.Tools[0].Name); @@ -115,8 +119,10 @@ public void WithToolChoiceRequired_SerializationRoundtrips() Assert.Equal("required", deserialized.ToolChoice.Mode); } - [Fact] - public void WithToolChoiceNone_SerializationRoundtrips() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void WithToolChoiceNone_SerializationRoundtrips(bool materializeUtf8TextContentBlocks) { CreateMessageRequestParams requestParams = new() { @@ -147,8 +153,9 @@ public void WithToolChoiceNone_SerializationRoundtrips() ToolChoice = new ToolChoice { Mode = "none" } }; - var json = JsonSerializer.Serialize(requestParams, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + var options = TextMaterializationTestHelpers.GetOptions(materializeUtf8TextContentBlocks); + var json = JsonSerializer.Serialize(requestParams, options); + var deserialized = JsonSerializer.Deserialize(json, options); Assert.NotNull(deserialized); Assert.Equal(1000, deserialized.MaxTokens); @@ -156,8 +163,7 @@ public void WithToolChoiceNone_SerializationRoundtrips() Assert.Single(deserialized.Messages); Assert.Equal(Role.User, deserialized.Messages[0].Role); Assert.Single(deserialized.Messages[0].Content); - var textContent = Assert.IsType(deserialized.Messages[0].Content[0]); - Assert.Equal("What's the weather in Paris?", textContent.Text); + Assert.Equal("What's the weather in Paris?", TextMaterializationTestHelpers.GetText(deserialized.Messages[0].Content[0], materializeUtf8TextContentBlocks)); Assert.NotNull(deserialized.Tools); Assert.Single(deserialized.Tools); Assert.Equal("get_weather", deserialized.Tools[0].Name); diff --git a/tests/ModelContextProtocol.Tests/Protocol/CreateMessageResultTests.cs b/tests/ModelContextProtocol.Tests/Protocol/CreateMessageResultTests.cs index 67ab5f4f9..22ce71d88 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/CreateMessageResultTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/CreateMessageResultTests.cs @@ -6,8 +6,10 @@ namespace ModelContextProtocol.Tests.Protocol; public class CreateMessageResultTests { - [Fact] - public void CreateMessageResult_WithSingleContent_Serializes() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void CreateMessageResult_WithSingleContent_Serializes(bool materializeUtf8TextContentBlocks) { CreateMessageResult result = new() { @@ -17,12 +19,13 @@ public void CreateMessageResult_WithSingleContent_Serializes() StopReason = "endTurn" }; - var json = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + var options = TextMaterializationTestHelpers.GetOptions(materializeUtf8TextContentBlocks); + var json = JsonSerializer.Serialize(result, options); + var deserialized = JsonSerializer.Deserialize(json, options); Assert.NotNull(deserialized); Assert.Single(deserialized.Content); - Assert.IsType(deserialized.Content[0]); + Assert.Equal("Hello", TextMaterializationTestHelpers.GetText(deserialized.Content[0], materializeUtf8TextContentBlocks)); } [Fact] @@ -60,8 +63,10 @@ public void CreateMessageResult_WithMultipleToolUses_Serializes() Assert.Equal("call_2", ((ToolUseContentBlock)deserialized.Content[1]).Id); } - [Fact] - public void CreateMessageResult_WithMixedContent_Serializes() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void CreateMessageResult_WithMixedContent_Serializes(bool materializeUtf8TextContentBlocks) { CreateMessageResult result = new() { @@ -80,12 +85,13 @@ public void CreateMessageResult_WithMixedContent_Serializes() StopReason = "toolUse" }; - var json = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + var options = TextMaterializationTestHelpers.GetOptions(materializeUtf8TextContentBlocks); + var json = JsonSerializer.Serialize(result, options); + var deserialized = JsonSerializer.Deserialize(json, options); Assert.NotNull(deserialized); Assert.Equal(2, deserialized.Content.Count); - Assert.IsType(deserialized.Content[0]); + Assert.Equal("Let me check that.", TextMaterializationTestHelpers.GetText(deserialized.Content[0], materializeUtf8TextContentBlocks)); Assert.IsType(deserialized.Content[1]); } diff --git a/tests/ModelContextProtocol.Tests/Protocol/ResourceContentsTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ResourceContentsTests.cs index 4f9890f7b..e507b8352 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/ResourceContentsTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/ResourceContentsTests.cs @@ -1,3 +1,4 @@ +using ModelContextProtocol.Core; using ModelContextProtocol.Protocol; using System.Text.Json; @@ -5,6 +6,9 @@ namespace ModelContextProtocol.Tests.Protocol; public static class ResourceContentsTests { + private static string GetUtf8String(ReadOnlyMemory bytes) => + McpTextUtilities.GetStringFromUtf8(bytes.Span); + [Fact] public static void TextResourceContents_UnknownArrayProperty_IsIgnored() { diff --git a/tests/ModelContextProtocol.Tests/Protocol/SamplingMessageTests.cs b/tests/ModelContextProtocol.Tests/Protocol/SamplingMessageTests.cs index 9765d1be3..1d7922adb 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/SamplingMessageTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/SamplingMessageTests.cs @@ -5,8 +5,10 @@ namespace ModelContextProtocol.Tests.Protocol; public class SamplingMessageTests { - [Fact] - public void WithToolResults_SerializationRoundtrips() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void WithToolResults_SerializationRoundtrips(bool materializeUtf8TextContentBlocks) { SamplingMessage message = new() { @@ -24,8 +26,9 @@ public void WithToolResults_SerializationRoundtrips() ] }; - var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + var options = TextMaterializationTestHelpers.GetOptions(materializeUtf8TextContentBlocks); + var json = JsonSerializer.Serialize(message, options); + var deserialized = JsonSerializer.Deserialize(json, options); Assert.NotNull(deserialized); Assert.Equal(Role.User, deserialized.Role); @@ -35,12 +38,13 @@ public void WithToolResults_SerializationRoundtrips() Assert.Equal("call_123", toolResult.ToolUseId); Assert.Single(toolResult.Content); - var textBlock = Assert.IsType(toolResult.Content[0]); - Assert.Equal("Weather in Paris: 18°C, partly cloudy", textBlock.Text); + Assert.Equal("Weather in Paris: 18°C, partly cloudy", TextMaterializationTestHelpers.GetText(toolResult.Content[0], materializeUtf8TextContentBlocks)); } - [Fact] - public void WithMultipleToolResults_SerializationRoundtrips() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void WithMultipleToolResults_SerializationRoundtrips(bool materializeUtf8TextContentBlocks) { SamplingMessage message = new() { @@ -60,8 +64,9 @@ public void WithMultipleToolResults_SerializationRoundtrips() ] }; - var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + var options = TextMaterializationTestHelpers.GetOptions(materializeUtf8TextContentBlocks); + var json = JsonSerializer.Serialize(message, options); + var deserialized = JsonSerializer.Deserialize(json, options); Assert.NotNull(deserialized); Assert.Equal(Role.User, deserialized.Role); @@ -70,18 +75,18 @@ public void WithMultipleToolResults_SerializationRoundtrips() var toolResult1 = Assert.IsType(deserialized.Content[0]); Assert.Equal("call_abc123", toolResult1.ToolUseId); Assert.Single(toolResult1.Content); - var textBlock1 = Assert.IsType(toolResult1.Content[0]); - Assert.Equal("Weather in Paris: 18°C, partly cloudy", textBlock1.Text); + Assert.Equal("Weather in Paris: 18°C, partly cloudy", TextMaterializationTestHelpers.GetText(toolResult1.Content[0], materializeUtf8TextContentBlocks)); var toolResult2 = Assert.IsType(deserialized.Content[1]); Assert.Equal("call_def456", toolResult2.ToolUseId); Assert.Single(toolResult2.Content); - var textBlock2 = Assert.IsType(toolResult2.Content[0]); - Assert.Equal("Weather in London: 15°C, rainy", textBlock2.Text); + Assert.Equal("Weather in London: 15°C, rainy", TextMaterializationTestHelpers.GetText(toolResult2.Content[0], materializeUtf8TextContentBlocks)); } - [Fact] - public void WithToolResultOnly_SerializationRoundtrips() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void WithToolResultOnly_SerializationRoundtrips(bool materializeUtf8TextContentBlocks) { SamplingMessage message = new() { @@ -96,8 +101,9 @@ public void WithToolResultOnly_SerializationRoundtrips() ] }; - var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions); - var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + var options = TextMaterializationTestHelpers.GetOptions(materializeUtf8TextContentBlocks); + var json = JsonSerializer.Serialize(message, options); + var deserialized = JsonSerializer.Deserialize(json, options); Assert.NotNull(deserialized); Assert.Equal(Role.User, deserialized.Role); @@ -105,7 +111,6 @@ public void WithToolResultOnly_SerializationRoundtrips() var toolResult = Assert.IsType(deserialized.Content[0]); Assert.Equal("call_123", toolResult.ToolUseId); Assert.Single(toolResult.Content); - var textBlock = Assert.IsType(toolResult.Content[0]); - Assert.Equal("Result", textBlock.Text); + Assert.Equal("Result", TextMaterializationTestHelpers.GetText(toolResult.Content[0], materializeUtf8TextContentBlocks)); } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Protocol/UnknownPropertiesTests.cs b/tests/ModelContextProtocol.Tests/Protocol/UnknownPropertiesTests.cs index fd3117de2..21e12e54c 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/UnknownPropertiesTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/UnknownPropertiesTests.cs @@ -1,3 +1,5 @@ +using System.Runtime.InteropServices; +using System.Text; using System.Text.Json; using ModelContextProtocol.Protocol; @@ -9,8 +11,10 @@ namespace ModelContextProtocol.Tests.Protocol; /// public class UnknownPropertiesTests { - [Fact] - public void ContentBlock_DeserializationWithUnknownProperty_SkipsProperty() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void ContentBlock_DeserializationWithUnknownProperty_SkipsProperty(bool materializeUtf8TextContentBlocks) { // Arrange - JSON with unknown "unknownField" property const string Json = """ @@ -22,16 +26,18 @@ public void ContentBlock_DeserializationWithUnknownProperty_SkipsProperty() """; // Act - var deserialized = JsonSerializer.Deserialize(Json, McpJsonUtilities.DefaultOptions); + var options = TextMaterializationTestHelpers.GetOptions(materializeUtf8TextContentBlocks); + var deserialized = JsonSerializer.Deserialize(Json, options); // Assert Assert.NotNull(deserialized); - var textBlock = Assert.IsType(deserialized); - Assert.Equal("Hello, world!", textBlock.Text); + Assert.Equal("Hello, world!", TextMaterializationTestHelpers.GetText(deserialized, materializeUtf8TextContentBlocks)); } - [Fact] - public void ContentBlock_DeserializationWithStructuredContentInContent_SkipsProperty() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void ContentBlock_DeserializationWithStructuredContentInContent_SkipsProperty(bool materializeUtf8TextContentBlocks) { // Arrange - This was the actual bug case: structuredContent incorrectly placed // inside a ContentBlock instead of at CallToolResult level @@ -46,12 +52,12 @@ public void ContentBlock_DeserializationWithStructuredContentInContent_SkipsProp """; // Act - Should not throw - var deserialized = JsonSerializer.Deserialize(Json, McpJsonUtilities.DefaultOptions); + var options = TextMaterializationTestHelpers.GetOptions(materializeUtf8TextContentBlocks); + var deserialized = JsonSerializer.Deserialize(Json, options); // Assert Assert.NotNull(deserialized); - var textBlock = Assert.IsType(deserialized); - Assert.Equal("Result text", textBlock.Text); + Assert.Equal("Result text", TextMaterializationTestHelpers.GetText(deserialized, materializeUtf8TextContentBlocks)); } [Fact] @@ -195,8 +201,10 @@ public void PrimitiveSchemaDefinition_DeserializationWithUnknownProperty_SkipsPr Assert.Equal("A test string", stringSchema.Description); } - [Fact] - public void CallToolResult_WithContentBlockContainingUnknownProperties_Succeeds() + [Theory] + [InlineData(false)] + [InlineData(true)] + public void CallToolResult_WithContentBlockContainingUnknownProperties_Succeeds(bool materializeUtf8TextContentBlocks) { // Arrange - Simulates the real-world bug scenario: a malformed response where // structuredContent was incorrectly nested inside content blocks @@ -216,13 +224,13 @@ public void CallToolResult_WithContentBlockContainingUnknownProperties_Succeeds( """; // Act - Should not throw an exception - var deserialized = JsonSerializer.Deserialize(Json, McpJsonUtilities.DefaultOptions); + var options = TextMaterializationTestHelpers.GetOptions(materializeUtf8TextContentBlocks); + var deserialized = JsonSerializer.Deserialize(Json, options); // Assert Assert.NotNull(deserialized); Assert.Single(deserialized.Content); - var textBlock = Assert.IsType(deserialized.Content[0]); - Assert.Equal("Tool executed successfully", textBlock.Text); + Assert.Equal("Tool executed successfully", TextMaterializationTestHelpers.GetText(deserialized.Content[0], materializeUtf8TextContentBlocks)); Assert.False(deserialized.IsError); } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index e1e1011b4..d92a28051 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -17,6 +17,13 @@ namespace ModelContextProtocol.Tests.Server; public partial class McpServerToolTests { + private static string GetText(ContentBlock content) => content switch + { + TextContentBlock text => text.Text, + Utf8TextContentBlock utf8 => utf8.Text, + _ => throw new XunitException($"Expected a text content block, got '{content.GetType()}'."), + }; + private static JsonRpcRequest CreateTestJsonRpcRequest() { return new JsonRpcRequest @@ -64,7 +71,7 @@ public async Task SupportsMcpServer() var result = await tool.InvokeAsync( new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); - Assert.Equal("42", (result.Content[0] as TextContentBlock)?.Text); + Assert.Equal("42", GetText(result.Content[0])); } [Fact] @@ -93,7 +100,7 @@ public async Task SupportsCtorInjection() Assert.NotNull(result); Assert.NotNull(result.Content); Assert.Single(result.Content); - Assert.Equal("True True True True", Assert.IsType(result.Content[0]).Text); + Assert.Equal("True True True True", GetText(result.Content[0])); } private sealed class HasCtorWithSpecialParameters @@ -174,7 +181,7 @@ public async Task SupportsServiceFromDI(ServiceLifetime injectedArgumentLifetime var result = await tool.InvokeAsync( new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Services = services }, TestContext.Current.CancellationToken); - Assert.Equal("42", (result.Content[0] as TextContentBlock)?.Text); + Assert.Equal("42", GetText(result.Content[0])); } [Fact] @@ -195,7 +202,7 @@ public async Task SupportsOptionalServiceFromDI() var result = await tool.InvokeAsync( new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); - Assert.Equal("42", (result.Content[0] as TextContentBlock)?.Text); + Assert.Equal("42", GetText(result.Content[0])); } [Fact] @@ -210,7 +217,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() var result = await tool1.InvokeAsync( new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); - Assert.Equal("""{"disposals":1}""", (result.Content[0] as TextContentBlock)?.Text); + Assert.Equal("""{"disposals":1}""", GetText(result.Content[0])); } [Fact] @@ -225,7 +232,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() var result = await tool1.InvokeAsync( new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); - Assert.Equal("""{"asyncDisposals":1}""", (result.Content[0] as TextContentBlock)?.Text); + Assert.Equal("""{"asyncDisposals":1}""", GetText(result.Content[0])); } [Fact] @@ -244,7 +251,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable var result = await tool1.InvokeAsync( new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Services = services }, TestContext.Current.CancellationToken); - Assert.Equal("""{"asyncDisposals":1,"disposals":0}""", (result.Content[0] as TextContentBlock)?.Text); + Assert.Equal("""{"asyncDisposals":1,"disposals":0}""", GetText(result.Content[0])); } @@ -268,12 +275,12 @@ public async Task CanReturnCollectionOfAIContent() Assert.Equal(3, result.Content.Count); - Assert.Equal("text", (result.Content[0] as TextContentBlock)?.Text); + Assert.Equal("text", GetText(result.Content[0])); - Assert.Equal("1234", (result.Content[1] as ImageContentBlock)?.Data); + Assert.Equal("1234", Assert.IsType(result.Content[1]).Data); Assert.Equal("image/png", (result.Content[1] as ImageContentBlock)?.MimeType); - Assert.Equal("1234", (result.Content[2] as AudioContentBlock)?.Data); + Assert.Equal("1234", Assert.IsType(result.Content[2]).Data); Assert.Equal("audio/wav", (result.Content[2] as AudioContentBlock)?.MimeType); } @@ -367,7 +374,7 @@ public async Task CanReturnCollectionOfStrings() new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Single(result.Content); - Assert.Equal("""["42","43"]""", Assert.IsType(result.Content[0]).Text); + Assert.Equal("""["42","43"]""", GetText(result.Content[0])); } [Fact] @@ -430,7 +437,7 @@ public async Task CanReturnCallToolResult() Assert.Same(response, result); Assert.Equal(2, result.Content.Count); - Assert.Equal("text", Assert.IsType(result.Content[0]).Text); + Assert.Equal("text", GetText(result.Content[0])); Assert.Equal("1234", Assert.IsType(result.Content[1]).Data); } diff --git a/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs index d14c376c1..9a102ad78 100644 --- a/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs @@ -17,15 +17,14 @@ public async Task SigInt_DisposesTestServerWithHosting_Gracefully() { using var process = new Process { - StartInfo = new ProcessStartInfo - { - FileName = "dotnet", - Arguments = "TestServerWithHosting.dll", - RedirectStandardInput = true, - RedirectStandardOutput = true, - UseShellExecute = false, - CreateNoWindow = true, - } + StartInfo = ProcessStartInfoUtilities.CreateOnPath( + "dotnet", + "TestServerWithHosting.dll", + redirectStandardInput: true, + redirectStandardOutput: true, + redirectStandardError: true, + useShellExecute: false, + createNoWindow: true) }; process.Start(); diff --git a/tests/ModelContextProtocol.Tests/TextMaterializationTestHelpers.cs b/tests/ModelContextProtocol.Tests/TextMaterializationTestHelpers.cs new file mode 100644 index 000000000..91e10a383 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/TextMaterializationTestHelpers.cs @@ -0,0 +1,23 @@ +using System.Text.Json; +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Tests; + +internal static class TextMaterializationTestHelpers +{ + /// + /// Gets the JsonSerializerOptions for tests, depending on whether UTF-8 text content blocks are materialized. + /// + internal static JsonSerializerOptions GetOptions(bool materializeUtf8TextContentBlocks) => + materializeUtf8TextContentBlocks + ? McpJsonUtilities.CreateOptions(materializeUtf8TextContentBlocks: true) + : McpJsonUtilities.DefaultOptions; + + /// + /// Gets the text from a ContentBlock, depending on whether UTF-8 text content blocks are materialized. + /// + internal static string GetText(ContentBlock contentBlock, bool materializeUtf8TextContentBlocks) => + materializeUtf8TextContentBlocks + ? Assert.IsType(contentBlock).Text + : Assert.IsType(contentBlock).Text; +} diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index cbe44da15..dc32a0c2a 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -66,7 +66,7 @@ public async Task SendMessageAsync_Should_Send_Message() await transport.SendMessageAsync(message, TestContext.Current.CancellationToken); - var result = Encoding.UTF8.GetString(output.ToArray()).Trim(); + var result = Core.McpTextUtilities.GetStringFromUtf8(output.GetBuffer().AsSpan(0, (int)output.Length)).Trim(); var expected = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions); Assert.Equal(expected, result); @@ -153,7 +153,7 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() await transport.SendMessageAsync(chineseMessage, TestContext.Current.CancellationToken); // Verify Chinese characters preserved but encoded - var chineseResult = Encoding.UTF8.GetString(output.ToArray()).Trim(); + var chineseResult = Core.McpTextUtilities.GetStringFromUtf8(output.GetBuffer().AsSpan(0, (int)output.Length)).Trim(); var expectedChinese = JsonSerializer.Serialize(chineseMessage, McpJsonUtilities.DefaultOptions); Assert.Equal(expectedChinese, chineseResult); Assert.Contains(JsonSerializer.Serialize(chineseText, McpJsonUtilities.DefaultOptions), chineseResult); @@ -175,7 +175,7 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() await transport.SendMessageAsync(emojiMessage, TestContext.Current.CancellationToken); // Verify emoji preserved - might be as either direct characters or escape sequences - var emojiResult = Encoding.UTF8.GetString(output.ToArray()).Trim(); + var emojiResult = Core.McpTextUtilities.GetStringFromUtf8(output.GetBuffer().AsSpan(0, (int)output.Length)).Trim(); var expectedEmoji = JsonSerializer.Serialize(emojiMessage, McpJsonUtilities.DefaultOptions); Assert.Equal(expectedEmoji, emojiResult);