diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index 8d96ba0bc..0c6fc67e5 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -43,6 +43,7 @@ import { CompleteResultSchema, CreateMessageRequestSchema, CreateMessageResultSchema, + CreateMessageResultWithToolsSchema, CreateTaskResultSchema, ElicitRequestSchema, ElicitResultSchema, @@ -458,8 +459,10 @@ export class Client< return taskValidationResult.data; } - // For non-task requests, validate against CreateMessageResultSchema - const validationResult = safeParse(CreateMessageResultSchema, result); + // For non-task requests, validate against appropriate schema based on tools presence + const hasTools = params.tools || params.toolChoice; + const resultSchema = hasTools ? CreateMessageResultWithToolsSchema : CreateMessageResultSchema; + const validationResult = safeParse(resultSchema, result); if (!validationResult.success) { const errorMessage = validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); diff --git a/test/integration/test/client/client.test.ts b/test/integration/test/client/client.test.ts index 5574a2d84..66f38e4cd 100644 --- a/test/integration/test/client/client.test.ts +++ b/test/integration/test/client/client.test.ts @@ -4132,3 +4132,129 @@ describe('getSupportedElicitationModes', () => { expect(result.supportsUrlMode).toBe(false); }); }); + +describe('Client sampling validation with tools', () => { + test('should validate array content with tool_use when request includes tools', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } }); + + // Handler returns array content with tool_use - should validate with CreateMessageResultWithToolsSchema + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + stopReason: 'toolUse', + content: [{ type: 'tool_use', id: 'call_1', name: 'test_tool', input: { arg: 'value' } }] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const result = await server.createMessage({ + messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }], + maxTokens: 100, + tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }] + }); + + expect(result.stopReason).toBe('toolUse'); + expect(Array.isArray(result.content)).toBe(true); + expect((result.content as Array<{ type: string }>)[0].type).toBe('tool_use'); + }); + + test('should validate single content when request includes tools', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } }); + + // Handler returns single content (text) - should still validate with CreateMessageResultWithToolsSchema + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + content: { type: 'text', text: 'No tool needed' } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const result = await server.createMessage({ + messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }], + maxTokens: 100, + tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }] + }); + + expect((result.content as { type: string }).type).toBe('text'); + }); + + test('should validate single content when request has no tools', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); + + // Handler returns single content - should validate with CreateMessageResultSchema + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + content: { type: 'text', text: 'Response' } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const result = await server.createMessage({ + messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }], + maxTokens: 100 + }); + + expect((result.content as { type: string }).type).toBe('text'); + }); + + test('should reject array content when request has no tools', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); + + // Handler returns array content - should fail validation with CreateMessageResultSchema + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + content: [{ type: 'text', text: 'Array response' }] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + await expect( + server.createMessage({ + messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }], + maxTokens: 100 + }) + ).rejects.toThrow('Invalid sampling result'); + }); + + test('should validate array content when request includes toolChoice', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } }); + + // Handler returns array content with tool_use + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + stopReason: 'toolUse', + content: [{ type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} }] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const result = await server.createMessage({ + messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }], + maxTokens: 100, + tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }], + toolChoice: { mode: 'auto' } + }); + + expect(result.stopReason).toBe('toolUse'); + expect(Array.isArray(result.content)).toBe(true); + }); +});