diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index 2f5d03a772..f68b349d9c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index cfd850b3a3..1bc4ee58c8 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 55d4b62e96..c228736847 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -16,11 +16,13 @@ import logging from typing import AsyncGenerator +from typing import TYPE_CHECKING from typing import Union from google.genai import types from ..utils.context_utils import Aclosing +from ..utils.transcription_utils import join_fragment from ..utils.variant_utils import GoogleLLMVariant from .base_llm_connection import BaseLlmConnection from .llm_response import LlmResponse @@ -28,7 +30,7 @@ logger = logging.getLogger('google_adk.' + __name__) RealtimeInput = Union[types.Blob, types.ActivityStart, types.ActivityEnd] -from typing import TYPE_CHECKING + if TYPE_CHECKING: from google.genai import live @@ -181,13 +183,16 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: # generation_complete, causing transcription to appear after # tool_call in the session log. if message.server_content.input_transcription: - if message.server_content.input_transcription.text: - self._input_transcription_text += ( - message.server_content.input_transcription.text + if ( + new_input_transcription_chunk := message.server_content.input_transcription.text + ): + existing = self._input_transcription_text + self._input_transcription_text = join_fragment( + existing, new_input_transcription_chunk ) yield LlmResponse( input_transcription=types.Transcription( - text=message.server_content.input_transcription.text, + text=new_input_transcription_chunk, finished=False, ), partial=True, @@ -204,13 +209,16 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ) self._input_transcription_text = '' if message.server_content.output_transcription: - if message.server_content.output_transcription.text: - self._output_transcription_text += ( - message.server_content.output_transcription.text + if ( + new_output_transcription_chunk := message.server_content.output_transcription.text + ): + existing = self._output_transcription_text + self._output_transcription_text = join_fragment( + existing, new_output_transcription_chunk ) yield LlmResponse( output_transcription=types.Transcription( - text=message.server_content.output_transcription.text, + text=new_output_transcription_chunk, finished=False, ), partial=True, diff --git a/src/google/adk/utils/transcription_utils.py b/src/google/adk/utils/transcription_utils.py new file mode 100644 index 0000000000..f4a9257932 --- /dev/null +++ b/src/google/adk/utils/transcription_utils.py @@ -0,0 +1,66 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for transcription text handling.""" + +PUNCTUATION_CHARS = {'.', '!', '?', ';', ':', "'"} + + +def join_fragment(existing_tx: str, new_chunk: str) -> str: + """Join transcription fragments preserving proper spacing. + + Handles three special cases: + - Leading punctuation on the new chunk (attach without space). + - Leading apostrophe followed by 's' (contraction, attach without space). + - Leading apostrophe followed by other letters (plural possessive): + move apostrophe to end of existing and insert a space before remainder. + + Also avoids introducing double spaces when `existing` already ends + with whitespace. + """ + new_stripped = new_chunk.strip() + if not new_stripped: + return existing_tx + + # If the existing text ends with an apostrophe and the new fragment + # starts with 's' (continuation of a contraction), attach without + # a space: "That'" + "s great" -> "That's great". + if existing_tx.rstrip().endswith("'") and new_stripped[0].lower() == 's': + return existing_tx.rstrip() + new_stripped + + # Leading apostrophe handling when the new fragment itself starts + # with an apostrophe (e.g. "'s great" or "'job?"). + if new_stripped[0] == "'" and len(new_stripped) > 1: + remainder = new_stripped[1:] + # contraction like "'s": attach directly (That's) + if remainder[0].lower() == 's': + # If existing already ends with an apostrophe, don't add another. + if existing_tx.rstrip().endswith("'"): + return existing_tx.rstrip() + remainder + return existing_tx.rstrip() + "'" + remainder + # possessive like "'job": attach apostrophe to previous word, + # then add a space before the remainder (parents' job) + base = existing_tx.rstrip() + if not base.endswith("'"): + base = base + "'" + return base + ' ' + remainder + + # Leading punctuation attaches without a space (e.g. ? , !) + if new_stripped[0] in PUNCTUATION_CHARS: + return existing_tx.rstrip() + new_stripped + + # Default: add a space if existing doesn't already end with whitespace + if existing_tx and not existing_tx.endswith((' ', '\t', '\n')): + return existing_tx + ' ' + new_stripped + return existing_tx + new_stripped diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 190007603c..66b78aa4e6 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -593,3 +593,80 @@ async def mock_receive_generator(): assert responses[2].output_transcription.text == 'How can I help?' assert responses[2].output_transcription.finished is True assert responses[2].partial is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize('tx_direction', ['input', 'output']) +@pytest.mark.parametrize( + 'fragments', + [ + ('That', "'s great", "That's great"), + ("That'", 's great', "That's great"), + ("That's", 'great', "That's great"), + ("That's", ' great', "That's great"), + ("That's ", 'great', "That's great"), + ('Great', '! Good to hear', 'Great! Good to hear'), + ('Great!', 'Good to hear', 'Great! Good to hear'), + ('Great! ', 'Good to hear', 'Great! Good to hear'), + ('Great! Good', 'to hear', 'Great! Good to hear'), + ('Great! Good ', 'to hear', 'Great! Good to hear'), + ('Great! Good', ' to hear', 'Great! Good to hear'), + ("Is that parents' job", '?', "Is that parents' job?"), + ("Is that parents' ", 'job?', "Is that parents' job?"), + ("Is that parents'", 'job?', "Is that parents' job?"), + ('Is that parents', "'job?", "Is that parents' job?"), + ('Is that parents', " 'job?", "Is that parents' job?"), + ], +) +async def test_receive_final_transcription_space_between_fragments( + gemini_connection, mock_gemini_session, tx_direction, fragments +): + """Test receive final transcription fragments are joined with a space between words.""" + fragment1, fragment2, expected = fragments + + def _create_mock_transcription_message( + text: str | None, finished: bool, direction: str + ) -> mock.Mock: + msg = mock.Mock() + msg.usage_metadata = None + msg.server_content = mock.Mock() + msg.server_content.model_turn = None + msg.server_content.interrupted = False + msg.server_content.turn_complete = False + msg.server_content.generation_complete = False + msg.tool_call = None + msg.session_resumption_update = None + + transcription = types.Transcription(text=text, finished=finished) + if direction == 'input': + msg.server_content.input_transcription = transcription + msg.server_content.output_transcription = None + else: + msg.server_content.input_transcription = None + msg.server_content.output_transcription = transcription + return msg + + message1 = _create_mock_transcription_message(fragment1, False, tx_direction) + message2 = _create_mock_transcription_message(fragment2, False, tx_direction) + message3 = _create_mock_transcription_message(None, True, tx_direction) + + async def mock_receive_generator(): + yield message1 + yield message2 + yield message3 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + # find the finished transcription response + attr_name = f'{tx_direction}_transcription' + finished_resps = [ + r + for r in responses + if getattr(r, attr_name) and getattr(r, attr_name).finished + ] + assert finished_resps, 'Expected finished transcription response' + transcription = getattr(finished_resps[0], attr_name) + assert transcription.text == expected