diff --git a/contributing/samples/hello_world_ollama_native/README.md b/contributing/samples/hello_world_ollama_native/README.md new file mode 100644 index 0000000000..c0e2397344 --- /dev/null +++ b/contributing/samples/hello_world_ollama_native/README.md @@ -0,0 +1,117 @@ +# Using Ollama Models with ADK (Native Integration) + +## Model Choice + +If your agent uses tools, choose an Ollama model that supports **function calling**. +Tool support can be verified with: + +```bash +ollama show mistral-small3.1 +``` +Model + architecture mistral3 + parameters 24.0B + context length 131072 + embedding length 5120 + quantization Q4_K_M + +Capabilities + completion + vision + tools + +Models must list tools under Capabilities. +Models without tool support will not execute ADK functions correctly. + +To inspect or customize a model template: +```bash +ollama show --modelfile llama3.1 > model_file_to_modify +``` +Then create a modified model: + +ollama create llama3.1-modified -f model_file_to_modify + + +## Native Ollama Provider in ADK + +ADK includes a native Ollama model class that communicates directly with the Ollama server at: + +http://localhost:11434/api/chat + +No LiteLLM provider, API keys, or OpenAI proxy endpoints are needed. + +### Example agent +```python +import random +from google.adk.agents.llm_agent import Agent +from google.adk.models.ollama_llm import Ollama + + +def roll_die(sides: int) -> int: + return random.randint(1, sides) + + +def check_prime(numbers: list) -> str: + """Check if a given list of values contains prime numbers. + + The input may include non-integer values produced by the LLM. + """ + primes = set() + + for number in numbers: + try: + number = int(number) + except (ValueError, TypeError): + continue + + if number <= 1: + continue + + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + break + else: + primes.add(number) + + return ( + "No prime numbers found." + if not primes + else f"{', '.join(str(n) for n in sorted(primes))} are prime numbers." + ) + + +root_agent = Agent( + model=Ollama(model="llama3.1"), + name="dice_agent", + description="Agent that rolls dice and checks primes using native Ollama.", + instruction="Always use the provided tools.", + tools=[roll_die, check_prime], +) +``` +## Connecting to a remote Ollama server + +Default Ollama endpoint: + +http://localhost:11434 + +Override using an environment variable: +```bash +export OLLAMA_API_BASE="http://192.168.1.20:11434" +``` +Or pass explicitly in code: +```python +Ollama(model="llama3.1", host="http://192.168.1.20:11434") +``` + + +## Running the Example with ADK Web + +Start the ADK Web UI: + +adk web hello_ollama_native + +The interface will be available in your browser, allowing interactive testing of tool calls. + + + + diff --git a/contributing/samples/hello_world_ollama_native/__init__.py b/contributing/samples/hello_world_ollama_native/__init__.py new file mode 100755 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/hello_world_ollama_native/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/hello_world_ollama_native/agent.py b/contributing/samples/hello_world_ollama_native/agent.py new file mode 100755 index 0000000000..c3c97ffcce --- /dev/null +++ b/contributing/samples/hello_world_ollama_native/agent.py @@ -0,0 +1,94 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +from google.adk.agents.llm_agent import Agent +from google.adk.models.ollama_llm import Ollama +from typing import Any + +def roll_die(sides: int) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + return random.randint(1, sides) + + +def check_prime(numbers: list[Any]) -> str: + """Check which values in a list are prime numbers. + + Args: + numbers: The list of values to check. Values may be non-integers + and are safely ignored if they cannot be converted. + + Returns: + A string indicating which numbers are prime. + """ + primes = set() + + for number in numbers: + try: + number = int(number) + except (ValueError, TypeError): + continue # Safely skip non-numeric values + + if number <= 1: + continue + + for i in range(2, int(number ** 0.5) + 1): + if number % i == 0: + break + else: + primes.add(number) + + return ( + "No prime numbers found." + if not primes + else f"{', '.join(str(num) for num in sorted(primes))} are prime numbers." + ) + +root_agent = Agent( + model=Ollama(model="llama3.1"), + name="dice_roll_agent", + description=( + "hello world agent that can roll a dice of any number of sides and" + " check prime numbers." + ), + instruction=""" + You roll dice and answer questions about the outcome of the dice rolls. + You can roll dice of different sizes. + You can use multiple tools in parallel by calling functions in parallel(in one request and in one round). + It is ok to discuss previous dice roles, and comment on the dice rolls. + When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string. + You should never roll a die on your own. + When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string. + You should not check prime numbers before calling the tool. + When you are asked to roll a die and check prime numbers, you should always make the following two function calls: + 1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool. + 2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result. + 2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list. + 3. When you respond, you must include the roll_die result from step 1. + You should always perform the previous 3 steps when asking for a roll and checking prime numbers. + You should not rely on the previous history on prime results. + """, + tools=[ + roll_die, + check_prime, + ], +) diff --git a/contributing/samples/hello_world_ollama_native/main.py b/contributing/samples/hello_world_ollama_native/main.py new file mode 100755 index 0000000000..28fdbbbc92 --- /dev/null +++ b/contributing/samples/hello_world_ollama_native/main.py @@ -0,0 +1,77 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import time +import warnings + +import agent +from dotenv import load_dotenv +from google.adk import Runner +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.cli.utils import logs +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.session import Session +from google.genai import types + +load_dotenv(override=True) +warnings.filterwarnings('ignore', category=UserWarning) +logs.log_to_tmp_folder() + + +async def main(): + app_name = 'my_app' + user_id_1 = 'user1' + session_service = InMemorySessionService() + artifact_service = InMemoryArtifactService() + runner = Runner( + app_name=app_name, + agent=agent.root_agent, + artifact_service=artifact_service, + session_service=session_service, + ) + session_11 = await session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + async def run_prompt(session: Session, new_message: str): + content = types.Content( + role='user', parts=[types.Part.from_text(text=new_message)] + ) + print('** User says:', content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ): + if event.content.parts and event.content.parts[0].text: + print(f'** {event.author}: {event.content.parts[0].text}') + + start_time = time.time() + print('Start time:', start_time) + print('------------------------------------') + await run_prompt(session_11, 'Hi, introduce yourself.') + await run_prompt( + session_11, 'Roll a die with 100 sides and check if it is prime' + ) + await run_prompt(session_11, 'Roll it again.') + await run_prompt(session_11, 'What numbers did I get?') + end_time = time.time() + print('------------------------------------') + print('End time:', end_time) + print('Total time:', end_time - start_time) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/google/adk/models/ollama_llm.py b/src/google/adk/models/ollama_llm.py new file mode 100644 index 0000000000..5159caf3e4 --- /dev/null +++ b/src/google/adk/models/ollama_llm.py @@ -0,0 +1,428 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import json +import logging +from typing import Any +from typing import AsyncGenerator +from typing import Optional +from typing import Sequence +from typing import Union +import urllib.error +import urllib.request +import os + +from google.genai import types +from pydantic import Field +from typing_extensions import override + +from .base_llm import BaseLlm +from .llm_request import LlmRequest +from .llm_response import LlmResponse + +logger = logging.getLogger("google_adk." + __name__) + +_CHAT_ENDPOINT = "/api/chat" + + +class Ollama(BaseLlm): + """Native integration for Ollama-hosted models. + + This backend talks directly to the Ollama HTTP API: + + POST /api/chat + + It supports: + * `ollama/` names (e.g. `ollama/llama3.2`) + * `ollama_chat/` names for LiteLlm compatibility + * System / user / assistant messages + * Unary generation + * Tool-calling via Ollama `tools` schema + """ + + # Default model name is compatible with Agent(model="ollama/llama3.1") + model: str = "ollama/llama3.1" + + host: str = Field( + default=os.environ.get("OLLAMA_API_BASE", "http://localhost:11434"), + description="Base URL of the Ollama server.", + ) + request_timeout: float = Field( + default=120.0, + description="Timeout in seconds for Ollama requests.", + ) + + @classmethod + @override + def supported_models(cls) -> list[str]: + # Allow any `ollama/...` style name. + return [r"ollama\/.+"] + + @override + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + if stream: + logger.warning( + "Streaming is not yet supported for Ollama; falling back to unary." + ) + + # Ensure last user content is appended if needed (BaseLlm helper). + self._maybe_append_user_content(llm_request) + + payload = self._build_payload(llm_request) + try: + response_json = await asyncio.to_thread(self._post_chat, payload) + except RuntimeError as exc: + logger.error("Failed to call Ollama: %s", exc) + yield LlmResponse(error_code="OLLAMA_ERROR", error_message=str(exc)) + return + + llm_response = self._to_llm_response( + response_json, request_model=llm_request.model + ) + yield llm_response + + # --------------------------------------------------------------------------- + # Payload construction + # --------------------------------------------------------------------------- + + def _build_payload(self, llm_request: LlmRequest) -> dict[str, Any]: + payload: dict[str, Any] = { + "model": self._extract_model_name(llm_request.model), + "messages": self._convert_messages(llm_request), + "stream": False, + } + + if tools := self._convert_tools(llm_request): + payload["tools"] = tools + if options := self._convert_options(llm_request): + payload["options"] = options + + return payload + + def _extract_model_name(self, request_model: Optional[str]) -> str: + """Normalize model name for the Ollama API. + + Supports: + * "ollama/llama3.2" → "llama3.2" + * "ollama_chat/llama3.2" → "llama3.2" + * "llama3.2" → "llama3.2" + """ + model_name = request_model or self.model + if model_name.startswith("ollama/") or model_name.startswith( + "ollama_chat/" + ): + return model_name.split("/", 1)[1] + return model_name + + def _convert_messages(self, llm_request: LlmRequest) -> list[dict[str, str]]: + """Convert ADK Contents into Ollama chat messages.""" + messages: list[dict[str, str]] = [] + + # System instruction → first system message. + system_instruction = llm_request.config.system_instruction + if system_instruction: + messages.append({ + "role": "system", + "content": self._system_instruction_to_text(system_instruction), + }) + + # User / assistant / tool messages. + for content in llm_request.contents: + message_text = self._content_to_text(content) + if not message_text: + continue + role = self._map_role(content.role) + messages.append({"role": role, "content": message_text}) + + return messages + + def _system_instruction_to_text(self, system_instruction: Any) -> str: + """Normalize `system_instruction` into plain text. + + It may be: + * a plain string + * a types.Content object + * a list/tuple of types.Content and/or strings + """ + # Single Content object + if isinstance(system_instruction, types.Content): + return self._content_to_text(system_instruction) + + # Sequence of items (e.g. list[Content]) + if isinstance(system_instruction, (list, tuple)): + pieces: list[str] = [] + for item in system_instruction: + if isinstance(item, types.Content): + pieces.append(self._content_to_text(item)) + elif item is not None: + pieces.append(str(item)) + return "\n".join(pieces) + + # Fallback: assume it's already string-like + return str(system_instruction) + + def _content_to_text(self, content: types.Content) -> str: + """Flatten a `Content` into plain text for Ollama. + + Encodes tool calls and tool responses as tagged lines so that the model + can reason about them and generate new tool calls. + """ + parts = content.parts or [] + text_parts: list[str] = [] + + for part in parts: + if part.text: + text_parts.append(part.text) + + elif part.function_response: + # Tool result from a previous call. + try: + response_json = json.dumps( + part.function_response.response, ensure_ascii=False + ) + except TypeError: + response_json = str(part.function_response.response) + text_parts.append( + f"[tool_response name={part.function_response.name or ''}]" + f" {response_json}" + ) + + elif part.function_call: + # A model-issued tool call (arguments as JSON). + try: + args_json = json.dumps(part.function_call.args, ensure_ascii=False) + except TypeError: + args_json = str(part.function_call.args) + text_parts.append( + f"[tool_call name={part.function_call.name}] {args_json}" + ) + + else: + logger.debug( + "Skipping unsupported content part for Ollama message: %s", part + ) + + return "\n".join(text_parts) + + def _map_role(self, role: Optional[str]) -> str: + if role in ("model", "assistant"): + return "assistant" + if role == "system": + return "system" + # "user", "tool", or anything else defaults to "user". + return "user" + + def _convert_tools(self, llm_request: LlmRequest) -> list[dict[str, Any]]: + """Convert ADK tool declarations into Ollama tool schema.""" + tools_spec: list[dict[str, Any]] = [] + if not llm_request.config.tools: + return tools_spec + + for tool in llm_request.config.tools: + function_declarations: Optional[Sequence[types.FunctionDeclaration]] = ( + tool.function_declarations if isinstance(tool, types.Tool) else None + ) + if not function_declarations: + continue + + for function_declaration in function_declarations: + tools_spec.append({ + "type": "function", + "function": { + "name": function_declaration.name, + "description": function_declaration.description or "", + "parameters": self._function_parameters_to_json( + function_declaration + ), + }, + }) + + return tools_spec + + def _function_parameters_to_json( + self, function_declaration: types.FunctionDeclaration + ) -> dict[str, Any]: + """Convert function parameters Schema → JSON Schema for Ollama.""" + if function_declaration.parameters is None: + return {"type": "object", "properties": {}} + + try: + return function_declaration.parameters.model_dump(exclude_none=True) + except AttributeError: + # model_dump is not guaranteed depending on the genai version. + try: + return json.loads( + function_declaration.parameters.model_dump_json(exclude_none=True) + ) + except (AttributeError, json.JSONDecodeError, TypeError) as exc: + logger.debug( + "Failed to convert function parameters, defaulting to empty" + " schema: %s", + exc, + ) + return {"type": "object", "properties": {}} + + def _convert_options(self, llm_request: LlmRequest) -> dict[str, Any]: + """Map ADK generation config fields to Ollama options.""" + options: dict[str, Any] = {} + config = llm_request.config + + temperature = getattr(config, "temperature", None) + if temperature is not None: + options["temperature"] = temperature + + top_p = getattr(config, "top_p", None) + if top_p is not None: + options["top_p"] = top_p + + max_output_tokens = getattr(config, "max_output_tokens", None) + if max_output_tokens is not None: + # Ollama uses `num_predict` to limit generated tokens. + options["num_predict"] = max_output_tokens + + return options + + # --------------------------------------------------------------------------- + # HTTP call + # --------------------------------------------------------------------------- + + def _post_chat(self, payload: dict[str, Any]) -> dict[str, Any]: + """Perform a blocking POST /api/chat call to Ollama. + Note: This method is intentionally blocking and is executed via + asyncio.to_thread() to avoid introducing additional async HTTP + dependencies. This keeps the backend consistent with existing ADK + providers. + """ + + url = self.host.rstrip("/") + _CHAT_ENDPOINT + data = json.dumps(payload).encode("utf-8") + request = urllib.request.Request( + url, + data=data, + headers={"Content-Type": "application/json"}, + method="POST", + ) + + try: + with urllib.request.urlopen( + request, timeout=self.request_timeout + ) as response: + response_body = response.read().decode("utf-8") + except urllib.error.URLError as exc: + raise RuntimeError(exc.reason) from exc + except urllib.error.HTTPError as exc: + message = exc.read().decode("utf-8", errors="ignore") + raise RuntimeError(f"{exc.code}: {message}") from exc + + return json.loads(response_body) + + # --------------------------------------------------------------------------- + # Response mapping + # --------------------------------------------------------------------------- + + def _to_llm_response( + self, + response_json: dict[str, Any], + request_model: Optional[str] = None, + ) -> LlmResponse: + """Convert Ollama JSON response → ADK `LlmResponse`.""" + if error := response_json.get("error"): + return LlmResponse( + error_code="OLLAMA_ERROR", + error_message=str(error), + ) + + message = response_json.get("message", {}) or {} + parts: list[types.Part] = [] + + # 1) Main text content. + content = message.get("content") + if isinstance(content, str) and content.strip(): + parts.append(types.Part.from_text(text=content)) + + # 2) Tool calls (if any). + for tool_call in message.get("tool_calls", []): + function_payload = tool_call.get("function", {}) or {} + name = function_payload.get("name") + arguments: Union[str, dict[str, Any], None] = function_payload.get( + "arguments" + ) + + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + logger.warning( + "Failed to parse tool call arguments as JSON: %s. Defaulting to" + " empty arguments.", + arguments, + ) + arguments = {} + elif arguments is None: + arguments = {} + + function_call = types.FunctionCall(name=name, args=arguments) + if tool_call_id := tool_call.get("id"): + # id is useful for correlating tool_call ↔ tool_response. + setattr(function_call, "id", tool_call_id) + + parts.append(types.Part(function_call=function_call)) + + if not parts: + return LlmResponse( + error_code="NO_CONTENT", + error_message="Ollama response did not contain model output.", + ) + + # 3) Usage mapping (Ollama → GenerateContentResponseUsageMetadata). + # Ollama returns: + # prompt_eval_count: tokens in prompt + # eval_count: tokens in completion + prompt_tokens = response_json.get("prompt_eval_count") + completion_tokens = response_json.get("eval_count") + + # Fallback: if someone wraps usage in a dict (e.g. in tests). + if prompt_tokens is None or completion_tokens is None: + usage = response_json.get("usage") or {} + if prompt_tokens is None: + prompt_tokens = usage.get("prompt_tokens") + if completion_tokens is None: + completion_tokens = usage.get("completion_tokens") + + usage_metadata: Optional[types.GenerateContentResponseUsageMetadata] = None + if prompt_tokens is not None and completion_tokens is not None: + total_tokens = response_json.get("total_tokens") + if total_tokens is None: + total_tokens = prompt_tokens + completion_tokens + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=prompt_tokens, + candidates_token_count=completion_tokens, + total_token_count=total_tokens, + ) + + # 4) Model version: prefer Ollama's `model`, fallback to request model. + model_version = response_json.get("model") or self._extract_model_name( + request_model or self.model + ) + + return LlmResponse( + content=types.Content(role="model", parts=parts), + model_version=model_version, + usage_metadata=usage_metadata, + ) diff --git a/tests/unittests/models/test_ollama.py b/tests/unittests/models/test_ollama.py new file mode 100644 index 0000000000..537424bbb3 --- /dev/null +++ b/tests/unittests/models/test_ollama.py @@ -0,0 +1,313 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.models.ollama_llm import Ollama +from google.genai import types +import pytest + +# ===================================================== +# Helpers +# ===================================================== + + +def mock_response_ok(text="Hello world", tool_calls=None): + """Create a minimal valid Ollama /api/chat response.""" + message = {"content": text} + if tool_calls: + message["tool_calls"] = tool_calls + return {"message": message} + + +# ===================================================== +# Model extraction +# ===================================================== + + +def test_extract_model_name_basic(): + o = Ollama(model="ollama/mistral") + assert o._extract_model_name("ollama/mistral") == "mistral" + + +def test_extract_model_name_chat_prefix(): + o = Ollama(model="ollama_chat/llama3.1") + assert o._extract_model_name("ollama_chat/llama3.1") == "llama3.1" + + +def test_extract_model_name_no_prefix(): + o = Ollama(model="mistral") + assert o._extract_model_name("mistral") == "mistral" + + +# ===================================================== +# Message conversion +# ===================================================== + + +def test_convert_messages_basic(): + o = Ollama() + req = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="Hi")]) + ] + ) + msgs = o._convert_messages(req) + assert msgs[0]["role"] == "user" + assert msgs[0]["content"] == "Hi" + + +def test_convert_messages_with_system(): + o = Ollama() + req = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="X")]) + ], + config=types.GenerateContentConfig(system_instruction="SYS"), + ) + msgs = o._convert_messages(req) + + assert msgs[0]["role"] == "system" + assert msgs[0]["content"] == "SYS" + assert msgs[1]["content"] == "X" + + +# ===================================================== +# Content → text extraction +# ===================================================== + + +def test_content_to_text_basic(): + o = Ollama() + content = types.Content(role="user", parts=[types.Part.from_text(text="ABC")]) + assert o._content_to_text(content) == "ABC" + + +def test_content_to_text_function_call(): + o = Ollama() + part = types.Part.from_function_call(name="add", args={"x": 1, "y": 2}) + part.function_call.id = "call123" + + content = types.Content(role="assistant", parts=[part]) + txt = o._content_to_text(content) + + assert "[tool_call name=add]" in txt + assert '"x": 1' in txt + + +def test_content_to_text_tool_response(): + o = Ollama() + part = types.Part.from_function_response(name="add", response={"z": 5}) + content = types.Content(role="tool", parts=[part]) + txt = o._content_to_text(content) + + assert "[tool_response name=add]" in txt + assert '"z": 5' in txt + + +# ===================================================== +# Tools conversion +# ===================================================== + + +def test_convert_tools_basic(): + o = Ollama() + req = LlmRequest( + config=types.GenerateContentConfig( + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="add", + description="Add numbers", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "x": types.Schema(type=types.Type.NUMBER) + }, + ), + ) + ] + ) + ] + ) + ) + + tools = o._convert_tools(req) + + assert tools[0]["function"]["name"] == "add" + assert tools[0]["function"]["parameters"]["type"] == types.Type.OBJECT + + +# ===================================================== +# POST wrapper +# ===================================================== + + +def test_post_chat_success(monkeypatch): + fake_response = {"message": {"content": "OK"}} + + def fake_urlopen(req, timeout=0): + class Resp: + + def __enter__(self): + return self + + def __exit__(self, *_): + return False + + def read(self): + return json.dumps(fake_response).encode("utf-8") + + return Resp() + + monkeypatch.setattr("urllib.request.urlopen", fake_urlopen) + + o = Ollama() + resp = o._post_chat({"model": "x"}) + + assert resp["message"]["content"] == "OK" + + +# ===================================================== +# LlmResponse conversion +# ===================================================== + + +def test_to_llm_response_text(): + o = Ollama() + resp = mock_response_ok("Hi") + + out = o._to_llm_response(resp) + + assert isinstance(out, LlmResponse) + assert out.content.parts[0].text == "Hi" + + +def test_to_llm_response_tool_call(): + o = Ollama() + tool_call = { + "id": "abc", + "function": {"name": "add", "arguments": '{"x": 1}'}, + } + + resp = mock_response_ok(tool_calls=[tool_call]) + out = o._to_llm_response(resp) + + fc = next(p.function_call for p in out.content.parts if p.function_call) + + assert fc.name == "add" + assert fc.args == {"x": 1} + assert fc.id == "abc" + + +def test_to_llm_response_tool_call_bad_json(): + o = Ollama() + tool_call = { + "id": "zzz", + "function": {"name": "add", "arguments": "{BAD_JSON"}, + } + + resp = mock_response_ok(tool_calls=[tool_call]) + out = o._to_llm_response(resp) + + fc = next(p.function_call for p in out.content.parts if p.function_call) + + assert fc.args == {} # fallback + + +def test_to_llm_response_usage_metadata(): + o = Ollama() + resp = mock_response_ok("Hi") + resp["prompt_eval_count"] = 10 + resp["eval_count"] = 5 + + out = o._to_llm_response(resp) + + assert out.usage_metadata.prompt_token_count == 10 + assert out.usage_metadata.candidates_token_count == 5 + assert out.usage_metadata.total_token_count == 15 + + +# ===================================================== +# Async generate_content_async() +# ===================================================== + + +@pytest.mark.asyncio +async def test_generate_content_async_basic(monkeypatch): + resp = mock_response_ok("Hello!") + + async def fake_thread(fn, *args): + return resp + + monkeypatch.setattr("asyncio.to_thread", fake_thread) + + o = Ollama(model="ollama/mistral") + req = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="Hi")]) + ] + ) + + results = [r async for r in o.generate_content_async(req)] + + assert results[0].content.parts[0].text == "Hello!" + + +@pytest.mark.asyncio +async def test_generate_content_async_error(monkeypatch): + async def fake_thread(fn, *args): + raise RuntimeError("boom") + + monkeypatch.setattr("asyncio.to_thread", fake_thread) + + o = Ollama() + req = LlmRequest(contents=[types.Content(role="user", parts=[])]) + + results = [r async for r in o.generate_content_async(req)] + + assert results[0].error_code == "OLLAMA_ERROR" + + +# ===================================================== +# Model override +# ===================================================== + + +@pytest.mark.asyncio +async def test_model_override(monkeypatch): + resp = mock_response_ok("Hello") + resp["model"] = "override" + + async def fake_thread(fn, *args): + payload = args[0] + assert payload["model"] == "override" + return resp + + monkeypatch.setattr("asyncio.to_thread", fake_thread) + + o = Ollama(model="default") + req = LlmRequest( + model="override", + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="X")]) + ], + ) + + out = [r async for r in o.generate_content_async(req)][0] + + assert out.model_version == "override"