From 4cee25d85339622e5e837a220806c3c1ec6a6a06 Mon Sep 17 00:00:00 2001 From: Ayman Hamed Date: Sun, 16 Nov 2025 12:39:41 +0200 Subject: [PATCH 1/9] Add native Ollama LLM support --- .../hello_world_ollama_native/README.md | 99 ++++++ .../hello_world_ollama_native/__init__.py | 15 + .../hello_world_ollama_native/agent.py | 89 +++++ .../samples/hello_world_ollama_native/main.py | 77 +++++ src/google/adk/models/ollama_llm.py | 279 ++++++++++++++++ tests/unittests/models/test_ollama.py | 314 ++++++++++++++++++ 6 files changed, 873 insertions(+) create mode 100644 contributing/samples/hello_world_ollama_native/README.md create mode 100755 contributing/samples/hello_world_ollama_native/__init__.py create mode 100755 contributing/samples/hello_world_ollama_native/agent.py create mode 100755 contributing/samples/hello_world_ollama_native/main.py create mode 100644 src/google/adk/models/ollama_llm.py create mode 100644 tests/unittests/models/test_ollama.py diff --git a/contributing/samples/hello_world_ollama_native/README.md b/contributing/samples/hello_world_ollama_native/README.md new file mode 100644 index 0000000000..dfe6bc4716 --- /dev/null +++ b/contributing/samples/hello_world_ollama_native/README.md @@ -0,0 +1,99 @@ +# Using Ollama Models with ADK (Native Integration) + +## Model Choice + +If your agent uses tools, choose an Ollama model that supports **function calling**. +Tool support can be verified with: + +```bash +ollama show mistral-small3.1 +``` +Model + architecture mistral3 + parameters 24.0B + context length 131072 + embedding length 5120 + quantization Q4_K_M + +Capabilities + completion + vision + tools + +Models must list tools under Capabilities. +Models without tool support will not execute ADK functions correctly. + +To inspect or customize a model template: +```bash +ollama show --modelfile llama3.1 > model_file_to_modify +``` +Then create a modified model: + +ollama create llama3.1-modified -f model_file_to_modify + + +## Native Ollama Provider in ADK + +ADK includes a native Ollama model class that communicates directly with the Ollama server at: + +http://localhost:11434/api/chat + +No LiteLLM provider, API keys, or OpenAI proxy endpoints are needed. + +### Example agent +```python +import random +from google.adk.agents.llm_agent import Agent +from google.adk.models.ollama import Ollama + +def roll_die(sides: int) -> int: + return random.randint(1, sides) + +def check_prime(numbers: list[int]) -> str: + primes = [] + for number in numbers: + number = int(number) + if number <= 1: + continue + for i in range(2, int(number ** 0.5) + 1): + if number % i == 0: + break + else: + primes.append(number) + return "No prime numbers found." if not primes else f"{', '.join(map(str, primes))} are prime numbers." + +root_agent = Agent( + model=Ollama(model="llama3.1"), + name="dice_agent", + description="Agent that rolls dice and checks primes using native Ollama.", + instruction="Always use the provided tools.", + tools=[roll_die, check_prime], +) +``` +## Connecting to a remote Ollama server + +Default Ollama endpoint: + +http://localhost:11434 + +Override using an environment variable: +```bash +export OLLAMA_API_BASE="http://192.168.1.20:11434" +``` +Or pass explicitly in code: +```python +Ollama(model="llama3.1", host="http://192.168.1.20:11434") +``` + + +## Running the Example with ADK Web + +Start the ADK Web UI: + +adk web hello_ollama_native + +The interface will be available in your browser, allowing interactive testing of tool calls. + + + + diff --git a/contributing/samples/hello_world_ollama_native/__init__.py b/contributing/samples/hello_world_ollama_native/__init__.py new file mode 100755 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/hello_world_ollama_native/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/hello_world_ollama_native/agent.py b/contributing/samples/hello_world_ollama_native/agent.py new file mode 100755 index 0000000000..fa4e0bd1d4 --- /dev/null +++ b/contributing/samples/hello_world_ollama_native/agent.py @@ -0,0 +1,89 @@ +# Copyright 2025 Ayman Hamed +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +from google.adk.agents.llm_agent import Agent +from google.adk.models.ollama_llm import Ollama + + +def roll_die(sides: int) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + return random.randint(1, sides) + + +def check_prime(numbers: list[int]) -> str: + """Check if a given list of numbers are prime. + + Args: + numbers: The list of numbers to check. + + Returns: + A str indicating which number is prime. + """ + primes = set() + for number in numbers: + number = int(number) + if number <= 1: + continue + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + primes.add(number) + return ( + "No prime numbers found." + if not primes + else f"{', '.join(str(num) for num in primes)} are prime numbers." + ) + + +root_agent = Agent( + model=Ollama(model='llama3.1'), + name="dice_roll_agent", + description=( + "hello world agent that can roll a dice of any number of sides and" + " check prime numbers." + ), + instruction=""" + You roll dice and answer questions about the outcome of the dice rolls. + You can roll dice of different sizes. + You can use multiple tools in parallel by calling functions in parallel(in one request and in one round). + It is ok to discuss previous dice roles, and comment on the dice rolls. + When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string. + You should never roll a die on your own. + When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string. + You should not check prime numbers before calling the tool. + When you are asked to roll a die and check prime numbers, you should always make the following two function calls: + 1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool. + 2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result. + 2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list. + 3. When you respond, you must include the roll_die result from step 1. + You should always perform the previous 3 steps when asking for a roll and checking prime numbers. + You should not rely on the previous history on prime results. + """, + tools=[ + roll_die, + check_prime, + ], +) diff --git a/contributing/samples/hello_world_ollama_native/main.py b/contributing/samples/hello_world_ollama_native/main.py new file mode 100755 index 0000000000..28fdbbbc92 --- /dev/null +++ b/contributing/samples/hello_world_ollama_native/main.py @@ -0,0 +1,77 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import time +import warnings + +import agent +from dotenv import load_dotenv +from google.adk import Runner +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.cli.utils import logs +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.session import Session +from google.genai import types + +load_dotenv(override=True) +warnings.filterwarnings('ignore', category=UserWarning) +logs.log_to_tmp_folder() + + +async def main(): + app_name = 'my_app' + user_id_1 = 'user1' + session_service = InMemorySessionService() + artifact_service = InMemoryArtifactService() + runner = Runner( + app_name=app_name, + agent=agent.root_agent, + artifact_service=artifact_service, + session_service=session_service, + ) + session_11 = await session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + async def run_prompt(session: Session, new_message: str): + content = types.Content( + role='user', parts=[types.Part.from_text(text=new_message)] + ) + print('** User says:', content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ): + if event.content.parts and event.content.parts[0].text: + print(f'** {event.author}: {event.content.parts[0].text}') + + start_time = time.time() + print('Start time:', start_time) + print('------------------------------------') + await run_prompt(session_11, 'Hi, introduce yourself.') + await run_prompt( + session_11, 'Roll a die with 100 sides and check if it is prime' + ) + await run_prompt(session_11, 'Roll it again.') + await run_prompt(session_11, 'What numbers did I get?') + end_time = time.time() + print('------------------------------------') + print('End time:', end_time) + print('Total time:', end_time - start_time) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/google/adk/models/ollama_llm.py b/src/google/adk/models/ollama_llm.py new file mode 100644 index 0000000000..18cf60cecf --- /dev/null +++ b/src/google/adk/models/ollama_llm.py @@ -0,0 +1,279 @@ +# Copyright 2025 Ayman Hamed +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import json +import logging +from typing import Any +from typing import AsyncGenerator +from typing import Optional +from typing import Sequence +from typing import Union +import urllib.error +import urllib.request + +from google.genai import types +from pydantic import Field +from typing_extensions import override + +from .base_llm import BaseLlm +from .llm_request import LlmRequest +from .llm_response import LlmResponse + +logger = logging.getLogger('google_adk.' + __name__) + +_CHAT_ENDPOINT = '/api/chat' + + +class Ollama(BaseLlm): + """Native integration for Ollama hosted models.""" + + model: str = 'ollama/llama3.1' + host: str = Field( + default='http://localhost:11434', + description='Base URL of the Ollama server.', + ) + request_timeout: float = Field( + default=120.0, description='Timeout in seconds for Ollama requests.' + ) + + @classmethod + @override + def supported_models(cls) -> list[str]: + return [r'ollama\/.+'] + + @override + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + if stream: + logger.warning( + 'Streaming is not yet supported for Ollama; falling back to unary.' + ) + self._maybe_append_user_content(llm_request) + + payload = self._build_payload(llm_request) + try: + response_json = await asyncio.to_thread(self._post_chat, payload) + except RuntimeError as exc: + logger.error('Failed to call Ollama: %s', exc) + yield LlmResponse(error_code='OLLAMA_ERROR', error_message=str(exc)) + return + + llm_response = self._to_llm_response(response_json) + yield llm_response + + def _build_payload(self, llm_request: LlmRequest) -> dict[str, Any]: + payload: dict[str, Any] = { + 'model': self._extract_model_name(llm_request.model), + 'messages': self._convert_messages(llm_request), + 'stream': False, + } + if tools := self._convert_tools(llm_request): + payload['tools'] = tools + if options := self._convert_options(llm_request): + payload['options'] = options + return payload + + def _extract_model_name(self, request_model: Optional[str]) -> str: + model_name = request_model or self.model + if model_name.startswith("ollama/") or model_name.startswith("ollama_chat/"): + return model_name.split('/', 1)[1] + return model_name + + def _convert_messages(self, llm_request: LlmRequest) -> list[dict[str, str]]: + messages: list[dict[str, str]] = [] + if llm_request.config.system_instruction: + messages.append( + { + 'role': 'system', + 'content': llm_request.config.system_instruction, + } + ) + for content in llm_request.contents: + message_text = self._content_to_text(content) + if not message_text: + continue + role = self._map_role(content.role) + messages.append({'role': role, 'content': message_text}) + return messages + + def _content_to_text(self, content: types.Content) -> str: + parts = content.parts or [] + text_parts: list[str] = [] + for part in parts: + if part.text: + text_parts.append(part.text) + elif part.function_response: + try: + response_json = json.dumps( + part.function_response.response, ensure_ascii=False + ) + except TypeError: + response_json = str(part.function_response.response) + text_parts.append( + f"[tool_response name={part.function_response.name or ''}] " + f"{response_json}" + ) + elif part.function_call: + try: + args_json = json.dumps(part.function_call.args, ensure_ascii=False) + except TypeError: + args_json = str(part.function_call.args) + text_parts.append( + f"[tool_call name={part.function_call.name}] {args_json}" + ) + else: + logger.debug( + 'Skipping unsupported content part for Ollama message: %s', part + ) + return '\n'.join(text_parts) + + def _map_role(self, role: Optional[str]) -> str: + if role in ('model', 'assistant'): + return 'assistant' + if role == 'system': + return 'system' + return 'user' + + def _convert_tools(self, llm_request: LlmRequest) -> list[dict[str, Any]]: + tools_spec: list[dict[str, Any]] = [] + if not llm_request.config.tools: + return tools_spec + for tool in llm_request.config.tools: + function_declarations: Optional[Sequence[types.FunctionDeclaration]] = ( + tool.function_declarations + if isinstance(tool, types.Tool) + else None + ) + if not function_declarations: + continue + for function_declaration in function_declarations: + tools_spec.append( + { + 'type': 'function', + 'function': { + 'name': function_declaration.name, + 'description': function_declaration.description or '', + 'parameters': self._function_parameters_to_json( + function_declaration + ), + }, + } + ) + return tools_spec + + def _function_parameters_to_json( + self, function_declaration: types.FunctionDeclaration + ) -> dict[str, Any]: + if function_declaration.parameters is None: + return {'type': 'object', 'properties': {}} + try: + return function_declaration.parameters.model_dump(exclude_none=True) + except AttributeError: + # model_dump is not guaranteed depending on the genai version. + try: + return json.loads( + function_declaration.parameters.model_dump_json(exclude_none=True) + ) + except (AttributeError, json.JSONDecodeError, TypeError) as exc: + logger.debug( + ( + 'Failed to convert function parameters, defaulting to empty' + ' schema: %s' + ), + exc, + ) + return {'type': 'object', 'properties': {}} + + def _convert_options(self, llm_request: LlmRequest) -> dict[str, Any]: + options: dict[str, Any] = {} + config = llm_request.config + temperature = getattr(config, 'temperature', None) + if temperature is not None: + options['temperature'] = temperature + top_p = getattr(config, 'top_p', None) + if top_p is not None: + options['top_p'] = top_p + max_output_tokens = getattr(config, 'max_output_tokens', None) + if max_output_tokens is not None: + options['num_predict'] = max_output_tokens + return options + + def _post_chat(self, payload: dict[str, Any]) -> dict[str, Any]: + url = self.host.rstrip('/') + _CHAT_ENDPOINT + data = json.dumps(payload).encode('utf-8') + request = urllib.request.Request( + url, + data=data, + headers={'Content-Type': 'application/json'}, + method='POST', + ) + try: + with urllib.request.urlopen( + request, timeout=self.request_timeout + ) as response: + response_body = response.read().decode('utf-8') + except urllib.error.URLError as exc: + raise RuntimeError(exc.reason) from exc + except urllib.error.HTTPError as exc: + message = exc.read().decode('utf-8', errors='ignore') + raise RuntimeError(f'{exc.code}: {message}') from exc + return json.loads(response_body) + + def _to_llm_response(self, response_json: dict[str, Any]) -> LlmResponse: + if error := response_json.get('error'): + return LlmResponse( + error_code='OLLAMA_ERROR', + error_message=str(error), + ) + + message = response_json.get('message', {}) + parts: list[types.Part] = [] + + content = message.get('content') + if isinstance(content, str) and content.strip(): + parts.append(types.Part.from_text(text=content)) + + for tool_call in message.get('tool_calls', []): + function_payload = tool_call.get('function', {}) + name = function_payload.get('name') + arguments: Union[str, dict[str, Any], None] = function_payload.get( + 'arguments' + ) + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + logger.debug( + 'Failed to parse tool call arguments as JSON: %s', arguments + ) + elif arguments is None: + arguments = {} + function_call = types.FunctionCall(name=name, args=arguments) + if tool_call_id := tool_call.get('id'): + setattr(function_call, 'id', tool_call_id) + parts.append(types.Part(function_call=function_call)) + + if not parts: + return LlmResponse( + error_code='NO_CONTENT', + error_message='Ollama response did not contain model output.', + ) + + return LlmResponse( + content=types.Content(role='model', parts=parts), + ) diff --git a/tests/unittests/models/test_ollama.py b/tests/unittests/models/test_ollama.py new file mode 100644 index 0000000000..f68c20c558 --- /dev/null +++ b/tests/unittests/models/test_ollama.py @@ -0,0 +1,314 @@ +# Copyright 2025 +# Tests for native Ollama integration for Google ADK. + +import json +import pytest +from unittest.mock import Mock, AsyncMock + +from google.adk.models.ollama_llm import Ollama +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.genai import types + + +# +# -------------------------- +# Helpers & Fixtures +# -------------------------- +# + +@pytest.fixture +def mock_ollama_client(): + """Mock HTTP client for Ollama POST calls.""" + class Client: + def __init__(self, response): + self.response = response + self.calls = [] + + def post(self, payload): + self.calls.append(payload) + return self.response + + return Client + + +def mock_response_ok(text="Hello world", tool_calls=None, usage=None): + """Creates a fake Ollama /api/chat response.""" + message = {"content": text} + if tool_calls: + message["tool_calls"] = tool_calls + + resp = {"message": message} + if usage: + resp["usage"] = usage + return resp + + +# +# -------------------------- +# Test: model extraction +# -------------------------- +# + +def test_extract_model_name_basic(): + o = Ollama(model="ollama/mistral") + assert o._extract_model_name("ollama/mistral") == "mistral" + + +def test_extract_model_name_chat_prefix(): + o = Ollama(model="ollama_chat/llama3.1") + assert o._extract_model_name("ollama_chat/llama3.1") == "llama3.1" + + +def test_extract_model_name_no_prefix(): + o = Ollama(model="mistral") + assert o._extract_model_name("mistral") == "mistral" + + +# +# -------------------------- +# Test: message conversion +# -------------------------- +# + +def test_convert_messages_basic(): + o = Ollama() + + req = LlmRequest( + contents=[ + types.Content( + role="user", + parts=[types.Part.from_text("Hi")] + ) + ] + ) + + msgs = o._convert_messages(req) + assert msgs[0]["role"] == "user" + assert msgs[0]["content"] == "Hi" + + +def test_convert_messages_with_system(): + o = Ollama() + + req = LlmRequest( + contents=[types.Content(role="user", parts=[types.Part.from_text("X")])], + config=types.GenerateContentConfig(system_instruction="SYS") + ) + + msgs = o._convert_messages(req) + assert msgs[0]["role"] == "system" + assert msgs[0]["content"] == "SYS" + assert msgs[1]["content"] == "X" + + +# +# -------------------------- +# Test: _content_to_text +# -------------------------- +# + +def test_content_to_text_basic(): + o = Ollama() + content = types.Content(role="user", parts=[types.Part.from_text("ABC")]) + assert o._content_to_text(content) == "ABC" + + +def test_content_to_text_function_call(): + o = Ollama() + + part = types.Part.from_function_call( + name="add", + args={"x": 1, "y": 2}, + ) + part.function_call.id = "call123" + + content = types.Content(role="assistant", parts=[part]) + txt = o._content_to_text(content) + + assert "[tool_call name=add]" in txt + assert '{"x": 1, "y": 2}' in txt + + +def test_content_to_text_tool_response(): + o = Ollama() + + part = types.Part.from_function_response( + name="add", response={"result": 3} + ) + part.function_response.id = "id123" + + content = types.Content(role="tool", parts=[part]) + txt = o._content_to_text(content) + + assert "[tool_response name=add]" in txt + assert '"result": 3' in txt + + +# +# -------------------------- +# Test: _convert_tools +# -------------------------- +# + +def test_convert_tools_basic(): + o = Ollama() + + req = LlmRequest( + config=types.GenerateContentConfig( + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="add", + description="Add numbers", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={"x": types.Schema(type=types.Type.NUMBER)} + ) + ) + ] + ) + ] + ) + ) + + tools = o._convert_tools(req) + assert tools[0]["function"]["name"] == "add" + assert tools[0]["function"]["parameters"]["type"] == "object" + + +# +# -------------------------- +# Test: HTTP call wrapper +# -------------------------- +# + +def test_post_chat_success(monkeypatch): + fake_response = {"message": {"content": "OK"}} + + def fake_urlopen(req, timeout=0): + class Resp: + def read(self): + return json.dumps(fake_response).encode("utf-8") + return Resp() + + monkeypatch.setattr("urllib.request.urlopen", fake_urlopen) + + o = Ollama() + resp = o._post_chat({"model": "x"}) + assert resp["message"]["content"] == "OK" + + +# +# -------------------------- +# Test: _to_llm_response +# -------------------------- +# + +def test_to_llm_response_text(): + o = Ollama() + resp = mock_response_ok("Hi") + + out = o._to_llm_response(resp) + assert isinstance(out, LlmResponse) + assert out.content.parts[0].text == "Hi" + + +def test_to_llm_response_tool_call(): + o = Ollama() + + tool_call = { + "id": "abc", + "function": { + "name": "add", + "arguments": '{"x": 1}' + } + } + + resp = mock_response_ok(tool_calls=[tool_call]) + out = o._to_llm_response(resp) + + fc = out.content.parts[0].function_call + assert fc.name == "add" + assert fc.args == {"x": 1} + assert fc.id == "abc" + + +def test_to_llm_response_usage_metadata(): + o = Ollama() + resp = mock_response_ok( + text="Hi", + usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + ) + + out = o._to_llm_response(resp) + + assert out.usage_metadata.prompt_token_count == 10 + assert out.usage_metadata.candidates_token_count == 5 + assert out.usage_metadata.total_token_count == 15 + + +# +# -------------------------- +# Test: full generate_content_async +# -------------------------- +# + +@pytest.mark.asyncio +async def test_generate_content_async_basic(monkeypatch): + resp = mock_response_ok("Hello!") + + async def fake_thread(fn, *args): + return resp + + monkeypatch.setattr("asyncio.to_thread", fake_thread) + + o = Ollama(model="ollama/mistral") + + req = LlmRequest( + contents=[types.Content(role="user", parts=[types.Part.from_text("Hi")])] + ) + + results = [r async for r in o.generate_content_async(req)] + assert results[0].content.parts[0].text == "Hello!" + + +@pytest.mark.asyncio +async def test_generate_content_async_error(monkeypatch): + async def fake_thread(fn, *args): + raise RuntimeError("boom") + + monkeypatch.setattr("asyncio.to_thread", fake_thread) + + o = Ollama() + + req = LlmRequest(contents=[types.Content(role="user", parts=[])]) + + results = [r async for r in o.generate_content_async(req)] + assert results[0].error_code == "OLLAMA_ERROR" + + +# +# -------------------------- +# Test: model override +# -------------------------- +# + +@pytest.mark.asyncio +async def test_model_override(monkeypatch): + resp = mock_response_ok("Hello") + + async def fake_thread(fn, *args): + return resp + + monkeypatch.setattr("asyncio.to_thread", fake_thread) + + o = Ollama(model="default") + req = LlmRequest( + model="override", + contents=[types.Content(role="user", parts=[types.Part.from_text("X")])] + ) + + out = [r async for r in o.generate_content_async(req)][0] + assert out.model_version == "override" \ No newline at end of file From fe1ee917eba0dfc51b860d19c4a26a925ecb96ed Mon Sep 17 00:00:00 2001 From: Ayman Hamed Date: Sun, 16 Nov 2025 15:36:59 +0200 Subject: [PATCH 2/9] Fix Ollama integration: add model_version, usage metadata, safe JSON parsing, updated tests --- src/google/adk/models/ollama_llm.py | 262 +++++++++++++++++++------- tests/unittests/models/test_ollama.py | 170 +++++++---------- 2 files changed, 263 insertions(+), 169 deletions(-) diff --git a/src/google/adk/models/ollama_llm.py b/src/google/adk/models/ollama_llm.py index 18cf60cecf..ec114cc625 100644 --- a/src/google/adk/models/ollama_llm.py +++ b/src/google/adk/models/ollama_llm.py @@ -17,11 +17,7 @@ import asyncio import json import logging -from typing import Any -from typing import AsyncGenerator -from typing import Optional -from typing import Sequence -from typing import Union +from typing import Any, AsyncGenerator, Optional, Sequence, Union import urllib.error import urllib.request @@ -33,27 +29,43 @@ from .llm_request import LlmRequest from .llm_response import LlmResponse -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) -_CHAT_ENDPOINT = '/api/chat' +_CHAT_ENDPOINT = "/api/chat" class Ollama(BaseLlm): - """Native integration for Ollama hosted models.""" + """Native integration for Ollama-hosted models. + + This backend talks directly to the Ollama HTTP API: + + POST /api/chat + + It supports: + * `ollama/` names (e.g. `ollama/llama3.2`) + * `ollama_chat/` names for LiteLlm compatibility + * System / user / assistant messages + * Unary generation + * Tool-calling via Ollama `tools` schema + """ + + # Default model name is compatible with Agent(model="ollama/llama3.1") + model: str = "ollama/llama3.1" - model: str = 'ollama/llama3.1' host: str = Field( - default='http://localhost:11434', - description='Base URL of the Ollama server.', + default="http://localhost:11434", + description="Base URL of the Ollama server.", ) request_timeout: float = Field( - default=120.0, description='Timeout in seconds for Ollama requests.' + default=120.0, + description="Timeout in seconds for Ollama requests.", ) @classmethod @override def supported_models(cls) -> list[str]: - return [r'ollama\/.+'] + # Allow any `ollama/...` style name. + return [r"ollama\/.+"] @override async def generate_content_async( @@ -61,63 +73,94 @@ async def generate_content_async( ) -> AsyncGenerator[LlmResponse, None]: if stream: logger.warning( - 'Streaming is not yet supported for Ollama; falling back to unary.' + "Streaming is not yet supported for Ollama; falling back to unary." ) + + # Ensure last user content is appended if needed (BaseLlm helper). self._maybe_append_user_content(llm_request) payload = self._build_payload(llm_request) try: response_json = await asyncio.to_thread(self._post_chat, payload) except RuntimeError as exc: - logger.error('Failed to call Ollama: %s', exc) - yield LlmResponse(error_code='OLLAMA_ERROR', error_message=str(exc)) + logger.error("Failed to call Ollama: %s", exc) + yield LlmResponse(error_code="OLLAMA_ERROR", error_message=str(exc)) return - llm_response = self._to_llm_response(response_json) + llm_response = self._to_llm_response( + response_json, request_model=llm_request.model + ) yield llm_response + # --------------------------------------------------------------------------- + # Payload construction + # --------------------------------------------------------------------------- + def _build_payload(self, llm_request: LlmRequest) -> dict[str, Any]: payload: dict[str, Any] = { - 'model': self._extract_model_name(llm_request.model), - 'messages': self._convert_messages(llm_request), - 'stream': False, + "model": self._extract_model_name(llm_request.model), + "messages": self._convert_messages(llm_request), + "stream": False, } + if tools := self._convert_tools(llm_request): - payload['tools'] = tools + payload["tools"] = tools if options := self._convert_options(llm_request): - payload['options'] = options + payload["options"] = options + return payload def _extract_model_name(self, request_model: Optional[str]) -> str: + """Normalize model name for the Ollama API. + + Supports: + * "ollama/llama3.2" → "llama3.2" + * "ollama_chat/llama3.2" → "llama3.2" + * "llama3.2" → "llama3.2" + """ model_name = request_model or self.model if model_name.startswith("ollama/") or model_name.startswith("ollama_chat/"): - return model_name.split('/', 1)[1] + return model_name.split("/", 1)[1] return model_name def _convert_messages(self, llm_request: LlmRequest) -> list[dict[str, str]]: + """Convert ADK Contents into Ollama chat messages.""" messages: list[dict[str, str]] = [] + + # System instruction → first system message. if llm_request.config.system_instruction: messages.append( { - 'role': 'system', - 'content': llm_request.config.system_instruction, + "role": "system", + "content": llm_request.config.system_instruction, } ) + + # User / assistant / tool messages. for content in llm_request.contents: message_text = self._content_to_text(content) if not message_text: continue role = self._map_role(content.role) - messages.append({'role': role, 'content': message_text}) + messages.append({"role": role, "content": message_text}) + return messages def _content_to_text(self, content: types.Content) -> str: + """Flatten a `Content` into plain text for Ollama. + + Encodes tool calls and tool responses as tagged lines so that the model + can reason about them and generate new tool calls. + """ parts = content.parts or [] text_parts: list[str] = [] + for part in parts: if part.text: text_parts.append(part.text) + elif part.function_response: + # Tool result from a previous call. try: response_json = json.dumps( part.function_response.response, ensure_ascii=False @@ -128,7 +171,9 @@ def _content_to_text(self, content: types.Content) -> str: f"[tool_response name={part.function_response.name or ''}] " f"{response_json}" ) + elif part.function_call: + # A model-issued tool call (arguments as JSON). try: args_json = json.dumps(part.function_call.args, ensure_ascii=False) except TypeError: @@ -136,23 +181,28 @@ def _content_to_text(self, content: types.Content) -> str: text_parts.append( f"[tool_call name={part.function_call.name}] {args_json}" ) + else: logger.debug( - 'Skipping unsupported content part for Ollama message: %s', part + "Skipping unsupported content part for Ollama message: %s", part ) - return '\n'.join(text_parts) + + return "\n".join(text_parts) def _map_role(self, role: Optional[str]) -> str: - if role in ('model', 'assistant'): - return 'assistant' - if role == 'system': - return 'system' - return 'user' + if role in ("model", "assistant"): + return "assistant" + if role == "system": + return "system" + # "user", "tool", or anything else defaults to "user". + return "user" def _convert_tools(self, llm_request: LlmRequest) -> list[dict[str, Any]]: + """Convert ADK tool declarations into Ollama tool schema.""" tools_spec: list[dict[str, Any]] = [] if not llm_request.config.tools: return tools_spec + for tool in llm_request.config.tools: function_declarations: Optional[Sequence[types.FunctionDeclaration]] = ( tool.function_declarations @@ -161,26 +211,30 @@ def _convert_tools(self, llm_request: LlmRequest) -> list[dict[str, Any]]: ) if not function_declarations: continue + for function_declaration in function_declarations: tools_spec.append( { - 'type': 'function', - 'function': { - 'name': function_declaration.name, - 'description': function_declaration.description or '', - 'parameters': self._function_parameters_to_json( + "type": "function", + "function": { + "name": function_declaration.name, + "description": function_declaration.description or "", + "parameters": self._function_parameters_to_json( function_declaration ), }, } ) + return tools_spec def _function_parameters_to_json( self, function_declaration: types.FunctionDeclaration ) -> dict[str, Any]: + """Convert function parameters Schema → JSON Schema for Ollama.""" if function_declaration.parameters is None: - return {'type': 'object', 'properties': {}} + return {"type": "object", "properties": {}} + try: return function_declaration.parameters.model_dump(exclude_none=True) except AttributeError: @@ -192,88 +246,152 @@ def _function_parameters_to_json( except (AttributeError, json.JSONDecodeError, TypeError) as exc: logger.debug( ( - 'Failed to convert function parameters, defaulting to empty' - ' schema: %s' + "Failed to convert function parameters, defaulting to empty" + " schema: %s" ), exc, ) - return {'type': 'object', 'properties': {}} + return {"type": "object", "properties": {}} def _convert_options(self, llm_request: LlmRequest) -> dict[str, Any]: + """Map ADK generation config fields to Ollama options.""" options: dict[str, Any] = {} config = llm_request.config - temperature = getattr(config, 'temperature', None) + + temperature = getattr(config, "temperature", None) if temperature is not None: - options['temperature'] = temperature - top_p = getattr(config, 'top_p', None) + options["temperature"] = temperature + + top_p = getattr(config, "top_p", None) if top_p is not None: - options['top_p'] = top_p - max_output_tokens = getattr(config, 'max_output_tokens', None) + options["top_p"] = top_p + + max_output_tokens = getattr(config, "max_output_tokens", None) if max_output_tokens is not None: - options['num_predict'] = max_output_tokens + # Ollama uses `num_predict` to limit generated tokens. + options["num_predict"] = max_output_tokens + return options + # --------------------------------------------------------------------------- + # HTTP call + # --------------------------------------------------------------------------- + def _post_chat(self, payload: dict[str, Any]) -> dict[str, Any]: - url = self.host.rstrip('/') + _CHAT_ENDPOINT - data = json.dumps(payload).encode('utf-8') + """Perform a blocking POST /api/chat call to Ollama.""" + url = self.host.rstrip("/") + _CHAT_ENDPOINT + data = json.dumps(payload).encode("utf-8") request = urllib.request.Request( url, data=data, - headers={'Content-Type': 'application/json'}, - method='POST', + headers={"Content-Type": "application/json"}, + method="POST", ) + try: with urllib.request.urlopen( request, timeout=self.request_timeout ) as response: - response_body = response.read().decode('utf-8') + response_body = response.read().decode("utf-8") except urllib.error.URLError as exc: raise RuntimeError(exc.reason) from exc except urllib.error.HTTPError as exc: - message = exc.read().decode('utf-8', errors='ignore') - raise RuntimeError(f'{exc.code}: {message}') from exc + message = exc.read().decode("utf-8", errors="ignore") + raise RuntimeError(f"{exc.code}: {message}") from exc + return json.loads(response_body) - def _to_llm_response(self, response_json: dict[str, Any]) -> LlmResponse: - if error := response_json.get('error'): + # --------------------------------------------------------------------------- + # Response mapping + # --------------------------------------------------------------------------- + + def _to_llm_response( + self, + response_json: dict[str, Any], + request_model: Optional[str] = None, + ) -> LlmResponse: + """Convert Ollama JSON response → ADK `LlmResponse`.""" + if error := response_json.get("error"): return LlmResponse( - error_code='OLLAMA_ERROR', + error_code="OLLAMA_ERROR", error_message=str(error), ) - message = response_json.get('message', {}) + message = response_json.get("message", {}) or {} parts: list[types.Part] = [] - content = message.get('content') + # 1) Main text content. + content = message.get("content") if isinstance(content, str) and content.strip(): parts.append(types.Part.from_text(text=content)) - for tool_call in message.get('tool_calls', []): - function_payload = tool_call.get('function', {}) - name = function_payload.get('name') + # 2) Tool calls (if any). + for tool_call in message.get("tool_calls", []): + function_payload = tool_call.get("function", {}) or {} + name = function_payload.get("name") arguments: Union[str, dict[str, Any], None] = function_payload.get( - 'arguments' + "arguments" ) + if isinstance(arguments, str): try: arguments = json.loads(arguments) except json.JSONDecodeError: - logger.debug( - 'Failed to parse tool call arguments as JSON: %s', arguments + logger.warning( + "Failed to parse tool call arguments as JSON: %s. " + "Defaulting to empty arguments.", + arguments, ) + arguments = {} elif arguments is None: arguments = {} + function_call = types.FunctionCall(name=name, args=arguments) - if tool_call_id := tool_call.get('id'): - setattr(function_call, 'id', tool_call_id) + if tool_call_id := tool_call.get("id"): + # id is useful for correlating tool_call ↔ tool_response. + setattr(function_call, "id", tool_call_id) + parts.append(types.Part(function_call=function_call)) if not parts: return LlmResponse( - error_code='NO_CONTENT', - error_message='Ollama response did not contain model output.', + error_code="NO_CONTENT", + error_message="Ollama response did not contain model output.", ) - return LlmResponse( - content=types.Content(role='model', parts=parts), + # 3) Usage mapping (Ollama → GenerateContentResponseUsageMetadata). + # Ollama returns: + # prompt_eval_count: tokens in prompt + # eval_count: tokens in completion + prompt_tokens = response_json.get("prompt_eval_count") + completion_tokens = response_json.get("eval_count") + + # Fallback: if someone wraps usage in a dict (e.g. in tests). + if prompt_tokens is None or completion_tokens is None: + usage = response_json.get("usage") or {} + if prompt_tokens is None: + prompt_tokens = usage.get("prompt_tokens") + if completion_tokens is None: + completion_tokens = usage.get("completion_tokens") + + usage_metadata: Optional[types.GenerateContentResponseUsageMetadata] = None + if prompt_tokens is not None and completion_tokens is not None: + total_tokens = response_json.get("total_tokens") + if total_tokens is None: + total_tokens = prompt_tokens + completion_tokens + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=prompt_tokens, + candidates_token_count=completion_tokens, + total_token_count=total_tokens, + ) + + # 4) Model version: prefer Ollama's `model`, fallback to request model. + model_version = response_json.get("model") or self._extract_model_name( + request_model or self.model ) + + return LlmResponse( + content=types.Content(role="model", parts=parts), + model_version=model_version, + usage_metadata=usage_metadata, + ) \ No newline at end of file diff --git a/tests/unittests/models/test_ollama.py b/tests/unittests/models/test_ollama.py index f68c20c558..b75f2a8b85 100644 --- a/tests/unittests/models/test_ollama.py +++ b/tests/unittests/models/test_ollama.py @@ -12,42 +12,23 @@ # -# -------------------------- -# Helpers & Fixtures -# -------------------------- +# ----------------------------------- +# Helpers +# ----------------------------------- # -@pytest.fixture -def mock_ollama_client(): - """Mock HTTP client for Ollama POST calls.""" - class Client: - def __init__(self, response): - self.response = response - self.calls = [] - - def post(self, payload): - self.calls.append(payload) - return self.response - - return Client - - -def mock_response_ok(text="Hello world", tool_calls=None, usage=None): - """Creates a fake Ollama /api/chat response.""" +def mock_response_ok(text="Hello world", tool_calls=None): + """Create a typical Ollama /api/chat response.""" message = {"content": text} if tool_calls: message["tool_calls"] = tool_calls - - resp = {"message": message} - if usage: - resp["usage"] = usage - return resp + return {"message": message} # -# -------------------------- -# Test: model extraction -# -------------------------- +# ----------------------------------- +# Test: model extraction +# ----------------------------------- # def test_extract_model_name_basic(): @@ -66,23 +47,16 @@ def test_extract_model_name_no_prefix(): # -# -------------------------- -# Test: message conversion -# -------------------------- +# ----------------------------------- +# Test: message conversion +# ----------------------------------- # def test_convert_messages_basic(): o = Ollama() - - req = LlmRequest( - contents=[ - types.Content( - role="user", - parts=[types.Part.from_text("Hi")] - ) - ] - ) - + req = LlmRequest(contents=[ + types.Content(role="user", parts=[types.Part.from_text("Hi")]) + ]) msgs = o._convert_messages(req) assert msgs[0]["role"] == "user" assert msgs[0]["content"] == "Hi" @@ -90,12 +64,10 @@ def test_convert_messages_basic(): def test_convert_messages_with_system(): o = Ollama() - req = LlmRequest( contents=[types.Content(role="user", parts=[types.Part.from_text("X")])], config=types.GenerateContentConfig(system_instruction="SYS") ) - msgs = o._convert_messages(req) assert msgs[0]["role"] == "system" assert msgs[0]["content"] == "SYS" @@ -103,9 +75,9 @@ def test_convert_messages_with_system(): # -# -------------------------- -# Test: _content_to_text -# -------------------------- +# ----------------------------------- +# Test: content → text +# ----------------------------------- # def test_content_to_text_basic(): @@ -116,72 +88,59 @@ def test_content_to_text_basic(): def test_content_to_text_function_call(): o = Ollama() - - part = types.Part.from_function_call( - name="add", - args={"x": 1, "y": 2}, - ) + part = types.Part.from_function_call(name="add", args={"x": 1, "y": 2}) part.function_call.id = "call123" content = types.Content(role="assistant", parts=[part]) txt = o._content_to_text(content) assert "[tool_call name=add]" in txt - assert '{"x": 1, "y": 2}' in txt + assert '"x": 1' in txt def test_content_to_text_tool_response(): o = Ollama() - - part = types.Part.from_function_response( - name="add", response={"result": 3} - ) - part.function_response.id = "id123" - + part = types.Part.from_function_response(name="add", response={"z": 5}) content = types.Content(role="tool", parts=[part]) txt = o._content_to_text(content) assert "[tool_response name=add]" in txt - assert '"result": 3' in txt + assert '"z": 5' in txt # -# -------------------------- -# Test: _convert_tools -# -------------------------- +# ----------------------------------- +# Test: tool conversion +# ----------------------------------- # def test_convert_tools_basic(): o = Ollama() - req = LlmRequest( config=types.GenerateContentConfig( tools=[ - types.Tool( - function_declarations=[ - types.FunctionDeclaration( - name="add", - description="Add numbers", - parameters=types.Schema( - type=types.Type.OBJECT, - properties={"x": types.Schema(type=types.Type.NUMBER)} - ) + types.Tool(function_declarations=[ + types.FunctionDeclaration( + name="add", + description="Add numbers", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={"x": types.Schema(type=types.Type.NUMBER)} ) - ] - ) + ) + ]) ] ) ) - tools = o._convert_tools(req) assert tools[0]["function"]["name"] == "add" assert tools[0]["function"]["parameters"]["type"] == "object" # -# -------------------------- -# Test: HTTP call wrapper -# -------------------------- +# ----------------------------------- +# Test: POST wrapper +# ----------------------------------- # def test_post_chat_success(monkeypatch): @@ -201,9 +160,9 @@ def read(self): # -# -------------------------- -# Test: _to_llm_response -# -------------------------- +# ----------------------------------- +# Test: _to_llm_response +# ----------------------------------- # def test_to_llm_response_text(): @@ -217,7 +176,6 @@ def test_to_llm_response_text(): def test_to_llm_response_tool_call(): o = Ollama() - tool_call = { "id": "abc", "function": { @@ -235,24 +193,41 @@ def test_to_llm_response_tool_call(): assert fc.id == "abc" +def test_to_llm_response_tool_call_bad_json(): + o = Ollama() + tool_call = { + "id": "zzz", + "function": { + "name": "add", + "arguments": "{BAD_JSON" + } + } + + resp = mock_response_ok(tool_calls=[tool_call]) + out = o._to_llm_response(resp) + + fc = out.content.parts[0].function_call + assert fc.args == {} # BAD JSON → fallback to {} + + def test_to_llm_response_usage_metadata(): o = Ollama() - resp = mock_response_ok( - text="Hi", - usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} - ) + resp = mock_response_ok("Hi") + resp["prompt_eval_count"] = 10 + resp["eval_count"] = 5 out = o._to_llm_response(resp) + assert out.usage_metadata is not None assert out.usage_metadata.prompt_token_count == 10 assert out.usage_metadata.candidates_token_count == 5 assert out.usage_metadata.total_token_count == 15 # -# -------------------------- -# Test: full generate_content_async -# -------------------------- +# ----------------------------------- +# async: generate_content_async +# ----------------------------------- # @pytest.mark.asyncio @@ -265,10 +240,9 @@ async def fake_thread(fn, *args): monkeypatch.setattr("asyncio.to_thread", fake_thread) o = Ollama(model="ollama/mistral") - - req = LlmRequest( - contents=[types.Content(role="user", parts=[types.Part.from_text("Hi")])] - ) + req = LlmRequest(contents=[ + types.Content(role="user", parts=[types.Part.from_text("Hi")]) + ]) results = [r async for r in o.generate_content_async(req)] assert results[0].content.parts[0].text == "Hello!" @@ -282,7 +256,6 @@ async def fake_thread(fn, *args): monkeypatch.setattr("asyncio.to_thread", fake_thread) o = Ollama() - req = LlmRequest(contents=[types.Content(role="user", parts=[])]) results = [r async for r in o.generate_content_async(req)] @@ -290,16 +263,19 @@ async def fake_thread(fn, *args): # -# -------------------------- -# Test: model override -# -------------------------- +# ----------------------------------- +# Test: model override +# ----------------------------------- # @pytest.mark.asyncio async def test_model_override(monkeypatch): resp = mock_response_ok("Hello") + resp["model"] = "override" async def fake_thread(fn, *args): + payload = args[0] + assert payload["model"] == "override" # important return resp monkeypatch.setattr("asyncio.to_thread", fake_thread) From 22884612a1e694f68cde5556b2c15f8e99e43e9f Mon Sep 17 00:00:00 2001 From: Ayman Hamed Date: Wed, 19 Nov 2025 07:46:10 +0200 Subject: [PATCH 3/9] Fix formatting and imports for CI --- src/google/adk/models/ollama_llm.py | 95 +++------ tests/unittests/models/test_ollama.py | 289 +++++++++++++------------- 2 files changed, 173 insertions(+), 211 deletions(-) diff --git a/src/google/adk/models/ollama_llm.py b/src/google/adk/models/ollama_llm.py index ec114cc625..8bae0aa7ff 100644 --- a/src/google/adk/models/ollama_llm.py +++ b/src/google/adk/models/ollama_llm.py @@ -17,7 +17,11 @@ import asyncio import json import logging -from typing import Any, AsyncGenerator, Optional, Sequence, Union +from typing import Any +from typing import AsyncGenerator +from typing import Optional +from typing import Sequence +from typing import Union import urllib.error import urllib.request @@ -72,9 +76,7 @@ async def generate_content_async( self, llm_request: LlmRequest, stream: bool = False ) -> AsyncGenerator[LlmResponse, None]: if stream: - logger.warning( - "Streaming is not yet supported for Ollama; falling back to unary." - ) + logger.warning("Streaming is not yet supported for Ollama; falling back to unary.") # Ensure last user content is appended if needed (BaseLlm helper). self._maybe_append_user_content(llm_request) @@ -87,9 +89,7 @@ async def generate_content_async( yield LlmResponse(error_code="OLLAMA_ERROR", error_message=str(exc)) return - llm_response = self._to_llm_response( - response_json, request_model=llm_request.model - ) + llm_response = self._to_llm_response(response_json, request_model=llm_request.model) yield llm_response # --------------------------------------------------------------------------- @@ -129,12 +129,10 @@ def _convert_messages(self, llm_request: LlmRequest) -> list[dict[str, str]]: # System instruction → first system message. if llm_request.config.system_instruction: - messages.append( - { - "role": "system", - "content": llm_request.config.system_instruction, - } - ) + messages.append({ + "role": "system", + "content": llm_request.config.system_instruction, + }) # User / assistant / tool messages. for content in llm_request.contents: @@ -162,15 +160,10 @@ def _content_to_text(self, content: types.Content) -> str: elif part.function_response: # Tool result from a previous call. try: - response_json = json.dumps( - part.function_response.response, ensure_ascii=False - ) + response_json = json.dumps(part.function_response.response, ensure_ascii=False) except TypeError: response_json = str(part.function_response.response) - text_parts.append( - f"[tool_response name={part.function_response.name or ''}] " - f"{response_json}" - ) + text_parts.append(f"[tool_response name={part.function_response.name or ''}] {response_json}") elif part.function_call: # A model-issued tool call (arguments as JSON). @@ -178,14 +171,10 @@ def _content_to_text(self, content: types.Content) -> str: args_json = json.dumps(part.function_call.args, ensure_ascii=False) except TypeError: args_json = str(part.function_call.args) - text_parts.append( - f"[tool_call name={part.function_call.name}] {args_json}" - ) + text_parts.append(f"[tool_call name={part.function_call.name}] {args_json}") else: - logger.debug( - "Skipping unsupported content part for Ollama message: %s", part - ) + logger.debug("Skipping unsupported content part for Ollama message: %s", part) return "\n".join(text_parts) @@ -205,32 +194,24 @@ def _convert_tools(self, llm_request: LlmRequest) -> list[dict[str, Any]]: for tool in llm_request.config.tools: function_declarations: Optional[Sequence[types.FunctionDeclaration]] = ( - tool.function_declarations - if isinstance(tool, types.Tool) - else None + tool.function_declarations if isinstance(tool, types.Tool) else None ) if not function_declarations: continue for function_declaration in function_declarations: - tools_spec.append( - { - "type": "function", - "function": { - "name": function_declaration.name, - "description": function_declaration.description or "", - "parameters": self._function_parameters_to_json( - function_declaration - ), - }, - } - ) + tools_spec.append({ + "type": "function", + "function": { + "name": function_declaration.name, + "description": function_declaration.description or "", + "parameters": self._function_parameters_to_json(function_declaration), + }, + }) return tools_spec - def _function_parameters_to_json( - self, function_declaration: types.FunctionDeclaration - ) -> dict[str, Any]: + def _function_parameters_to_json(self, function_declaration: types.FunctionDeclaration) -> dict[str, Any]: """Convert function parameters Schema → JSON Schema for Ollama.""" if function_declaration.parameters is None: return {"type": "object", "properties": {}} @@ -240,15 +221,10 @@ def _function_parameters_to_json( except AttributeError: # model_dump is not guaranteed depending on the genai version. try: - return json.loads( - function_declaration.parameters.model_dump_json(exclude_none=True) - ) + return json.loads(function_declaration.parameters.model_dump_json(exclude_none=True)) except (AttributeError, json.JSONDecodeError, TypeError) as exc: logger.debug( - ( - "Failed to convert function parameters, defaulting to empty" - " schema: %s" - ), + "Failed to convert function parameters, defaulting to empty schema: %s", exc, ) return {"type": "object", "properties": {}} @@ -289,9 +265,7 @@ def _post_chat(self, payload: dict[str, Any]) -> dict[str, Any]: ) try: - with urllib.request.urlopen( - request, timeout=self.request_timeout - ) as response: + with urllib.request.urlopen(request, timeout=self.request_timeout) as response: response_body = response.read().decode("utf-8") except urllib.error.URLError as exc: raise RuntimeError(exc.reason) from exc @@ -329,17 +303,14 @@ def _to_llm_response( for tool_call in message.get("tool_calls", []): function_payload = tool_call.get("function", {}) or {} name = function_payload.get("name") - arguments: Union[str, dict[str, Any], None] = function_payload.get( - "arguments" - ) + arguments: Union[str, dict[str, Any], None] = function_payload.get("arguments") if isinstance(arguments, str): try: arguments = json.loads(arguments) except json.JSONDecodeError: logger.warning( - "Failed to parse tool call arguments as JSON: %s. " - "Defaulting to empty arguments.", + "Failed to parse tool call arguments as JSON: %s. Defaulting to empty arguments.", arguments, ) arguments = {} @@ -386,12 +357,10 @@ def _to_llm_response( ) # 4) Model version: prefer Ollama's `model`, fallback to request model. - model_version = response_json.get("model") or self._extract_model_name( - request_model or self.model - ) + model_version = response_json.get("model") or self._extract_model_name(request_model or self.model) return LlmResponse( content=types.Content(role="model", parts=parts), model_version=model_version, usage_metadata=usage_metadata, - ) \ No newline at end of file + ) diff --git a/tests/unittests/models/test_ollama.py b/tests/unittests/models/test_ollama.py index b75f2a8b85..70a7e2dae4 100644 --- a/tests/unittests/models/test_ollama.py +++ b/tests/unittests/models/test_ollama.py @@ -2,14 +2,14 @@ # Tests for native Ollama integration for Google ADK. import json -import pytest -from unittest.mock import Mock, AsyncMock +from unittest.mock import AsyncMock +from unittest.mock import Mock -from google.adk.models.ollama_llm import Ollama from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse +from google.adk.models.ollama_llm import Ollama from google.genai import types - +import pytest # # ----------------------------------- @@ -17,12 +17,13 @@ # ----------------------------------- # + def mock_response_ok(text="Hello world", tool_calls=None): - """Create a typical Ollama /api/chat response.""" - message = {"content": text} - if tool_calls: - message["tool_calls"] = tool_calls - return {"message": message} + """Create a typical Ollama /api/chat response.""" + message = {"content": text} + if tool_calls: + message["tool_calls"] = tool_calls + return {"message": message} # @@ -31,19 +32,20 @@ def mock_response_ok(text="Hello world", tool_calls=None): # ----------------------------------- # + def test_extract_model_name_basic(): - o = Ollama(model="ollama/mistral") - assert o._extract_model_name("ollama/mistral") == "mistral" + o = Ollama(model="ollama/mistral") + assert o._extract_model_name("ollama/mistral") == "mistral" def test_extract_model_name_chat_prefix(): - o = Ollama(model="ollama_chat/llama3.1") - assert o._extract_model_name("ollama_chat/llama3.1") == "llama3.1" + o = Ollama(model="ollama_chat/llama3.1") + assert o._extract_model_name("ollama_chat/llama3.1") == "llama3.1" def test_extract_model_name_no_prefix(): - o = Ollama(model="mistral") - assert o._extract_model_name("mistral") == "mistral" + o = Ollama(model="mistral") + assert o._extract_model_name("mistral") == "mistral" # @@ -52,26 +54,25 @@ def test_extract_model_name_no_prefix(): # ----------------------------------- # + def test_convert_messages_basic(): - o = Ollama() - req = LlmRequest(contents=[ - types.Content(role="user", parts=[types.Part.from_text("Hi")]) - ]) - msgs = o._convert_messages(req) - assert msgs[0]["role"] == "user" - assert msgs[0]["content"] == "Hi" + o = Ollama() + req = LlmRequest(contents=[types.Content(role="user", parts=[types.Part.from_text("Hi")])]) + msgs = o._convert_messages(req) + assert msgs[0]["role"] == "user" + assert msgs[0]["content"] == "Hi" def test_convert_messages_with_system(): - o = Ollama() - req = LlmRequest( - contents=[types.Content(role="user", parts=[types.Part.from_text("X")])], - config=types.GenerateContentConfig(system_instruction="SYS") - ) - msgs = o._convert_messages(req) - assert msgs[0]["role"] == "system" - assert msgs[0]["content"] == "SYS" - assert msgs[1]["content"] == "X" + o = Ollama() + req = LlmRequest( + contents=[types.Content(role="user", parts=[types.Part.from_text("X")])], + config=types.GenerateContentConfig(system_instruction="SYS"), + ) + msgs = o._convert_messages(req) + assert msgs[0]["role"] == "system" + assert msgs[0]["content"] == "SYS" + assert msgs[1]["content"] == "X" # @@ -80,32 +81,33 @@ def test_convert_messages_with_system(): # ----------------------------------- # + def test_content_to_text_basic(): - o = Ollama() - content = types.Content(role="user", parts=[types.Part.from_text("ABC")]) - assert o._content_to_text(content) == "ABC" + o = Ollama() + content = types.Content(role="user", parts=[types.Part.from_text("ABC")]) + assert o._content_to_text(content) == "ABC" def test_content_to_text_function_call(): - o = Ollama() - part = types.Part.from_function_call(name="add", args={"x": 1, "y": 2}) - part.function_call.id = "call123" + o = Ollama() + part = types.Part.from_function_call(name="add", args={"x": 1, "y": 2}) + part.function_call.id = "call123" - content = types.Content(role="assistant", parts=[part]) - txt = o._content_to_text(content) + content = types.Content(role="assistant", parts=[part]) + txt = o._content_to_text(content) - assert "[tool_call name=add]" in txt - assert '"x": 1' in txt + assert "[tool_call name=add]" in txt + assert '"x": 1' in txt def test_content_to_text_tool_response(): - o = Ollama() - part = types.Part.from_function_response(name="add", response={"z": 5}) - content = types.Content(role="tool", parts=[part]) - txt = o._content_to_text(content) + o = Ollama() + part = types.Part.from_function_response(name="add", response={"z": 5}) + content = types.Content(role="tool", parts=[part]) + txt = o._content_to_text(content) - assert "[tool_response name=add]" in txt - assert '"z": 5' in txt + assert "[tool_response name=add]" in txt + assert '"z": 5' in txt # @@ -114,27 +116,29 @@ def test_content_to_text_tool_response(): # ----------------------------------- # + def test_convert_tools_basic(): - o = Ollama() - req = LlmRequest( - config=types.GenerateContentConfig( - tools=[ - types.Tool(function_declarations=[ - types.FunctionDeclaration( - name="add", - description="Add numbers", - parameters=types.Schema( - type=types.Type.OBJECT, - properties={"x": types.Schema(type=types.Type.NUMBER)} - ) - ) - ]) - ] - ) - ) - tools = o._convert_tools(req) - assert tools[0]["function"]["name"] == "add" - assert tools[0]["function"]["parameters"]["type"] == "object" + o = Ollama() + req = LlmRequest( + config=types.GenerateContentConfig( + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="add", + description="Add numbers", + parameters=types.Schema( + type=types.Type.OBJECT, properties={"x": types.Schema(type=types.Type.NUMBER)} + ), + ) + ] + ) + ] + ) + ) + tools = o._convert_tools(req) + assert tools[0]["function"]["name"] == "add" + assert tools[0]["function"]["parameters"]["type"] == "object" # @@ -143,20 +147,23 @@ def test_convert_tools_basic(): # ----------------------------------- # + def test_post_chat_success(monkeypatch): - fake_response = {"message": {"content": "OK"}} + fake_response = {"message": {"content": "OK"}} + + def fake_urlopen(req, timeout=0): + class Resp: - def fake_urlopen(req, timeout=0): - class Resp: - def read(self): - return json.dumps(fake_response).encode("utf-8") - return Resp() + def read(self): + return json.dumps(fake_response).encode("utf-8") - monkeypatch.setattr("urllib.request.urlopen", fake_urlopen) + return Resp() - o = Ollama() - resp = o._post_chat({"model": "x"}) - assert resp["message"]["content"] == "OK" + monkeypatch.setattr("urllib.request.urlopen", fake_urlopen) + + o = Ollama() + resp = o._post_chat({"model": "x"}) + assert resp["message"]["content"] == "OK" # @@ -165,63 +172,52 @@ def read(self): # ----------------------------------- # + def test_to_llm_response_text(): - o = Ollama() - resp = mock_response_ok("Hi") + o = Ollama() + resp = mock_response_ok("Hi") - out = o._to_llm_response(resp) - assert isinstance(out, LlmResponse) - assert out.content.parts[0].text == "Hi" + out = o._to_llm_response(resp) + assert isinstance(out, LlmResponse) + assert out.content.parts[0].text == "Hi" def test_to_llm_response_tool_call(): - o = Ollama() - tool_call = { - "id": "abc", - "function": { - "name": "add", - "arguments": '{"x": 1}' - } - } + o = Ollama() + tool_call = {"id": "abc", "function": {"name": "add", "arguments": '{"x": 1}'}} - resp = mock_response_ok(tool_calls=[tool_call]) - out = o._to_llm_response(resp) + resp = mock_response_ok(tool_calls=[tool_call]) + out = o._to_llm_response(resp) - fc = out.content.parts[0].function_call - assert fc.name == "add" - assert fc.args == {"x": 1} - assert fc.id == "abc" + fc = out.content.parts[0].function_call + assert fc.name == "add" + assert fc.args == {"x": 1} + assert fc.id == "abc" def test_to_llm_response_tool_call_bad_json(): - o = Ollama() - tool_call = { - "id": "zzz", - "function": { - "name": "add", - "arguments": "{BAD_JSON" - } - } + o = Ollama() + tool_call = {"id": "zzz", "function": {"name": "add", "arguments": "{BAD_JSON"}} - resp = mock_response_ok(tool_calls=[tool_call]) - out = o._to_llm_response(resp) + resp = mock_response_ok(tool_calls=[tool_call]) + out = o._to_llm_response(resp) - fc = out.content.parts[0].function_call - assert fc.args == {} # BAD JSON → fallback to {} + fc = out.content.parts[0].function_call + assert fc.args == {} # BAD JSON → fallback to {} def test_to_llm_response_usage_metadata(): - o = Ollama() - resp = mock_response_ok("Hi") - resp["prompt_eval_count"] = 10 - resp["eval_count"] = 5 + o = Ollama() + resp = mock_response_ok("Hi") + resp["prompt_eval_count"] = 10 + resp["eval_count"] = 5 - out = o._to_llm_response(resp) + out = o._to_llm_response(resp) - assert out.usage_metadata is not None - assert out.usage_metadata.prompt_token_count == 10 - assert out.usage_metadata.candidates_token_count == 5 - assert out.usage_metadata.total_token_count == 15 + assert out.usage_metadata is not None + assert out.usage_metadata.prompt_token_count == 10 + assert out.usage_metadata.candidates_token_count == 5 + assert out.usage_metadata.total_token_count == 15 # @@ -230,36 +226,35 @@ def test_to_llm_response_usage_metadata(): # ----------------------------------- # + @pytest.mark.asyncio async def test_generate_content_async_basic(monkeypatch): - resp = mock_response_ok("Hello!") + resp = mock_response_ok("Hello!") - async def fake_thread(fn, *args): - return resp + async def fake_thread(fn, *args): + return resp - monkeypatch.setattr("asyncio.to_thread", fake_thread) + monkeypatch.setattr("asyncio.to_thread", fake_thread) - o = Ollama(model="ollama/mistral") - req = LlmRequest(contents=[ - types.Content(role="user", parts=[types.Part.from_text("Hi")]) - ]) + o = Ollama(model="ollama/mistral") + req = LlmRequest(contents=[types.Content(role="user", parts=[types.Part.from_text("Hi")])]) - results = [r async for r in o.generate_content_async(req)] - assert results[0].content.parts[0].text == "Hello!" + results = [r async for r in o.generate_content_async(req)] + assert results[0].content.parts[0].text == "Hello!" @pytest.mark.asyncio async def test_generate_content_async_error(monkeypatch): - async def fake_thread(fn, *args): - raise RuntimeError("boom") + async def fake_thread(fn, *args): + raise RuntimeError("boom") - monkeypatch.setattr("asyncio.to_thread", fake_thread) + monkeypatch.setattr("asyncio.to_thread", fake_thread) - o = Ollama() - req = LlmRequest(contents=[types.Content(role="user", parts=[])]) + o = Ollama() + req = LlmRequest(contents=[types.Content(role="user", parts=[])]) - results = [r async for r in o.generate_content_async(req)] - assert results[0].error_code == "OLLAMA_ERROR" + results = [r async for r in o.generate_content_async(req)] + assert results[0].error_code == "OLLAMA_ERROR" # @@ -268,23 +263,21 @@ async def fake_thread(fn, *args): # ----------------------------------- # + @pytest.mark.asyncio async def test_model_override(monkeypatch): - resp = mock_response_ok("Hello") - resp["model"] = "override" + resp = mock_response_ok("Hello") + resp["model"] = "override" - async def fake_thread(fn, *args): - payload = args[0] - assert payload["model"] == "override" # important - return resp + async def fake_thread(fn, *args): + payload = args[0] + assert payload["model"] == "override" # important + return resp - monkeypatch.setattr("asyncio.to_thread", fake_thread) + monkeypatch.setattr("asyncio.to_thread", fake_thread) - o = Ollama(model="default") - req = LlmRequest( - model="override", - contents=[types.Content(role="user", parts=[types.Part.from_text("X")])] - ) + o = Ollama(model="default") + req = LlmRequest(model="override", contents=[types.Content(role="user", parts=[types.Part.from_text("X")])]) - out = [r async for r in o.generate_content_async(req)][0] - assert out.model_version == "override" \ No newline at end of file + out = [r async for r in o.generate_content_async(req)][0] + assert out.model_version == "override" From b909b566c901a3ba2e1c7e966941ff3be0b28d34 Mon Sep 17 00:00:00 2001 From: Ayman Hamed Date: Wed, 19 Nov 2025 08:00:35 +0200 Subject: [PATCH 4/9] Fix formatting and imports for CI --- .../hello_world_ollama_native/agent.py | 2 +- src/google/adk/models/ollama_llm.py | 32 +++++++++++++++++-- tests/unittests/models/test_ollama.py | 15 +++++++-- 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/contributing/samples/hello_world_ollama_native/agent.py b/contributing/samples/hello_world_ollama_native/agent.py index fa4e0bd1d4..ff0584ad2b 100755 --- a/contributing/samples/hello_world_ollama_native/agent.py +++ b/contributing/samples/hello_world_ollama_native/agent.py @@ -1,4 +1,4 @@ -# Copyright 2025 Ayman Hamed +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/google/adk/models/ollama_llm.py b/src/google/adk/models/ollama_llm.py index 8bae0aa7ff..ea46e2bd69 100644 --- a/src/google/adk/models/ollama_llm.py +++ b/src/google/adk/models/ollama_llm.py @@ -1,4 +1,4 @@ -# Copyright 2025 Ayman Hamed +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -128,10 +128,11 @@ def _convert_messages(self, llm_request: LlmRequest) -> list[dict[str, str]]: messages: list[dict[str, str]] = [] # System instruction → first system message. - if llm_request.config.system_instruction: + system_instruction = llm_request.config.system_instruction + if system_instruction: messages.append({ "role": "system", - "content": llm_request.config.system_instruction, + "content": self._system_instruction_to_text(system_instruction), }) # User / assistant / tool messages. @@ -144,6 +145,31 @@ def _convert_messages(self, llm_request: LlmRequest) -> list[dict[str, str]]: return messages + def _system_instruction_to_text(self, system_instruction: Any) -> str: + """Normalize `system_instruction` into plain text. + + It may be: + * a plain string + * a types.Content object + * a list/tuple of types.Content and/or strings + """ + # Single Content object + if isinstance(system_instruction, types.Content): + return self._content_to_text(system_instruction) + + # Sequence of items (e.g. list[Content]) + if isinstance(system_instruction, (list, tuple)): + pieces: list[str] = [] + for item in system_instruction: + if isinstance(item, types.Content): + pieces.append(self._content_to_text(item)) + elif item is not None: + pieces.append(str(item)) + return "\n".join(pieces) + + # Fallback: assume it's already string-like + return str(system_instruction) + def _content_to_text(self, content: types.Content) -> str: """Flatten a `Content` into plain text for Ollama. diff --git a/tests/unittests/models/test_ollama.py b/tests/unittests/models/test_ollama.py index 70a7e2dae4..a2c032fba4 100644 --- a/tests/unittests/models/test_ollama.py +++ b/tests/unittests/models/test_ollama.py @@ -1,5 +1,16 @@ -# Copyright 2025 -# Tests for native Ollama integration for Google ADK. +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json from unittest.mock import AsyncMock From f0b3f98ecfb598715a553d229db3ebd577b7bb7f Mon Sep 17 00:00:00 2001 From: Ayman Hamed Date: Fri, 21 Nov 2025 06:57:16 +0200 Subject: [PATCH 5/9] Fix hello_world_ollama_native/agent.py formatting --- contributing/samples/hello_world_ollama_native/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contributing/samples/hello_world_ollama_native/agent.py b/contributing/samples/hello_world_ollama_native/agent.py index ff0584ad2b..75639e0fbd 100755 --- a/contributing/samples/hello_world_ollama_native/agent.py +++ b/contributing/samples/hello_world_ollama_native/agent.py @@ -59,7 +59,7 @@ def check_prime(numbers: list[int]) -> str: root_agent = Agent( - model=Ollama(model='llama3.1'), + model=Ollama(model="llama3.1"), name="dice_roll_agent", description=( "hello world agent that can roll a dice of any number of sides and" From 6ceae6da26a1bb398d7db0418c4abeec51e3eac3 Mon Sep 17 00:00:00 2001 From: Ayman Hamed Date: Wed, 3 Dec 2025 23:04:36 +0200 Subject: [PATCH 6/9] Fix formatting for CI --- src/google/adk/models/ollama_llm.py | 59 +++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/src/google/adk/models/ollama_llm.py b/src/google/adk/models/ollama_llm.py index ea46e2bd69..063e580316 100644 --- a/src/google/adk/models/ollama_llm.py +++ b/src/google/adk/models/ollama_llm.py @@ -76,7 +76,9 @@ async def generate_content_async( self, llm_request: LlmRequest, stream: bool = False ) -> AsyncGenerator[LlmResponse, None]: if stream: - logger.warning("Streaming is not yet supported for Ollama; falling back to unary.") + logger.warning( + "Streaming is not yet supported for Ollama; falling back to unary." + ) # Ensure last user content is appended if needed (BaseLlm helper). self._maybe_append_user_content(llm_request) @@ -89,7 +91,9 @@ async def generate_content_async( yield LlmResponse(error_code="OLLAMA_ERROR", error_message=str(exc)) return - llm_response = self._to_llm_response(response_json, request_model=llm_request.model) + llm_response = self._to_llm_response( + response_json, request_model=llm_request.model + ) yield llm_response # --------------------------------------------------------------------------- @@ -119,7 +123,9 @@ def _extract_model_name(self, request_model: Optional[str]) -> str: * "llama3.2" → "llama3.2" """ model_name = request_model or self.model - if model_name.startswith("ollama/") or model_name.startswith("ollama_chat/"): + if model_name.startswith("ollama/") or model_name.startswith( + "ollama_chat/" + ): return model_name.split("/", 1)[1] return model_name @@ -186,10 +192,15 @@ def _content_to_text(self, content: types.Content) -> str: elif part.function_response: # Tool result from a previous call. try: - response_json = json.dumps(part.function_response.response, ensure_ascii=False) + response_json = json.dumps( + part.function_response.response, ensure_ascii=False + ) except TypeError: response_json = str(part.function_response.response) - text_parts.append(f"[tool_response name={part.function_response.name or ''}] {response_json}") + text_parts.append( + f"[tool_response name={part.function_response.name or ''}]" + f" {response_json}" + ) elif part.function_call: # A model-issued tool call (arguments as JSON). @@ -197,10 +208,14 @@ def _content_to_text(self, content: types.Content) -> str: args_json = json.dumps(part.function_call.args, ensure_ascii=False) except TypeError: args_json = str(part.function_call.args) - text_parts.append(f"[tool_call name={part.function_call.name}] {args_json}") + text_parts.append( + f"[tool_call name={part.function_call.name}] {args_json}" + ) else: - logger.debug("Skipping unsupported content part for Ollama message: %s", part) + logger.debug( + "Skipping unsupported content part for Ollama message: %s", part + ) return "\n".join(text_parts) @@ -231,13 +246,17 @@ def _convert_tools(self, llm_request: LlmRequest) -> list[dict[str, Any]]: "function": { "name": function_declaration.name, "description": function_declaration.description or "", - "parameters": self._function_parameters_to_json(function_declaration), + "parameters": self._function_parameters_to_json( + function_declaration + ), }, }) return tools_spec - def _function_parameters_to_json(self, function_declaration: types.FunctionDeclaration) -> dict[str, Any]: + def _function_parameters_to_json( + self, function_declaration: types.FunctionDeclaration + ) -> dict[str, Any]: """Convert function parameters Schema → JSON Schema for Ollama.""" if function_declaration.parameters is None: return {"type": "object", "properties": {}} @@ -247,10 +266,13 @@ def _function_parameters_to_json(self, function_declaration: types.FunctionDecla except AttributeError: # model_dump is not guaranteed depending on the genai version. try: - return json.loads(function_declaration.parameters.model_dump_json(exclude_none=True)) + return json.loads( + function_declaration.parameters.model_dump_json(exclude_none=True) + ) except (AttributeError, json.JSONDecodeError, TypeError) as exc: logger.debug( - "Failed to convert function parameters, defaulting to empty schema: %s", + "Failed to convert function parameters, defaulting to empty" + " schema: %s", exc, ) return {"type": "object", "properties": {}} @@ -291,7 +313,9 @@ def _post_chat(self, payload: dict[str, Any]) -> dict[str, Any]: ) try: - with urllib.request.urlopen(request, timeout=self.request_timeout) as response: + with urllib.request.urlopen( + request, timeout=self.request_timeout + ) as response: response_body = response.read().decode("utf-8") except urllib.error.URLError as exc: raise RuntimeError(exc.reason) from exc @@ -329,14 +353,17 @@ def _to_llm_response( for tool_call in message.get("tool_calls", []): function_payload = tool_call.get("function", {}) or {} name = function_payload.get("name") - arguments: Union[str, dict[str, Any], None] = function_payload.get("arguments") + arguments: Union[str, dict[str, Any], None] = function_payload.get( + "arguments" + ) if isinstance(arguments, str): try: arguments = json.loads(arguments) except json.JSONDecodeError: logger.warning( - "Failed to parse tool call arguments as JSON: %s. Defaulting to empty arguments.", + "Failed to parse tool call arguments as JSON: %s. Defaulting to" + " empty arguments.", arguments, ) arguments = {} @@ -383,7 +410,9 @@ def _to_llm_response( ) # 4) Model version: prefer Ollama's `model`, fallback to request model. - model_version = response_json.get("model") or self._extract_model_name(request_model or self.model) + model_version = response_json.get("model") or self._extract_model_name( + request_model or self.model + ) return LlmResponse( content=types.Content(role="model", parts=parts), From d78caa70358ec7572bc3adb0661adff76aad60eb Mon Sep 17 00:00:00 2001 From: Ayman Hamed Date: Tue, 9 Dec 2025 06:46:50 +0200 Subject: [PATCH 7/9] refactor(tests): clean and reorganize Ollama test suite for readability, consistency, and ADK compliance --- tests/unittests/models/test_ollama.py | 376 +++++++++++++------------- 1 file changed, 193 insertions(+), 183 deletions(-) diff --git a/tests/unittests/models/test_ollama.py b/tests/unittests/models/test_ollama.py index a2c032fba4..842444c7ef 100644 --- a/tests/unittests/models/test_ollama.py +++ b/tests/unittests/models/test_ollama.py @@ -13,282 +13,292 @@ # limitations under the License. import json -from unittest.mock import AsyncMock -from unittest.mock import Mock +import pytest +from google.genai import types from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.models.ollama_llm import Ollama -from google.genai import types -import pytest -# -# ----------------------------------- -# Helpers -# ----------------------------------- -# +# ===================================================== +# Helpers +# ===================================================== def mock_response_ok(text="Hello world", tool_calls=None): - """Create a typical Ollama /api/chat response.""" - message = {"content": text} - if tool_calls: - message["tool_calls"] = tool_calls - return {"message": message} + """Create a minimal valid Ollama /api/chat response.""" + message = {"content": text} + if tool_calls: + message["tool_calls"] = tool_calls + return {"message": message} -# -# ----------------------------------- -# Test: model extraction -# ----------------------------------- -# - +# ===================================================== +# Model extraction +# ===================================================== def test_extract_model_name_basic(): - o = Ollama(model="ollama/mistral") - assert o._extract_model_name("ollama/mistral") == "mistral" + o = Ollama(model="ollama/mistral") + assert o._extract_model_name("ollama/mistral") == "mistral" def test_extract_model_name_chat_prefix(): - o = Ollama(model="ollama_chat/llama3.1") - assert o._extract_model_name("ollama_chat/llama3.1") == "llama3.1" + o = Ollama(model="ollama_chat/llama3.1") + assert o._extract_model_name("ollama_chat/llama3.1") == "llama3.1" def test_extract_model_name_no_prefix(): - o = Ollama(model="mistral") - assert o._extract_model_name("mistral") == "mistral" + o = Ollama(model="mistral") + assert o._extract_model_name("mistral") == "mistral" -# -# ----------------------------------- -# Test: message conversion -# ----------------------------------- -# - +# ===================================================== +# Message conversion +# ===================================================== def test_convert_messages_basic(): - o = Ollama() - req = LlmRequest(contents=[types.Content(role="user", parts=[types.Part.from_text("Hi")])]) - msgs = o._convert_messages(req) - assert msgs[0]["role"] == "user" - assert msgs[0]["content"] == "Hi" + o = Ollama() + req = LlmRequest( + contents=[types.Content(role="user", + parts=[types.Part.from_text(text="Hi")])] + ) + msgs = o._convert_messages(req) + assert msgs[0]["role"] == "user" + assert msgs[0]["content"] == "Hi" def test_convert_messages_with_system(): - o = Ollama() - req = LlmRequest( - contents=[types.Content(role="user", parts=[types.Part.from_text("X")])], - config=types.GenerateContentConfig(system_instruction="SYS"), - ) - msgs = o._convert_messages(req) - assert msgs[0]["role"] == "system" - assert msgs[0]["content"] == "SYS" - assert msgs[1]["content"] == "X" + o = Ollama() + req = LlmRequest( + contents=[types.Content(role="user", + parts=[types.Part.from_text(text="X")])], + config=types.GenerateContentConfig(system_instruction="SYS"), + ) + msgs = o._convert_messages(req) + assert msgs[0]["role"] == "system" + assert msgs[0]["content"] == "SYS" + assert msgs[1]["content"] == "X" -# -# ----------------------------------- -# Test: content → text -# ----------------------------------- -# +# ===================================================== +# Content → text extraction +# ===================================================== def test_content_to_text_basic(): - o = Ollama() - content = types.Content(role="user", parts=[types.Part.from_text("ABC")]) - assert o._content_to_text(content) == "ABC" + o = Ollama() + content = types.Content(role="user", + parts=[types.Part.from_text(text="ABC")]) + assert o._content_to_text(content) == "ABC" def test_content_to_text_function_call(): - o = Ollama() - part = types.Part.from_function_call(name="add", args={"x": 1, "y": 2}) - part.function_call.id = "call123" + o = Ollama() + part = types.Part.from_function_call(name="add", + args={"x": 1, "y": 2}) + part.function_call.id = "call123" - content = types.Content(role="assistant", parts=[part]) - txt = o._content_to_text(content) + content = types.Content(role="assistant", parts=[part]) + txt = o._content_to_text(content) - assert "[tool_call name=add]" in txt - assert '"x": 1' in txt + assert "[tool_call name=add]" in txt + assert '"x": 1' in txt def test_content_to_text_tool_response(): - o = Ollama() - part = types.Part.from_function_response(name="add", response={"z": 5}) - content = types.Content(role="tool", parts=[part]) - txt = o._content_to_text(content) - - assert "[tool_response name=add]" in txt - assert '"z": 5' in txt + o = Ollama() + part = types.Part.from_function_response(name="add", + response={"z": 5}) + content = types.Content(role="tool", parts=[part]) + txt = o._content_to_text(content) + assert "[tool_response name=add]" in txt + assert '"z": 5' in txt -# -# ----------------------------------- -# Test: tool conversion -# ----------------------------------- -# +# ===================================================== +# Tools conversion +# ===================================================== def test_convert_tools_basic(): - o = Ollama() - req = LlmRequest( - config=types.GenerateContentConfig( - tools=[ - types.Tool( - function_declarations=[ - types.FunctionDeclaration( - name="add", - description="Add numbers", - parameters=types.Schema( - type=types.Type.OBJECT, properties={"x": types.Schema(type=types.Type.NUMBER)} - ), - ) - ] - ) - ] - ) - ) - tools = o._convert_tools(req) - assert tools[0]["function"]["name"] == "add" - assert tools[0]["function"]["parameters"]["type"] == "object" - - -# -# ----------------------------------- -# Test: POST wrapper -# ----------------------------------- -# - + o = Ollama() + req = LlmRequest( + config=types.GenerateContentConfig( + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="add", + description="Add numbers", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "x": types.Schema(type=types.Type.NUMBER) + }, + ), + ) + ] + ) + ] + ) + ) + + tools = o._convert_tools(req) + + assert tools[0]["function"]["name"] == "add" + assert tools[0]["function"]["parameters"]["type"] == types.Type.OBJECT + + +# ===================================================== +# POST wrapper +# ===================================================== def test_post_chat_success(monkeypatch): - fake_response = {"message": {"content": "OK"}} + fake_response = {"message": {"content": "OK"}} - def fake_urlopen(req, timeout=0): - class Resp: + def fake_urlopen(req, timeout=0): + class Resp: + def __enter__(self): return self + def __exit__(self, *_): return False + def read(self): return json.dumps(fake_response).encode("utf-8") - def read(self): - return json.dumps(fake_response).encode("utf-8") + return Resp() - return Resp() + monkeypatch.setattr("urllib.request.urlopen", fake_urlopen) - monkeypatch.setattr("urllib.request.urlopen", fake_urlopen) + o = Ollama() + resp = o._post_chat({"model": "x"}) - o = Ollama() - resp = o._post_chat({"model": "x"}) - assert resp["message"]["content"] == "OK" + assert resp["message"]["content"] == "OK" -# -# ----------------------------------- -# Test: _to_llm_response -# ----------------------------------- -# - +# ===================================================== +# LlmResponse conversion +# ===================================================== def test_to_llm_response_text(): - o = Ollama() - resp = mock_response_ok("Hi") + o = Ollama() + resp = mock_response_ok("Hi") + + out = o._to_llm_response(resp) - out = o._to_llm_response(resp) - assert isinstance(out, LlmResponse) - assert out.content.parts[0].text == "Hi" + assert isinstance(out, LlmResponse) + assert out.content.parts[0].text == "Hi" def test_to_llm_response_tool_call(): - o = Ollama() - tool_call = {"id": "abc", "function": {"name": "add", "arguments": '{"x": 1}'}} + o = Ollama() + tool_call = { + "id": "abc", + "function": { + "name": "add", + "arguments": '{"x": 1}' + }, + } - resp = mock_response_ok(tool_calls=[tool_call]) - out = o._to_llm_response(resp) + resp = mock_response_ok(tool_calls=[tool_call]) + out = o._to_llm_response(resp) - fc = out.content.parts[0].function_call - assert fc.name == "add" - assert fc.args == {"x": 1} - assert fc.id == "abc" + fc = next(p.function_call for p in out.content.parts if p.function_call) + + assert fc.name == "add" + assert fc.args == {"x": 1} + assert fc.id == "abc" def test_to_llm_response_tool_call_bad_json(): - o = Ollama() - tool_call = {"id": "zzz", "function": {"name": "add", "arguments": "{BAD_JSON"}} + o = Ollama() + tool_call = { + "id": "zzz", + "function": { + "name": "add", + "arguments": "{BAD_JSON" + }, + } - resp = mock_response_ok(tool_calls=[tool_call]) - out = o._to_llm_response(resp) + resp = mock_response_ok(tool_calls=[tool_call]) + out = o._to_llm_response(resp) - fc = out.content.parts[0].function_call - assert fc.args == {} # BAD JSON → fallback to {} + fc = next(p.function_call for p in out.content.parts if p.function_call) + assert fc.args == {} # fallback -def test_to_llm_response_usage_metadata(): - o = Ollama() - resp = mock_response_ok("Hi") - resp["prompt_eval_count"] = 10 - resp["eval_count"] = 5 - out = o._to_llm_response(resp) +def test_to_llm_response_usage_metadata(): + o = Ollama() + resp = mock_response_ok("Hi") + resp["prompt_eval_count"] = 10 + resp["eval_count"] = 5 - assert out.usage_metadata is not None - assert out.usage_metadata.prompt_token_count == 10 - assert out.usage_metadata.candidates_token_count == 5 - assert out.usage_metadata.total_token_count == 15 + out = o._to_llm_response(resp) + assert out.usage_metadata.prompt_token_count == 10 + assert out.usage_metadata.candidates_token_count == 5 + assert out.usage_metadata.total_token_count == 15 -# -# ----------------------------------- -# async: generate_content_async -# ----------------------------------- -# +# ===================================================== +# Async generate_content_async() +# ===================================================== @pytest.mark.asyncio async def test_generate_content_async_basic(monkeypatch): - resp = mock_response_ok("Hello!") + resp = mock_response_ok("Hello!") - async def fake_thread(fn, *args): - return resp + async def fake_thread(fn, *args): + return resp - monkeypatch.setattr("asyncio.to_thread", fake_thread) + monkeypatch.setattr("asyncio.to_thread", fake_thread) - o = Ollama(model="ollama/mistral") - req = LlmRequest(contents=[types.Content(role="user", parts=[types.Part.from_text("Hi")])]) + o = Ollama(model="ollama/mistral") + req = LlmRequest( + contents=[types.Content(role="user", + parts=[types.Part.from_text(text="Hi")])] + ) - results = [r async for r in o.generate_content_async(req)] - assert results[0].content.parts[0].text == "Hello!" + results = [r async for r in o.generate_content_async(req)] + + assert results[0].content.parts[0].text == "Hello!" @pytest.mark.asyncio async def test_generate_content_async_error(monkeypatch): - async def fake_thread(fn, *args): - raise RuntimeError("boom") + async def fake_thread(fn, *args): + raise RuntimeError("boom") - monkeypatch.setattr("asyncio.to_thread", fake_thread) + monkeypatch.setattr("asyncio.to_thread", fake_thread) - o = Ollama() - req = LlmRequest(contents=[types.Content(role="user", parts=[])]) + o = Ollama() + req = LlmRequest(contents=[types.Content(role="user", parts=[])]) - results = [r async for r in o.generate_content_async(req)] - assert results[0].error_code == "OLLAMA_ERROR" + results = [r async for r in o.generate_content_async(req)] + assert results[0].error_code == "OLLAMA_ERROR" -# -# ----------------------------------- -# Test: model override -# ----------------------------------- -# +# ===================================================== +# Model override +# ===================================================== @pytest.mark.asyncio async def test_model_override(monkeypatch): - resp = mock_response_ok("Hello") - resp["model"] = "override" + resp = mock_response_ok("Hello") + resp["model"] = "override" + + async def fake_thread(fn, *args): + payload = args[0] + assert payload["model"] == "override" + return resp - async def fake_thread(fn, *args): - payload = args[0] - assert payload["model"] == "override" # important - return resp + monkeypatch.setattr("asyncio.to_thread", fake_thread) - monkeypatch.setattr("asyncio.to_thread", fake_thread) + o = Ollama(model="default") + req = LlmRequest( + model="override", + contents=[types.Content(role="user", + parts=[types.Part.from_text(text="X")])] + ) - o = Ollama(model="default") - req = LlmRequest(model="override", contents=[types.Content(role="user", parts=[types.Part.from_text("X")])]) + out = [r async for r in o.generate_content_async(req)][0] - out = [r async for r in o.generate_content_async(req)][0] - assert out.model_version == "override" + assert out.model_version == "override" \ No newline at end of file From 5693b12cb84b67c0b2b556243955b580080972f9 Mon Sep 17 00:00:00 2001 From: Ayman Hamed Date: Thu, 18 Dec 2025 06:06:23 +0200 Subject: [PATCH 8/9] fix imports and formatting --- tests/unittests/models/test_ollama.py | 329 +++++++++++++------------- 1 file changed, 169 insertions(+), 160 deletions(-) diff --git a/tests/unittests/models/test_ollama.py b/tests/unittests/models/test_ollama.py index 842444c7ef..537424bbb3 100644 --- a/tests/unittests/models/test_ollama.py +++ b/tests/unittests/models/test_ollama.py @@ -13,292 +13,301 @@ # limitations under the License. import json -import pytest -from google.genai import types from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.models.ollama_llm import Ollama - +from google.genai import types +import pytest # ===================================================== # Helpers # ===================================================== + def mock_response_ok(text="Hello world", tool_calls=None): - """Create a minimal valid Ollama /api/chat response.""" - message = {"content": text} - if tool_calls: - message["tool_calls"] = tool_calls - return {"message": message} + """Create a minimal valid Ollama /api/chat response.""" + message = {"content": text} + if tool_calls: + message["tool_calls"] = tool_calls + return {"message": message} # ===================================================== # Model extraction # ===================================================== + def test_extract_model_name_basic(): - o = Ollama(model="ollama/mistral") - assert o._extract_model_name("ollama/mistral") == "mistral" + o = Ollama(model="ollama/mistral") + assert o._extract_model_name("ollama/mistral") == "mistral" def test_extract_model_name_chat_prefix(): - o = Ollama(model="ollama_chat/llama3.1") - assert o._extract_model_name("ollama_chat/llama3.1") == "llama3.1" + o = Ollama(model="ollama_chat/llama3.1") + assert o._extract_model_name("ollama_chat/llama3.1") == "llama3.1" def test_extract_model_name_no_prefix(): - o = Ollama(model="mistral") - assert o._extract_model_name("mistral") == "mistral" + o = Ollama(model="mistral") + assert o._extract_model_name("mistral") == "mistral" # ===================================================== # Message conversion # ===================================================== + def test_convert_messages_basic(): - o = Ollama() - req = LlmRequest( - contents=[types.Content(role="user", - parts=[types.Part.from_text(text="Hi")])] - ) - msgs = o._convert_messages(req) - assert msgs[0]["role"] == "user" - assert msgs[0]["content"] == "Hi" + o = Ollama() + req = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="Hi")]) + ] + ) + msgs = o._convert_messages(req) + assert msgs[0]["role"] == "user" + assert msgs[0]["content"] == "Hi" def test_convert_messages_with_system(): - o = Ollama() - req = LlmRequest( - contents=[types.Content(role="user", - parts=[types.Part.from_text(text="X")])], - config=types.GenerateContentConfig(system_instruction="SYS"), - ) - msgs = o._convert_messages(req) + o = Ollama() + req = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="X")]) + ], + config=types.GenerateContentConfig(system_instruction="SYS"), + ) + msgs = o._convert_messages(req) - assert msgs[0]["role"] == "system" - assert msgs[0]["content"] == "SYS" - assert msgs[1]["content"] == "X" + assert msgs[0]["role"] == "system" + assert msgs[0]["content"] == "SYS" + assert msgs[1]["content"] == "X" # ===================================================== # Content → text extraction # ===================================================== + def test_content_to_text_basic(): - o = Ollama() - content = types.Content(role="user", - parts=[types.Part.from_text(text="ABC")]) - assert o._content_to_text(content) == "ABC" + o = Ollama() + content = types.Content(role="user", parts=[types.Part.from_text(text="ABC")]) + assert o._content_to_text(content) == "ABC" def test_content_to_text_function_call(): - o = Ollama() - part = types.Part.from_function_call(name="add", - args={"x": 1, "y": 2}) - part.function_call.id = "call123" + o = Ollama() + part = types.Part.from_function_call(name="add", args={"x": 1, "y": 2}) + part.function_call.id = "call123" - content = types.Content(role="assistant", parts=[part]) - txt = o._content_to_text(content) + content = types.Content(role="assistant", parts=[part]) + txt = o._content_to_text(content) - assert "[tool_call name=add]" in txt - assert '"x": 1' in txt + assert "[tool_call name=add]" in txt + assert '"x": 1' in txt def test_content_to_text_tool_response(): - o = Ollama() - part = types.Part.from_function_response(name="add", - response={"z": 5}) - content = types.Content(role="tool", parts=[part]) - txt = o._content_to_text(content) + o = Ollama() + part = types.Part.from_function_response(name="add", response={"z": 5}) + content = types.Content(role="tool", parts=[part]) + txt = o._content_to_text(content) - assert "[tool_response name=add]" in txt - assert '"z": 5' in txt + assert "[tool_response name=add]" in txt + assert '"z": 5' in txt # ===================================================== # Tools conversion # ===================================================== + def test_convert_tools_basic(): - o = Ollama() - req = LlmRequest( - config=types.GenerateContentConfig( - tools=[ - types.Tool( - function_declarations=[ - types.FunctionDeclaration( - name="add", - description="Add numbers", - parameters=types.Schema( - type=types.Type.OBJECT, - properties={ - "x": types.Schema(type=types.Type.NUMBER) - }, - ), - ) - ] - ) - ] - ) - ) - - tools = o._convert_tools(req) - - assert tools[0]["function"]["name"] == "add" - assert tools[0]["function"]["parameters"]["type"] == types.Type.OBJECT + o = Ollama() + req = LlmRequest( + config=types.GenerateContentConfig( + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="add", + description="Add numbers", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "x": types.Schema(type=types.Type.NUMBER) + }, + ), + ) + ] + ) + ] + ) + ) + + tools = o._convert_tools(req) + + assert tools[0]["function"]["name"] == "add" + assert tools[0]["function"]["parameters"]["type"] == types.Type.OBJECT # ===================================================== # POST wrapper # ===================================================== + def test_post_chat_success(monkeypatch): - fake_response = {"message": {"content": "OK"}} + fake_response = {"message": {"content": "OK"}} + + def fake_urlopen(req, timeout=0): + class Resp: + + def __enter__(self): + return self - def fake_urlopen(req, timeout=0): - class Resp: - def __enter__(self): return self - def __exit__(self, *_): return False - def read(self): return json.dumps(fake_response).encode("utf-8") + def __exit__(self, *_): + return False - return Resp() + def read(self): + return json.dumps(fake_response).encode("utf-8") - monkeypatch.setattr("urllib.request.urlopen", fake_urlopen) + return Resp() - o = Ollama() - resp = o._post_chat({"model": "x"}) + monkeypatch.setattr("urllib.request.urlopen", fake_urlopen) - assert resp["message"]["content"] == "OK" + o = Ollama() + resp = o._post_chat({"model": "x"}) + + assert resp["message"]["content"] == "OK" # ===================================================== # LlmResponse conversion # ===================================================== + def test_to_llm_response_text(): - o = Ollama() - resp = mock_response_ok("Hi") + o = Ollama() + resp = mock_response_ok("Hi") - out = o._to_llm_response(resp) + out = o._to_llm_response(resp) - assert isinstance(out, LlmResponse) - assert out.content.parts[0].text == "Hi" + assert isinstance(out, LlmResponse) + assert out.content.parts[0].text == "Hi" def test_to_llm_response_tool_call(): - o = Ollama() - tool_call = { - "id": "abc", - "function": { - "name": "add", - "arguments": '{"x": 1}' - }, - } + o = Ollama() + tool_call = { + "id": "abc", + "function": {"name": "add", "arguments": '{"x": 1}'}, + } - resp = mock_response_ok(tool_calls=[tool_call]) - out = o._to_llm_response(resp) + resp = mock_response_ok(tool_calls=[tool_call]) + out = o._to_llm_response(resp) - fc = next(p.function_call for p in out.content.parts if p.function_call) + fc = next(p.function_call for p in out.content.parts if p.function_call) - assert fc.name == "add" - assert fc.args == {"x": 1} - assert fc.id == "abc" + assert fc.name == "add" + assert fc.args == {"x": 1} + assert fc.id == "abc" def test_to_llm_response_tool_call_bad_json(): - o = Ollama() - tool_call = { - "id": "zzz", - "function": { - "name": "add", - "arguments": "{BAD_JSON" - }, - } + o = Ollama() + tool_call = { + "id": "zzz", + "function": {"name": "add", "arguments": "{BAD_JSON"}, + } - resp = mock_response_ok(tool_calls=[tool_call]) - out = o._to_llm_response(resp) + resp = mock_response_ok(tool_calls=[tool_call]) + out = o._to_llm_response(resp) - fc = next(p.function_call for p in out.content.parts if p.function_call) + fc = next(p.function_call for p in out.content.parts if p.function_call) - assert fc.args == {} # fallback + assert fc.args == {} # fallback def test_to_llm_response_usage_metadata(): - o = Ollama() - resp = mock_response_ok("Hi") - resp["prompt_eval_count"] = 10 - resp["eval_count"] = 5 + o = Ollama() + resp = mock_response_ok("Hi") + resp["prompt_eval_count"] = 10 + resp["eval_count"] = 5 - out = o._to_llm_response(resp) + out = o._to_llm_response(resp) - assert out.usage_metadata.prompt_token_count == 10 - assert out.usage_metadata.candidates_token_count == 5 - assert out.usage_metadata.total_token_count == 15 + assert out.usage_metadata.prompt_token_count == 10 + assert out.usage_metadata.candidates_token_count == 5 + assert out.usage_metadata.total_token_count == 15 # ===================================================== # Async generate_content_async() # ===================================================== + @pytest.mark.asyncio async def test_generate_content_async_basic(monkeypatch): - resp = mock_response_ok("Hello!") + resp = mock_response_ok("Hello!") - async def fake_thread(fn, *args): - return resp + async def fake_thread(fn, *args): + return resp - monkeypatch.setattr("asyncio.to_thread", fake_thread) + monkeypatch.setattr("asyncio.to_thread", fake_thread) - o = Ollama(model="ollama/mistral") - req = LlmRequest( - contents=[types.Content(role="user", - parts=[types.Part.from_text(text="Hi")])] - ) + o = Ollama(model="ollama/mistral") + req = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="Hi")]) + ] + ) - results = [r async for r in o.generate_content_async(req)] + results = [r async for r in o.generate_content_async(req)] - assert results[0].content.parts[0].text == "Hello!" + assert results[0].content.parts[0].text == "Hello!" @pytest.mark.asyncio async def test_generate_content_async_error(monkeypatch): - async def fake_thread(fn, *args): - raise RuntimeError("boom") + async def fake_thread(fn, *args): + raise RuntimeError("boom") - monkeypatch.setattr("asyncio.to_thread", fake_thread) + monkeypatch.setattr("asyncio.to_thread", fake_thread) - o = Ollama() - req = LlmRequest(contents=[types.Content(role="user", parts=[])]) + o = Ollama() + req = LlmRequest(contents=[types.Content(role="user", parts=[])]) - results = [r async for r in o.generate_content_async(req)] + results = [r async for r in o.generate_content_async(req)] - assert results[0].error_code == "OLLAMA_ERROR" + assert results[0].error_code == "OLLAMA_ERROR" # ===================================================== # Model override # ===================================================== + @pytest.mark.asyncio async def test_model_override(monkeypatch): - resp = mock_response_ok("Hello") - resp["model"] = "override" + resp = mock_response_ok("Hello") + resp["model"] = "override" - async def fake_thread(fn, *args): - payload = args[0] - assert payload["model"] == "override" - return resp + async def fake_thread(fn, *args): + payload = args[0] + assert payload["model"] == "override" + return resp - monkeypatch.setattr("asyncio.to_thread", fake_thread) + monkeypatch.setattr("asyncio.to_thread", fake_thread) - o = Ollama(model="default") - req = LlmRequest( - model="override", - contents=[types.Content(role="user", - parts=[types.Part.from_text(text="X")])] - ) + o = Ollama(model="default") + req = LlmRequest( + model="override", + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="X")]) + ], + ) - out = [r async for r in o.generate_content_async(req)][0] + out = [r async for r in o.generate_content_async(req)][0] - assert out.model_version == "override" \ No newline at end of file + assert out.model_version == "override" From 93395f68a54b2f7493a4c49aa5e51f078ced3aa4 Mon Sep 17 00:00:00 2001 From: Ayman Hamed Date: Fri, 19 Dec 2025 10:15:14 +0200 Subject: [PATCH 9/9] Fix Ollama host configuration, examples, and documentation mismatches --- .../hello_world_ollama_native/README.md | 32 ++++++++--- .../hello_world_ollama_native/agent.py | 55 ++++++++++--------- src/google/adk/models/ollama_llm.py | 11 +++- 3 files changed, 64 insertions(+), 34 deletions(-) diff --git a/contributing/samples/hello_world_ollama_native/README.md b/contributing/samples/hello_world_ollama_native/README.md index dfe6bc4716..c0e2397344 100644 --- a/contributing/samples/hello_world_ollama_native/README.md +++ b/contributing/samples/hello_world_ollama_native/README.md @@ -44,23 +44,41 @@ No LiteLLM provider, API keys, or OpenAI proxy endpoints are needed. ```python import random from google.adk.agents.llm_agent import Agent -from google.adk.models.ollama import Ollama +from google.adk.models.ollama_llm import Ollama + def roll_die(sides: int) -> int: return random.randint(1, sides) -def check_prime(numbers: list[int]) -> str: - primes = [] + +def check_prime(numbers: list) -> str: + """Check if a given list of values contains prime numbers. + + The input may include non-integer values produced by the LLM. + """ + primes = set() + for number in numbers: - number = int(number) + try: + number = int(number) + except (ValueError, TypeError): + continue + if number <= 1: continue - for i in range(2, int(number ** 0.5) + 1): + + for i in range(2, int(number**0.5) + 1): if number % i == 0: break else: - primes.append(number) - return "No prime numbers found." if not primes else f"{', '.join(map(str, primes))} are prime numbers." + primes.add(number) + + return ( + "No prime numbers found." + if not primes + else f"{', '.join(str(n) for n in sorted(primes))} are prime numbers." + ) + root_agent = Agent( model=Ollama(model="llama3.1"), diff --git a/contributing/samples/hello_world_ollama_native/agent.py b/contributing/samples/hello_world_ollama_native/agent.py index 75639e0fbd..c3c97ffcce 100755 --- a/contributing/samples/hello_world_ollama_native/agent.py +++ b/contributing/samples/hello_world_ollama_native/agent.py @@ -16,7 +16,7 @@ from google.adk.agents.llm_agent import Agent from google.adk.models.ollama_llm import Ollama - +from typing import Any def roll_die(sides: int) -> int: """Roll a die and return the rolled result. @@ -30,33 +30,38 @@ def roll_die(sides: int) -> int: return random.randint(1, sides) -def check_prime(numbers: list[int]) -> str: - """Check if a given list of numbers are prime. +def check_prime(numbers: list[Any]) -> str: + """Check which values in a list are prime numbers. - Args: - numbers: The list of numbers to check. + Args: + numbers: The list of values to check. Values may be non-integers + and are safely ignored if they cannot be converted. - Returns: - A str indicating which number is prime. - """ - primes = set() - for number in numbers: - number = int(number) - if number <= 1: - continue - is_prime = True - for i in range(2, int(number**0.5) + 1): - if number % i == 0: - is_prime = False - break - if is_prime: - primes.add(number) - return ( - "No prime numbers found." - if not primes - else f"{', '.join(str(num) for num in primes)} are prime numbers." - ) + Returns: + A string indicating which numbers are prime. + """ + primes = set() + + for number in numbers: + try: + number = int(number) + except (ValueError, TypeError): + continue # Safely skip non-numeric values + + if number <= 1: + continue + + for i in range(2, int(number ** 0.5) + 1): + if number % i == 0: + break + else: + primes.add(number) + return ( + "No prime numbers found." + if not primes + else f"{', '.join(str(num) for num in sorted(primes))} are prime numbers." + ) root_agent = Agent( model=Ollama(model="llama3.1"), diff --git a/src/google/adk/models/ollama_llm.py b/src/google/adk/models/ollama_llm.py index 063e580316..5159caf3e4 100644 --- a/src/google/adk/models/ollama_llm.py +++ b/src/google/adk/models/ollama_llm.py @@ -24,6 +24,7 @@ from typing import Union import urllib.error import urllib.request +import os from google.genai import types from pydantic import Field @@ -57,7 +58,7 @@ class Ollama(BaseLlm): model: str = "ollama/llama3.1" host: str = Field( - default="http://localhost:11434", + default=os.environ.get("OLLAMA_API_BASE", "http://localhost:11434"), description="Base URL of the Ollama server.", ) request_timeout: float = Field( @@ -302,7 +303,13 @@ def _convert_options(self, llm_request: LlmRequest) -> dict[str, Any]: # --------------------------------------------------------------------------- def _post_chat(self, payload: dict[str, Any]) -> dict[str, Any]: - """Perform a blocking POST /api/chat call to Ollama.""" + """Perform a blocking POST /api/chat call to Ollama. + Note: This method is intentionally blocking and is executed via + asyncio.to_thread() to avoid introducing additional async HTTP + dependencies. This keeps the backend consistent with existing ADK + providers. + """ + url = self.host.rstrip("/") + _CHAT_ENDPOINT data = json.dumps(payload).encode("utf-8") request = urllib.request.Request(