diff --git a/src/index.ts b/src/index.ts index d74e0c60..2ed22216 100644 --- a/src/index.ts +++ b/src/index.ts @@ -109,6 +109,8 @@ export { } from './lib/stream-transformers.js'; // Tool creation helpers export { tool } from './lib/tool.js'; +// Real-time tool event broadcasting +export { ToolEventBroadcaster } from './lib/tool-event-broadcaster.js'; export { hasApprovalRequiredTools, hasExecuteFunction, diff --git a/src/lib/model-result.ts b/src/lib/model-result.ts index 644cd7aa..51158da2 100644 --- a/src/lib/model-result.ts +++ b/src/lib/model-result.ts @@ -15,6 +15,7 @@ import type { TurnContext, UnsentToolResult, } from './tool-types.js'; +import { ToolEventBroadcaster } from './tool-event-broadcaster.js'; import { betaResponsesSend } from '../funcs/betaResponsesSend.js'; import { @@ -134,7 +135,11 @@ export class ModelResult { private initPromise: Promise | null = null; private toolExecutionPromise: Promise | null = null; private finalResponse: models.OpenResponsesNonStreamingResponse | null = null; - private preliminaryResults: Map = new Map(); + private toolEventBroadcaster: ToolEventBroadcaster<{ + type: 'preliminary_result'; + toolCallId: string; + result: InferToolEventsUnion; + }> | null = null; private allToolExecutionRounds: Array<{ round: number; toolCalls: ParsedToolCall[]; @@ -402,10 +407,16 @@ export class ModelResult { /** * Execute all tools in a single round * Returns the tool results for API submission + * @param broadcaster - Optional broadcaster for real-time preliminary result streaming */ private async executeToolRound( toolCalls: ParsedToolCall[], - turnContext: TurnContext + turnContext: TurnContext, + broadcaster?: ToolEventBroadcaster<{ + type: 'preliminary_result'; + toolCallId: string; + result: InferToolEventsUnion; + }> ): Promise { const toolResults: models.OpenResponsesFunctionCallOutput[] = []; @@ -413,12 +424,18 @@ export class ModelResult { const tool = this.options.tools?.find((t) => t.function.name === toolCall.name); if (!tool || !hasExecuteFunction(tool)) continue; - const result = await executeTool(tool, toolCall, turnContext); + // Create callback for real-time preliminary results + const onPreliminaryResult = broadcaster + ? (callId: string, result: unknown) => { + broadcaster.push({ + type: 'preliminary_result' as const, + toolCallId: callId, + result: result as InferToolEventsUnion, + }); + } + : undefined; - // Store preliminary results for streaming - if (result.preliminaryResults && result.preliminaryResults.length > 0) { - this.preliminaryResults.set(toolCall.id, result.preliminaryResults); - } + const result = await executeTool(tool, toolCall, turnContext, onPreliminaryResult); toolResults.push({ type: 'function_call_output' as const, @@ -914,6 +931,127 @@ export class ModelResult { return this.toolExecutionPromise; } + /** + * Execute tools with real-time broadcasting of preliminary results. + * This is used by streaming methods that want real-time tool events. + * Unlike executeToolsIfNeeded, this creates a new broadcaster and passes it through. + */ + private async executeToolsWithBroadcast( + broadcaster: ToolEventBroadcaster<{ + type: 'preliminary_result'; + toolCallId: string; + result: InferToolEventsUnion; + }> + ): Promise { + try { + await this.initStream(); + + // If resuming from approval and still pending, don't continue + if (this.isResumingFromApproval && this.currentState?.status === 'awaiting_approval') { + return; + } + + // Get initial response + let currentResponse = await this.getInitialResponse(); + + // Save initial response to state + await this.saveResponseToState(currentResponse); + + // Check if tools should be executed + const hasToolCalls = currentResponse.output.some( + (item) => hasTypeProperty(item) && item.type === 'function_call' + ); + + if (!this.options.tools?.length || !hasToolCalls) { + this.finalResponse = currentResponse; + await this.markStateComplete(); + return; + } + + // Extract and check tool calls + const toolCalls = extractToolCallsFromResponse(currentResponse); + + // Check for approval requirements + if (await this.handleApprovalCheck(toolCalls, 0, currentResponse)) { + return; // Paused for approval + } + + if (!this.hasExecutableToolCalls(toolCalls)) { + this.finalResponse = currentResponse; + await this.markStateComplete(); + return; + } + + // Main execution loop + let currentRound = 0; + + while (true) { + // Check for external interruption + if (await this.checkForInterruption(currentResponse)) { + return; + } + + // Check stop conditions + if (await this.shouldStopExecution()) { + break; + } + + const currentToolCalls = extractToolCallsFromResponse(currentResponse); + if (currentToolCalls.length === 0) { + break; + } + + // Check for approval requirements + if (await this.handleApprovalCheck(currentToolCalls, currentRound + 1, currentResponse)) { + return; + } + + if (!this.hasExecutableToolCalls(currentToolCalls)) { + break; + } + + // Build turn context + const turnContext: TurnContext = { numberOfTurns: currentRound + 1 }; + + // Resolve async functions for this turn + await this.resolveAsyncFunctionsForTurn(turnContext); + + // Execute tools WITH broadcaster for real-time events + const toolResults = await this.executeToolRound(currentToolCalls, turnContext, broadcaster); + + // Track execution round + this.allToolExecutionRounds.push({ + round: currentRound, + toolCalls: currentToolCalls, + response: currentResponse, + toolResults, + }); + + // Save tool results to state + await this.saveToolResultsToState(toolResults); + + // Apply nextTurnParams + await this.applyNextTurnParams(currentToolCalls); + + // Make follow-up request + currentResponse = await this.makeFollowupRequest(currentResponse, toolResults); + + // Save new response to state + await this.saveResponseToState(currentResponse); + + currentRound++; + } + + // Validate and finalize + this.validateFinalResponse(currentResponse); + this.finalResponse = currentResponse; + await this.markStateComplete(); + } finally { + // Always complete the broadcaster when done + broadcaster.complete(); + } + } + /** * Internal helper to get the text after tool execution */ @@ -958,7 +1096,7 @@ export class ModelResult { /** * Stream all response events as they arrive. * Multiple consumers can iterate over this stream concurrently. - * Includes preliminary tool result events after tool execution. + * Preliminary tool results are streamed in REAL-TIME as generator tools yield. */ getFullResponsesStream(): AsyncIterableIterator>> { return async function* (this: ModelResult) { @@ -967,27 +1105,32 @@ export class ModelResult { throw new Error('Stream not initialized'); } + // Create broadcaster for real-time tool events + this.toolEventBroadcaster = new ToolEventBroadcaster(); + const toolEventConsumer = this.toolEventBroadcaster.createConsumer(); + + // Start tool execution in background (doesn't block) + const executionPromise = this.executeToolsWithBroadcast(this.toolEventBroadcaster); + const consumer = this.reusableStream.createConsumer(); - // Yield original events directly + // Yield original API events for await (const event of consumer) { yield event; } - // After stream completes, check if tools were executed and emit preliminary results - await this.executeToolsIfNeeded(); - - // Emit all preliminary results as new event types - for (const [toolCallId, results] of this.preliminaryResults) { - for (const result of results) { - yield { - type: 'tool.preliminary_result' as const, - toolCallId, - result: result as InferToolEventsUnion, - timestamp: Date.now(), - }; - } + // Yield tool preliminary results as they arrive (real-time!) + for await (const event of toolEventConsumer) { + yield { + type: 'tool.preliminary_result' as const, + toolCallId: event.toolCallId, + result: event.result, + timestamp: Date.now(), + }; } + + // Ensure execution completed (handles errors) + await executionPromise; }.call(this); } @@ -1065,7 +1208,7 @@ export class ModelResult { /** * Stream tool call argument deltas and preliminary results. - * This filters the full event stream to yield: + * Preliminary results are streamed in REAL-TIME as generator tools yield. * - Tool call argument deltas as { type: "delta", content: string } * - Preliminary results as { type: "preliminary_result", toolCallId, result } */ @@ -1076,7 +1219,14 @@ export class ModelResult { throw new Error('Stream not initialized'); } - // Yield tool deltas as structured events + // Create broadcaster for real-time tool events + this.toolEventBroadcaster = new ToolEventBroadcaster(); + const toolEventConsumer = this.toolEventBroadcaster.createConsumer(); + + // Start tool execution in background (doesn't block) + const executionPromise = this.executeToolsWithBroadcast(this.toolEventBroadcaster); + + // Yield tool deltas from API stream for await (const delta of extractToolDeltas(this.reusableStream)) { yield { type: 'delta' as const, @@ -1084,19 +1234,13 @@ export class ModelResult { }; } - // After stream completes, check if tools were executed and emit preliminary results - await this.executeToolsIfNeeded(); - - // Emit all preliminary results - for (const [toolCallId, results] of this.preliminaryResults) { - for (const result of results) { - yield { - type: 'preliminary_result' as const, - toolCallId, - result: result as InferToolEventsUnion, - }; - } + // Yield tool events as they arrive (real-time!) + for await (const event of toolEventConsumer) { + yield event; } + + // Ensure execution completed (handles errors) + await executionPromise; }.call(this); } diff --git a/src/lib/tool-event-broadcaster.ts b/src/lib/tool-event-broadcaster.ts new file mode 100644 index 00000000..8dddd3a3 --- /dev/null +++ b/src/lib/tool-event-broadcaster.ts @@ -0,0 +1,151 @@ +/** + * A push-based event broadcaster that supports multiple concurrent consumers. + * Similar to ReusableReadableStream but for push-based events from tool execution. + * + * Each consumer gets their own position in the buffer and receives all events + * from their join point onward. This enables real-time streaming of generator + * tool preliminary results to multiple consumers simultaneously. + * + * @template T - The event type being broadcast + */ +export class ToolEventBroadcaster { + private buffer: T[] = []; + private consumers = new Map(); + private nextConsumerId = 0; + private isComplete = false; + private completionError: Error | null = null; + + /** + * Push a new event to all consumers. + * Events are buffered so late-joining consumers can catch up. + */ + push(event: T): void { + if (this.isComplete) return; + this.buffer.push(event); + this.notifyWaitingConsumers(); + } + + /** + * Mark the broadcaster as complete - no more events will be pushed. + * Optionally pass an error to signal failure to all consumers. + */ + complete(error?: Error): void { + this.isComplete = true; + this.completionError = error ?? null; + this.notifyWaitingConsumers(); + } + + /** + * Create a new consumer that can independently iterate over events. + * Consumers can join at any time and will receive events from position 0. + * Multiple consumers can be created and will all receive the same events. + */ + createConsumer(): AsyncIterableIterator { + const consumerId = this.nextConsumerId++; + const state: ConsumerState = { + position: 0, + waitingPromise: null, + cancelled: false, + }; + this.consumers.set(consumerId, state); + + // eslint-disable-next-line @typescript-eslint/no-this-alias + const self = this; + + return { + async next(): Promise> { + const consumer = self.consumers.get(consumerId); + if (!consumer) { + return { done: true, value: undefined }; + } + + if (consumer.cancelled) { + return { done: true, value: undefined }; + } + + // Return buffered event if available + if (consumer.position < self.buffer.length) { + const value = self.buffer[consumer.position]!; + consumer.position++; + return { done: false, value }; + } + + // If complete and caught up, we're done + if (self.isComplete) { + self.consumers.delete(consumerId); + if (self.completionError) { + throw self.completionError; + } + return { done: true, value: undefined }; + } + + // Set up waiting promise FIRST to avoid race condition + const waitPromise = new Promise((resolve, reject) => { + consumer.waitingPromise = { resolve, reject }; + + // Immediately check if we should resolve after setting up promise + if ( + self.isComplete || + self.completionError || + consumer.position < self.buffer.length + ) { + resolve(); + } + }); + + await waitPromise; + consumer.waitingPromise = null; + + // Recursively try again after waking up + return this.next(); + }, + + async return(): Promise> { + const consumer = self.consumers.get(consumerId); + if (consumer) { + consumer.cancelled = true; + self.consumers.delete(consumerId); + } + return { done: true, value: undefined }; + }, + + async throw(e?: unknown): Promise> { + const consumer = self.consumers.get(consumerId); + if (consumer) { + consumer.cancelled = true; + self.consumers.delete(consumerId); + } + throw e; + }, + + [Symbol.asyncIterator]() { + return this; + }, + }; + } + + /** + * Notify all waiting consumers that new data is available or stream completed + */ + private notifyWaitingConsumers(): void { + for (const consumer of this.consumers.values()) { + if (consumer.waitingPromise) { + if (this.completionError) { + consumer.waitingPromise.reject(this.completionError); + } else { + consumer.waitingPromise.resolve(); + } + consumer.waitingPromise = null; + } + } + } +} + +interface ConsumerState { + position: number; + waitingPromise: { + resolve: () => void; + reject: (error: Error) => void; + } | null; + cancelled: boolean; +} diff --git a/tests/e2e/call-model-tools.test.ts b/tests/e2e/call-model-tools.test.ts index 47f605ec..af4dc0e6 100644 --- a/tests/e2e/call-model-tools.test.ts +++ b/tests/e2e/call-model-tools.test.ts @@ -525,6 +525,105 @@ describe('Enhanced Tool Support for callModel', () => { }); }); + describe('Real-Time Generator Tool Streaming', () => { + it('should stream preliminary results in real-time via getToolStream', async () => { + const receivedEvents: unknown[] = []; + const timestamps: number[] = []; + + const progressTool = { + type: ToolType.Function, + function: { + name: 'progress_task', + description: 'A task with progress updates', + inputSchema: z.object({ task: z.string() }), + eventSchema: z.object({ progress: z.number(), message: z.string() }), + outputSchema: z.object({ completed: z.boolean() }), + execute: async function* (_params: { task: string }) { + yield { progress: 25, message: 'Starting' }; + await new Promise((r) => setTimeout(r, 50)); + yield { progress: 50, message: 'Halfway' }; + await new Promise((r) => setTimeout(r, 50)); + yield { progress: 75, message: 'Almost done' }; + yield { completed: true }; + }, + }, + }; + + const response = await client.callModel({ + model: 'openai/gpt-4o', + input: 'Run the progress task for testing', + tools: [progressTool], + stopWhen: stepCountIs(2), + }); + + for await (const event of response.getToolStream()) { + if (event.type === 'preliminary_result') { + receivedEvents.push(event.result); + timestamps.push(Date.now()); + } + } + + // Should have received 3 preliminary events (not counting final output) + expect(receivedEvents.length).toBeGreaterThanOrEqual(3); + + // Verify the events contain expected progress data + expect(receivedEvents[0]).toHaveProperty('progress', 25); + expect(receivedEvents[1]).toHaveProperty('progress', 50); + expect(receivedEvents[2]).toHaveProperty('progress', 75); + + // Events should have arrived over time (not all at once) + // Since we have 50ms delays, there should be measurable time between events + if (timestamps.length > 1) { + const timeDiff = timestamps[timestamps.length - 1]! - timestamps[0]!; + expect(timeDiff).toBeGreaterThan(50); // At least 50ms between first and last + } + }, 30000); + + it('should stream preliminary results via getFullResponsesStream', async () => { + const receivedPreliminaryEvents: unknown[] = []; + + const progressTool = { + type: ToolType.Function, + function: { + name: 'streaming_progress', + description: 'Stream progress updates', + inputSchema: z.object({ input: z.string() }), + eventSchema: z.object({ step: z.number() }), + outputSchema: z.object({ done: z.boolean() }), + execute: async function* (_params: { input: string }) { + yield { step: 1 }; + yield { step: 2 }; + yield { step: 3 }; + yield { done: true }; + }, + }, + }; + + const response = await client.callModel({ + model: 'openai/gpt-4o', + input: 'Run streaming progress', + tools: [progressTool], + stopWhen: stepCountIs(2), + }); + + for await (const event of response.getFullResponsesStream()) { + if (event.type === 'tool.preliminary_result') { + receivedPreliminaryEvents.push(event); + } + } + + // Should have received preliminary events + expect(receivedPreliminaryEvents.length).toBeGreaterThanOrEqual(3); + + // Each event should have toolCallId and result + for (const event of receivedPreliminaryEvents) { + expect(event).toHaveProperty('toolCallId'); + expect(event).toHaveProperty('result'); + expect(event).toHaveProperty('timestamp'); + } + }, 30000); + }); + describe('Manual Tool Execution', () => { it('should define tool without execute function', () => { const manualTool = { diff --git a/tests/unit/tool-event-broadcaster.test.ts b/tests/unit/tool-event-broadcaster.test.ts new file mode 100644 index 00000000..13a34558 --- /dev/null +++ b/tests/unit/tool-event-broadcaster.test.ts @@ -0,0 +1,255 @@ +import { describe, expect, it } from 'vitest'; +import { ToolEventBroadcaster } from '../../src/lib/tool-event-broadcaster.js'; + +describe('ToolEventBroadcaster', () => { + describe('single consumer', () => { + it('should deliver events to a single consumer', async () => { + const broadcaster = new ToolEventBroadcaster(); + const consumer = broadcaster.createConsumer(); + + broadcaster.push(1); + broadcaster.push(2); + broadcaster.push(3); + broadcaster.complete(); + + const results: number[] = []; + for await (const event of consumer) { + results.push(event); + } + + expect(results).toEqual([1, 2, 3]); + }); + + it('should handle empty stream', async () => { + const broadcaster = new ToolEventBroadcaster(); + const consumer = broadcaster.createConsumer(); + + broadcaster.complete(); + + const results: string[] = []; + for await (const event of consumer) { + results.push(event); + } + + expect(results).toEqual([]); + }); + + it('should handle consumer cancellation via return()', async () => { + const broadcaster = new ToolEventBroadcaster(); + const consumer = broadcaster.createConsumer(); + + broadcaster.push(1); + broadcaster.push(2); + + // Get first value + const first = await consumer.next(); + expect(first.done).toBe(false); + expect(first.value).toBe(1); + + // Cancel consumer + await consumer.return!(); + + // Should be done now + const after = await consumer.next(); + expect(after.done).toBe(true); + }); + }); + + describe('multiple consumers', () => { + it('should deliver same events to multiple consumers', async () => { + const broadcaster = new ToolEventBroadcaster(); + const consumer1 = broadcaster.createConsumer(); + const consumer2 = broadcaster.createConsumer(); + + broadcaster.push('a'); + broadcaster.push('b'); + broadcaster.complete(); + + const results1: string[] = []; + const results2: string[] = []; + + await Promise.all([ + (async () => { + for await (const e of consumer1) results1.push(e); + })(), + (async () => { + for await (const e of consumer2) results2.push(e); + })(), + ]); + + expect(results1).toEqual(['a', 'b']); + expect(results2).toEqual(['a', 'b']); + }); + + it('should allow consumers at different read positions', async () => { + const broadcaster = new ToolEventBroadcaster(); + const consumer1 = broadcaster.createConsumer(); + + broadcaster.push(1); + broadcaster.push(2); + + // Consumer 1 reads first event + const first = await consumer1.next(); + expect(first.value).toBe(1); + + // Consumer 2 joins after events pushed + const consumer2 = broadcaster.createConsumer(); + + broadcaster.push(3); + broadcaster.complete(); + + // Consumer 1 continues from position 1 + const remaining1: number[] = []; + for await (const e of consumer1) remaining1.push(e); + expect(remaining1).toEqual([2, 3]); + + // Consumer 2 gets all events from position 0 + const all2: number[] = []; + for await (const e of consumer2) all2.push(e); + expect(all2).toEqual([1, 2, 3]); + }); + }); + + describe('async waiting', () => { + it('should wait for events when consumer is ahead of buffer', async () => { + const broadcaster = new ToolEventBroadcaster(); + const consumer = broadcaster.createConsumer(); + + // Start consuming before events arrive + const consumePromise = (async () => { + const results: number[] = []; + for await (const event of consumer) { + results.push(event); + } + return results; + })(); + + // Push events after consumer starts waiting + await new Promise((r) => setTimeout(r, 10)); + broadcaster.push(1); + await new Promise((r) => setTimeout(r, 10)); + broadcaster.push(2); + broadcaster.complete(); + + const results = await consumePromise; + expect(results).toEqual([1, 2]); + }); + + it('should handle rapid push/consume interleaving', async () => { + const broadcaster = new ToolEventBroadcaster(); + const consumer = broadcaster.createConsumer(); + + const received: number[] = []; + const consumePromise = (async () => { + for await (const event of consumer) { + received.push(event); + } + })(); + + // Push events with minimal delay + for (let i = 0; i < 10; i++) { + broadcaster.push(i); + await new Promise((r) => setTimeout(r, 1)); + } + broadcaster.complete(); + + await consumePromise; + expect(received).toEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + }); + }); + + describe('error handling', () => { + it('should propagate errors to consumers', async () => { + const broadcaster = new ToolEventBroadcaster(); + const consumer = broadcaster.createConsumer(); + + broadcaster.push(1); + broadcaster.complete(new Error('Test error')); + + const results: number[] = []; + let caughtError: Error | null = null; + + try { + for await (const event of consumer) { + results.push(event); + } + } catch (e) { + caughtError = e as Error; + } + + expect(results).toEqual([1]); + expect(caughtError).not.toBeNull(); + expect(caughtError!.message).toBe('Test error'); + }); + + it('should propagate errors to waiting consumers', async () => { + const broadcaster = new ToolEventBroadcaster(); + const consumer = broadcaster.createConsumer(); + + // Start consuming (will wait) + const consumePromise = (async () => { + const results: number[] = []; + for await (const event of consumer) { + results.push(event); + } + return results; + })(); + + // Complete with error while consumer is waiting + await new Promise((r) => setTimeout(r, 10)); + broadcaster.complete(new Error('Async error')); + + await expect(consumePromise).rejects.toThrow('Async error'); + }); + }); + + describe('ignore after complete', () => { + it('should ignore pushes after complete', async () => { + const broadcaster = new ToolEventBroadcaster(); + const consumer = broadcaster.createConsumer(); + + broadcaster.push(1); + broadcaster.complete(); + broadcaster.push(2); // Should be ignored + + const results: number[] = []; + for await (const event of consumer) { + results.push(event); + } + + expect(results).toEqual([1]); + }); + }); + + describe('typed events', () => { + it('should work with typed tool events', async () => { + type ToolEvent = + | { type: 'delta'; content: string } + | { type: 'preliminary_result'; toolCallId: string; result: unknown }; + + const broadcaster = new ToolEventBroadcaster(); + const consumer = broadcaster.createConsumer(); + + broadcaster.push({ type: 'delta', content: 'test' }); + broadcaster.push({ + type: 'preliminary_result', + toolCallId: 'call_123', + result: { progress: 50 }, + }); + broadcaster.complete(); + + const events: ToolEvent[] = []; + for await (const event of consumer) { + events.push(event); + } + + expect(events).toHaveLength(2); + expect(events[0]).toEqual({ type: 'delta', content: 'test' }); + expect(events[1]).toEqual({ + type: 'preliminary_result', + toolCallId: 'call_123', + result: { progress: 50 }, + }); + }); + }); +});