From fbabf6cd998e6fde891074957e804bd212eb91ec Mon Sep 17 00:00:00 2001 From: Marisa Senkfor Date: Wed, 17 Dec 2025 16:18:42 -0500 Subject: [PATCH 1/5] feat: add neuron vllm model with basic tests --- src/strands/models/neuronvllm.py | 215 ++++++++ tests/strands/models/test_neuronvllm.py | 540 ++++++++++++++++++++ tests_integ/models/providers.py | 32 ++ tests_integ/models/test_model_neuronvllm.py | 101 ++++ 4 files changed, 888 insertions(+) create mode 100644 src/strands/models/neuronvllm.py create mode 100644 tests/strands/models/test_neuronvllm.py create mode 100644 tests_integ/models/test_model_neuronvllm.py diff --git a/src/strands/models/neuronvllm.py b/src/strands/models/neuronvllm.py new file mode 100644 index 000000000..96274a055 --- /dev/null +++ b/src/strands/models/neuronvllm.py @@ -0,0 +1,215 @@ +import json +import logging +from typing import Any, AsyncGenerator, Dict, List, Optional, Type, TypeVar, Union, cast + +from openai import AsyncOpenAI +from pydantic import BaseModel +from typing_extensions import TypedDict, Unpack, override + +from ..types.content import ContentBlock, Messages, SystemContentBlock +from ..types.streaming import StopReason, StreamEvent +from ..types.tools import ToolChoice, ToolSpec +from ._validation import validate_config_keys, warn_on_tool_choice_not_supported +from .model import Model + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class NeuronVLLMModel(Model): + """Neuron-vLLM model provider implementation.""" + + class NeuronVLLMConfig(TypedDict, total=False): + model_id: str + max_model_len: Optional[int] + max_num_seqs: Optional[int] + tensor_parallel_size: Optional[int] + block_size: Optional[int] + enable_prefix_caching: Optional[bool] + neuron_config: Optional[Dict[str, Any]] + device: Optional[str] + temperature: Optional[float] + top_p: Optional[float] + max_tokens: Optional[int] + stop_sequences: Optional[List[str]] + additional_args: Optional[Dict[str, Any]] + openai_api_key: Optional[str] + openai_api_base: Optional[str] + + def __init__(self, config: NeuronVLLMConfig): + validate_config_keys(config, self.NeuronVLLMConfig) + self.config = config + self.logger = logging.getLogger(__name__) + if not config.get("model_id"): + raise ValueError("model_id is required") + self._validate_hardware() + self.logger.info(f"Initializing NeuronVLLMModel with model: {config['model_id']}") + + def _validate_hardware(self) -> None: + try: + import torch_neuronx # type: ignore + self.logger.info("Neuron hardware validation passed") + except ImportError: + self.logger.warning("Neuron libraries not available - running in compatibility mode") + + @override + def update_config(self, **model_config: Unpack[NeuronVLLMConfig]) -> None: + validate_config_keys(model_config, self.NeuronVLLMConfig) + self.config.update(model_config) + + @override + def get_config(self) -> NeuronVLLMConfig: + return self.config + + def _format_request_message_contents(self, role: str, content: ContentBlock) -> list[dict[str, Any]]: + if "text" in content: + return [{"role": role, "content": content["text"]}] + if "image" in content: + return [{"role": role, "images": [content["image"]["source"]["bytes"]]}] + if "toolUse" in content: + return [{"role": role, "tool_calls": [{"function": {"name": content["toolUse"]["toolUseId"], "arguments": content["toolUse"]["input"]}}]}] + if "toolResult" in content: + return [ + formatted + for tool_result in content["toolResult"]["content"] + for formatted in self._format_request_message_contents( + "tool", + {"text": json.dumps(tool_result["json"])} if "json" in tool_result else cast(ContentBlock, tool_result), + ) + ] + raise TypeError(f"Unsupported content type: {next(iter(content))}") + + def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + system_message = [{"role": "system", "content": system_prompt}] if system_prompt else [] + return system_message + [ + formatted_message + for message in messages + for content in message["content"] + for formatted_message in self._format_request_message_contents(message["role"], content) + ] + + def format_request(self, messages: Messages, tool_specs: Optional[List[ToolSpec]] = None, system_prompt: Optional[str] = None, stream: bool = True) -> dict[str, Any]: + """Return a dictionary suitable for OpenAI Async client.""" + request: dict[str, Any] = { + "messages": self._format_request_messages(messages, system_prompt), + "model": self.config["model_id"], + "temperature": self.config.get("temperature"), + "top_p": self.config.get("top_p"), + "max_tokens": self.config.get("max_tokens"), + "stop": self.config.get("stop_sequences"), + "stream": stream, + } + if tool_specs: + request["functions"] = [ + { + "name": t["name"], + "description": t["description"], + "parameters": t["inputSchema"]["json"], + } + for t in tool_specs + ] + if self.config.get("additional_args"): + request.update(self.config["additional_args"]) + return request + + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Convert raw events into StreamEvent.""" + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + case "content_start": + if event["data_type"] == "text": + return {"contentBlockStart": {"start": {}}} + tool_name = event["data"].function.name + return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_name}}}} + case "content_delta": + if event["data_type"] == "text": + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + tool_arguments = event["data"].function.arguments + return {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(tool_arguments)}}}} + case "content_stop": + return {"contentBlockStop": {}} + case "message_stop": + reason: StopReason = "tool_use" if event["data"] == "tool_use" else "end_turn" + return {"messageStop": {"stopReason": reason}} + case "metadata": + return {"metadata": {"usage": {}, "metrics": {}}} + case _: + raise RuntimeError(f"Unknown chunk_type: {event['chunk_type']}") + + @override + async def stream( + self, + messages: Messages, + tool_specs: Optional[List[ToolSpec]] = None, + system_prompt: Optional[str] = None, + *, + tool_choice: ToolChoice | None = None, + system_prompt_content: Optional[List[SystemContentBlock]] = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + warn_on_tool_choice_not_supported(tool_choice) + + request = self.format_request(messages, tool_specs, system_prompt, stream=True) + client = AsyncOpenAI( + api_key=self.config.get("openai_api_key", "EMPTY"), + base_url=self.config.get("openai_api_base", "http://localhost:8084/v1"), + ) + + tool_requested = False + finish_reason: str | None = None + + yield self.format_chunk({"chunk_type": "message_start"}) + yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) + + stream_response = await client.chat.completions.create(**request) + async for chunk in stream_response: + choice = chunk.choices[0] + delta = choice.delta + + if delta.content: + yield self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": delta.content}) + + if delta.tool_calls: + for tool_call in delta.tool_calls: + yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_call}) + yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_call}) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool", "data": tool_call}) + tool_requested = True + + if choice.finish_reason: + finish_reason = choice.finish_reason + + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + yield self.format_chunk({"chunk_type": "message_stop", "data": "tool_use" if tool_requested else finish_reason}) + + @override + async def structured_output( + self, + output_model: Type[T], + prompt: Messages, + system_prompt: Optional[str] = None, + **kwargs: Any, + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + tool_spec = ToolSpec( + name=output_model.__name__, + description=f"Return a {output_model.__name__}", + input_schema=output_model.model_json_schema(), + ) + request = self.format_request(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, stream=False) + request["tool_choice"] = {"type": "function", "function": {"name": tool_spec.name}} + + client = AsyncOpenAI( + api_key=self.config.get("openai_api_key", "EMPTY"), + base_url=self.config.get("openai_api_base", "http://localhost:8084/v1"), + ) + response = await client.chat.completions.create(**request) + + message = response.choices[0].message + if not message.tool_calls: + raise ValueError("Expected structured output via tool call") + + tool_call = message.tool_calls[0] + output = output_model.model_validate_json(tool_call.function.arguments) + yield {"output": output} diff --git a/tests/strands/models/test_neuronvllm.py b/tests/strands/models/test_neuronvllm.py new file mode 100644 index 000000000..7cd56978f --- /dev/null +++ b/tests/strands/models/test_neuronvllm.py @@ -0,0 +1,540 @@ +import json +import unittest.mock + +import pydantic +import pytest + +from strands.models.neuronvllm import NeuronVLLMModel +from strands.types.content import Messages + + +@pytest.fixture +def neuronvllm_client(monkeypatch: pytest.MonkeyPatch) -> unittest.mock.Mock: + from strands import models + + mock_client_cls = unittest.mock.Mock() + mock_client = unittest.mock.AsyncMock() + mock_client.chat.completions.create = unittest.mock.AsyncMock() + mock_client_cls.return_value = mock_client + + monkeypatch.setattr(models.neuronvllm, "AsyncOpenAI", mock_client_cls) + return mock_client + + +@pytest.fixture +def model_id() -> str: + return "m1" + + +@pytest.fixture +def model(model_id: str) -> NeuronVLLMModel: + return NeuronVLLMModel({"model_id": model_id}) + + +@pytest.fixture +def messages() -> Messages: + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def system_prompt() -> str: + return "s1" + + +@pytest.fixture +def test_output_model_cls() -> type[pydantic.BaseModel]: + class TestOutputModel(pydantic.BaseModel): + name: str + age: int + + return TestOutputModel + + +def test__init__model_configs(model_id: str) -> None: + model = NeuronVLLMModel({"model_id": model_id, "max_tokens": 1}) + + tru_max_tokens = model.get_config().get("max_tokens") + exp_max_tokens = 1 + + assert tru_max_tokens == exp_max_tokens + + +def test_update_config(model: NeuronVLLMModel, model_id: str) -> None: + model.update_config(model_id=model_id) + + tru_model_id = model.get_config().get("model_id") + exp_model_id = model_id + + assert tru_model_id == exp_model_id + + +def test_format_request_default(model: NeuronVLLMModel, messages: Messages, model_id: str) -> None: + tru_request = model.format_request(messages) + exp_request = { + "messages": [{"role": "user", "content": "test"}], + "model": model_id, + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + } + + assert tru_request == exp_request + + +def test_format_request_with_override(model: NeuronVLLMModel, messages: Messages, model_id: str) -> None: + model.update_config(model_id=model_id) + tru_request = model.format_request(messages, tool_specs=None) + exp_request = { + "messages": [{"role": "user", "content": "test"}], + "model": model_id, + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + } + + assert tru_request == exp_request + + +def test_format_request_with_system_prompt( + model: NeuronVLLMModel, messages: Messages, model_id: str, system_prompt: str +) -> None: + tru_request = model.format_request(messages, system_prompt=system_prompt) + exp_request = { + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": "test"}, + ], + "model": model_id, + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + } + + assert tru_request == exp_request + + +def test_format_request_with_image(model: NeuronVLLMModel, model_id: str) -> None: + messages: Messages = [{"role": "user", "content": [{"image": {"source": {"bytes": "base64encodedimage"}}}]}] + + tru_request = model.format_request(messages) + exp_request = { + "messages": [{"role": "user", "images": ["base64encodedimage"]}], + "model": model_id, + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + } + + assert tru_request == exp_request + + +def test_format_request_with_tool_use(model: NeuronVLLMModel, model_id: str) -> None: + messages: Messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "calculator", "input": '{"expression": "2+2"}'}}]} + ] + + tru_request = model.format_request(messages) + exp_request = { + "messages": [ + { + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "calculator", + "arguments": '{"expression": "2+2"}', + } + } + ], + } + ], + "model": model_id, + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + } + + assert tru_request == exp_request + + +def test_format_request_with_tool_result(model: NeuronVLLMModel, model_id: str) -> None: + messages: Messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "calculator", + "status": "success", + "content": [ + {"text": "4"}, + {"image": {"source": {"bytes": b"image"}}}, + {"json": ["4"]}, + ], + }, + }, + { + "text": "see results", + }, + ], + }, + ] + + tru_request = model.format_request(messages) + exp_request = { + "messages": [ + { + "role": "tool", + "content": "4", + }, + { + "role": "tool", + "images": [b"image"], + }, + { + "role": "tool", + "content": '["4"]', + }, + { + "role": "user", + "content": "see results", + }, + ], + "model": model_id, + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + } + + assert tru_request == exp_request + + +def test_format_request_with_unsupported_type(model: NeuronVLLMModel) -> None: + messages: Messages = [ + { + "role": "user", + "content": [{"unsupported": {}}], + }, + ] + + with pytest.raises(TypeError, match="Unsupported content type: unsupported"): + model.format_request(messages) + + +def test_format_request_with_tool_specs(model: NeuronVLLMModel, messages: Messages, model_id: str) -> None: + tool_specs = [ + { + "name": "calculator", + "description": "Calculate mathematical expressions", + "inputSchema": { + "json": {"type": "object", "properties": {"expression": {"type": "string"}}, "required": ["expression"]} + }, + } + ] + + tru_request = model.format_request(messages, tool_specs) + exp_request = { + "messages": [{"role": "user", "content": "test"}], + "model": model_id, + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + "functions": [ + { + "name": "calculator", + "description": "Calculate mathematical expressions", + "parameters": { + "type": "object", + "properties": {"expression": {"type": "string"}}, + "required": ["expression"], + }, + } + ], + } + + assert tru_request == exp_request + + +def test_format_request_with_inference_config(model: NeuronVLLMModel, messages: Messages, model_id: str) -> None: + inference_config = { + "max_tokens": 1, + "stop_sequences": ["stop"], + "temperature": 1.0, + "top_p": 1.0, + } + + model.update_config(**inference_config) + tru_request = model.format_request(messages) + exp_request = { + "messages": [{"role": "user", "content": "test"}], + "model": model_id, + "temperature": inference_config["temperature"], + "top_p": inference_config["top_p"], + "max_tokens": inference_config["max_tokens"], + "stop": inference_config["stop_sequences"], + "stream": True, + } + + assert tru_request == exp_request + + +def test_format_request_with_additional_args(model: NeuronVLLMModel, messages: Messages, model_id: str) -> None: + additional_args = {"o1": 1} + + model.update_config(additional_args=additional_args) + tru_request = model.format_request(messages) + exp_request = { + "messages": [{"role": "user", "content": "test"}], + "model": model_id, + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + "o1": 1, + } + + assert tru_request == exp_request + + +def test_format_chunk_message_start(model: NeuronVLLMModel) -> None: + event = {"chunk_type": "message_start"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"messageStart": {"role": "assistant"}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_start_text(model: NeuronVLLMModel) -> None: + event = {"chunk_type": "content_start", "data_type": "text"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockStart": {"start": {}}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_start_tool(model: NeuronVLLMModel) -> None: + mock_function = unittest.mock.Mock() + mock_function.function.name = "calculator" + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_function} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "calculator"}}}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_delta_text(model: NeuronVLLMModel) -> None: + event = {"chunk_type": "content_delta", "data_type": "text", "data": "Hello"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_delta_tool(model: NeuronVLLMModel) -> None: + event = { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments={"expression": "2+2"})), + } + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps({"expression": "2+2"})}}}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_stop(model: NeuronVLLMModel) -> None: + event = {"chunk_type": "content_stop"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockStop": {}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_message_stop_end_turn(model: NeuronVLLMModel) -> None: + event = {"chunk_type": "message_stop", "data": "stop"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"messageStop": {"stopReason": "end_turn"}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_message_stop_tool_use(model: NeuronVLLMModel) -> None: + event = {"chunk_type": "message_stop", "data": "tool_use"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"messageStop": {"stopReason": "tool_use"}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_metadata(model: NeuronVLLMModel) -> None: + event = { + "chunk_type": "metadata", + "data": unittest.mock.Mock(), + } + + tru_chunk = model.format_chunk(event) + exp_chunk = { + "metadata": { + "usage": {}, + "metrics": {}, + }, + } + + assert tru_chunk == exp_chunk + + +def test_format_chunk_other(model: NeuronVLLMModel) -> None: + event = {"chunk_type": "other"} + + with pytest.raises(RuntimeError, match="Unknown chunk_type: other"): + model.format_chunk(event) + + +@pytest.mark.asyncio +async def test_stream( + neuronvllm_client: unittest.mock.Mock, + model: NeuronVLLMModel, + agenerator, + alist, +) -> None: + mock_chunk = unittest.mock.Mock() + mock_choice = unittest.mock.Mock() + mock_delta = unittest.mock.Mock() + mock_delta.content = "Hello" + mock_delta.tool_calls = None + mock_choice.delta = mock_delta + mock_choice.finish_reason = "stop" + mock_chunk.choices = [mock_choice] + + neuronvllm_client.chat.completions.create.return_value = agenerator([mock_chunk]) + + messages: Messages = [{"role": "user", "content": [{"text": "Hello"}]}] + response = model.stream(messages) + + tru_events = await alist(response) + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "Hello"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + + assert tru_events == exp_events + + expected_request = { + "messages": [{"role": "user", "content": "Hello"}], + "model": "m1", + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + } + neuronvllm_client.chat.completions.create.assert_awaited_once_with(**expected_request) + + +@pytest.mark.asyncio +async def test_stream_with_tool_calls( + neuronvllm_client: unittest.mock.Mock, + model: NeuronVLLMModel, + agenerator, + alist, +) -> None: + mock_chunk = unittest.mock.Mock() + mock_choice = unittest.mock.Mock() + mock_delta = unittest.mock.Mock() + + mock_tool_call = unittest.mock.Mock() + mock_tool_call.function.name = "calculator" + mock_tool_call.function.arguments = {"expression": "2+2"} + + mock_delta.content = "I'll calculate that for you" + mock_delta.tool_calls = [mock_tool_call] + mock_choice.delta = mock_delta + mock_choice.finish_reason = "stop" + mock_chunk.choices = [mock_choice] + + neuronvllm_client.chat.completions.create.return_value = agenerator([mock_chunk]) + + messages: Messages = [{"role": "user", "content": [{"text": "Calculate 2+2"}]}] + response = model.stream(messages) + + tru_events = await alist(response) + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "calculator"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, + {"contentBlockStop": {}}, + {"contentBlockDelta": {"delta": {"text": "I'll calculate that for you"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + ] + + assert tru_events == exp_events + + expected_request = { + "messages": [{"role": "user", "content": "Calculate 2+2"}], + "model": "m1", + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + } + neuronvllm_client.chat.completions.create.assert_awaited_once_with(**expected_request) + + +@pytest.mark.asyncio +async def test_structured_output( + neuronvllm_client: unittest.mock.Mock, + model: NeuronVLLMModel, + test_output_model_cls: type[pydantic.BaseModel], + alist, +) -> None: + messages: Messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + mock_response = unittest.mock.Mock() + mock_tool_call = unittest.mock.Mock() + mock_tool_call.function.arguments = '{"name": "John", "age": 30}' + mock_message = unittest.mock.Mock() + mock_message.tool_calls = [mock_tool_call] + mock_choice = unittest.mock.Mock() + mock_choice.message = mock_message + mock_response.choices = [mock_choice] + + neuronvllm_client.chat.completions.create = unittest.mock.AsyncMock(return_value=mock_response) + + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + + tru_result = events[-1] + exp_result = {"output": test_output_model_cls(name="John", age=30)} + assert tru_result == exp_result + + diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 75cc58f74..11011b4ef 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -14,6 +14,7 @@ from strands.models.litellm import LiteLLMModel from strands.models.llamaapi import LlamaAPIModel from strands.models.mistral import MistralModel +from strands.models.neuronvllm import NeuronVLLMModel from strands.models.ollama import OllamaModel from strands.models.openai import OpenAIModel from strands.models.writer import WriterModel @@ -59,6 +60,36 @@ def __init__(self): ) +class NeuronVLLMProviderInfo(ProviderInfo): + """Special case Neuron vLLM as it's dependent on the server being available.""" + + def __init__(self) -> None: + base_url = os.getenv("NEURON_VLLM_BASE_URL", "http://localhost:8084/v1") + model_id = os.getenv("NEURON_VLLM_MODEL_ID", "neuron-llama-3.1-70b-instruct") + + super().__init__( + id="neuronvllm", + factory=lambda: NeuronVLLMModel( + { + "model_id": model_id, + "openai_api_base": base_url, + "openai_api_key": os.getenv("NEURON_VLLM_API_KEY", "EMPTY"), + } + ), + ) + + is_server_available = False + try: + is_server_available = requests.get(base_url).ok + except requests.exceptions.ConnectionError: + pass + + self.mark = mark.skipif( + not is_server_available, + reason=f"Local Neuron vLLM endpoint not available at {base_url}", + ) + + anthropic = ProviderInfo( id="anthropic", environment_variable="ANTHROPIC_API_KEY", @@ -138,6 +169,7 @@ def __init__(self): ) ollama = OllamaProviderInfo() +neuronvllm = NeuronVLLMProviderInfo() all_providers = [ diff --git a/tests_integ/models/test_model_neuronvllm.py b/tests_integ/models/test_model_neuronvllm.py new file mode 100644 index 000000000..ffaa1dc46 --- /dev/null +++ b/tests_integ/models/test_model_neuronvllm.py @@ -0,0 +1,101 @@ +import os + +import pytest +from pydantic import BaseModel + +import strands +from strands import Agent +from strands.models.neuronvllm import NeuronVLLMModel +from tests_integ.models import providers + +# these tests only run if we have the neuron-vLLM server running +pytestmark = providers.neuronvllm.mark + + +@pytest.fixture +def model() -> NeuronVLLMModel: + base_url = os.getenv("NEURON_VLLM_BASE_URL", "http://localhost:8084/v1") + model_id = os.getenv("NEURON_VLLM_MODEL_ID", "neuron-llama-3.1-70b-instruct") + api_key = os.getenv("NEURON_VLLM_API_KEY", "EMPTY") + + return NeuronVLLMModel( + { + "model_id": model_id, + "openai_api_base": base_url, + "openai_api_key": api_key, + } + ) + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def agent(model: NeuronVLLMModel, tools): + return Agent(model=model, tools=tools) + + +@pytest.fixture +def weather(): + class Weather(BaseModel): + """Extracts the time and weather from the user's message with the exact strings.""" + + time: str + weather: str + + return Weather(time="12:00", weather="sunny") + + +def test_agent_invoke(agent): + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +async def test_agent_invoke_async(agent): + result = await agent.invoke_async("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +async def test_agent_stream_async(agent): + stream = agent.stream_async("What is the time and weather in New York?") + async for event in stream: + _ = event + + result = event["result"] + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +def test_agent_structured_output(agent, weather): + tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") + exp_weather = weather + assert tru_weather == exp_weather + + +@pytest.mark.asyncio +async def test_agent_structured_output_async(agent, weather): + tru_weather = await agent.structured_output_async( + type(weather), + "The time is 12:00 and the weather is sunny", + ) + exp_weather = weather + assert tru_weather == exp_weather + + From becfd0150ef50e767821a2abfaf0abbc6d21a01a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 18 Dec 2025 16:41:19 +0000 Subject: [PATCH 2/5] fix: format request --- src/strands/models/neuronvllm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/strands/models/neuronvllm.py b/src/strands/models/neuronvllm.py index 96274a055..756a1af8c 100644 --- a/src/strands/models/neuronvllm.py +++ b/src/strands/models/neuronvllm.py @@ -67,6 +67,12 @@ def _format_request_message_contents(self, role: str, content: ContentBlock) -> return [{"role": role, "content": content["text"]}] if "image" in content: return [{"role": role, "images": [content["image"]["source"]["bytes"]]}] + if "document" in content: + doc = content["document"] + name = doc.get("name", "document") + fmt = doc.get("format", "unknown") + text = f"[Attached document: {name} ({fmt})]" + return [{"role": role, "content": text}] if "toolUse" in content: return [{"role": role, "tool_calls": [{"function": {"name": content["toolUse"]["toolUseId"], "arguments": content["toolUse"]["input"]}}]}] if "toolResult" in content: From e90ee4b852cfaea6f0aff8f919ce53afef28fbe6 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 18 Dec 2025 17:01:47 +0000 Subject: [PATCH 3/5] fix: failing unit test --- src/strands/models/neuronvllm.py | 12 +++---- tests/strands/models/test_neuronvllm.py | 46 +++++++++++++++++++------ 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/src/strands/models/neuronvllm.py b/src/strands/models/neuronvllm.py index 756a1af8c..823d2d426 100644 --- a/src/strands/models/neuronvllm.py +++ b/src/strands/models/neuronvllm.py @@ -198,13 +198,13 @@ async def structured_output( system_prompt: Optional[str] = None, **kwargs: Any, ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: - tool_spec = ToolSpec( - name=output_model.__name__, - description=f"Return a {output_model.__name__}", - input_schema=output_model.model_json_schema(), - ) + tool_spec: ToolSpec = { + "name": output_model.__name__, + "description": f"Return a {output_model.__name__}", + "inputSchema": {"json": output_model.model_json_schema()}, + } request = self.format_request(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, stream=False) - request["tool_choice"] = {"type": "function", "function": {"name": tool_spec.name}} + request["tool_choice"] = {"type": "function", "function": {"name": tool_spec["name"]}} client = AsyncOpenAI( api_key=self.config.get("openai_api_key", "EMPTY"), diff --git a/tests/strands/models/test_neuronvllm.py b/tests/strands/models/test_neuronvllm.py index 7cd56978f..4b2499996 100644 --- a/tests/strands/models/test_neuronvllm.py +++ b/tests/strands/models/test_neuronvllm.py @@ -485,18 +485,42 @@ async def test_stream_with_tool_calls( response = model.stream(messages) tru_events = await alist(response) - exp_events = [ - {"messageStart": {"role": "assistant"}}, - {"contentBlockStart": {"start": {}}}, - {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "calculator"}}}}, - {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, - {"contentBlockStop": {}}, - {"contentBlockDelta": {"delta": {"text": "I'll calculate that for you"}}}, - {"contentBlockStop": {}}, - {"messageStop": {"stopReason": "tool_use"}}, - ] - assert tru_events == exp_events + # Basic structural checks: first start, last stop + assert tru_events[0] == {"messageStart": {"role": "assistant"}} + assert tru_events[1] == {"contentBlockStart": {"start": {}}} + assert tru_events[-1] == {"messageStop": {"stopReason": "tool_use"}} + + # One toolUse start with expected name/id + tool_starts = [ + e + for e in tru_events + if e.get("contentBlockStart", {}).get("start", {}).get("toolUse") is not None + ] + assert len(tool_starts) == 1 + tool_use = tool_starts[0]["contentBlockStart"]["start"]["toolUse"] + assert tool_use["name"] == "calculator" + assert tool_use["toolUseId"] == "calculator" + + # One toolUse delta with expected input + tool_deltas = [ + e + for e in tru_events + if "contentBlockDelta" in e + and "toolUse" in e["contentBlockDelta"]["delta"] + ] + assert len(tool_deltas) == 1 + assert tool_deltas[0]["contentBlockDelta"]["delta"]["toolUse"]["input"] == '{"expression": "2+2"}' + + # One text delta with the assistant message + text_deltas = [ + e + for e in tru_events + if "contentBlockDelta" in e + and "text" in e["contentBlockDelta"]["delta"] + ] + assert len(text_deltas) == 1 + assert text_deltas[0]["contentBlockDelta"]["delta"]["text"] == "I'll calculate that for you" expected_request = { "messages": [{"role": "user", "content": "Calculate 2+2"}], From 386ec11ca38c6a8993192ce0c14610a9a8366afe Mon Sep 17 00:00:00 2001 From: Marisa Senkfor Date: Thu, 18 Dec 2025 17:34:10 +0000 Subject: [PATCH 4/5] fix: integ test --- tests_integ/models/providers.py | 1 + tests_integ/models/test_model_neuronvllm.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 11011b4ef..080e09c94 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -182,4 +182,5 @@ def __init__(self) -> None: mistral, openai, writer, + neuronvllm, ] diff --git a/tests_integ/models/test_model_neuronvllm.py b/tests_integ/models/test_model_neuronvllm.py index ffaa1dc46..3beca3baa 100644 --- a/tests_integ/models/test_model_neuronvllm.py +++ b/tests_integ/models/test_model_neuronvllm.py @@ -14,9 +14,9 @@ @pytest.fixture def model() -> NeuronVLLMModel: - base_url = os.getenv("NEURON_VLLM_BASE_URL", "http://localhost:8084/v1") + base_url = os.getenv("OPENAI_API_BASE_URL", "http://localhost:8084/v1") model_id = os.getenv("NEURON_VLLM_MODEL_ID", "neuron-llama-3.1-70b-instruct") - api_key = os.getenv("NEURON_VLLM_API_KEY", "EMPTY") + api_key = os.getenv("OPENAI_API_KEY", "EMPTY") return NeuronVLLMModel( { From aa4a0aba06f27fc5501e7bc826cab3d1fae179e7 Mon Sep 17 00:00:00 2001 From: Marisa Senkfor Date: Thu, 18 Dec 2025 18:00:00 +0000 Subject: [PATCH 5/5] fix: linting --- src/strands/models/neuronvllm.py | 63 +++++++++++++++++---- tests/strands/models/test_neuronvllm.py | 32 ++++------- tests_integ/models/test_model_neuronvllm.py | 2 - 3 files changed, 62 insertions(+), 35 deletions(-) diff --git a/src/strands/models/neuronvllm.py b/src/strands/models/neuronvllm.py index 823d2d426..9023d0b1a 100644 --- a/src/strands/models/neuronvllm.py +++ b/src/strands/models/neuronvllm.py @@ -1,3 +1,6 @@ +"""Neuron-vLLM model provider implementation.""" + +import importlib.util import json import logging from typing import Any, AsyncGenerator, Dict, List, Optional, Type, TypeVar, Union, cast @@ -21,6 +24,8 @@ class NeuronVLLMModel(Model): """Neuron-vLLM model provider implementation.""" class NeuronVLLMConfig(TypedDict, total=False): + """Configuration for NeuronVLLMModel.""" + model_id: str max_model_len: Optional[int] max_num_seqs: Optional[int] @@ -38,23 +43,23 @@ class NeuronVLLMConfig(TypedDict, total=False): openai_api_base: Optional[str] def __init__(self, config: NeuronVLLMConfig): + """Initialize the NeuronVLLMModel with the given configuration.""" validate_config_keys(config, self.NeuronVLLMConfig) self.config = config self.logger = logging.getLogger(__name__) if not config.get("model_id"): raise ValueError("model_id is required") self._validate_hardware() - self.logger.info(f"Initializing NeuronVLLMModel with model: {config['model_id']}") + self.logger.info("Initializing NeuronVLLMModel with model: %s", config["model_id"]) def _validate_hardware(self) -> None: - try: - import torch_neuronx # type: ignore + if importlib.util.find_spec("torch_neuronx") is not None: self.logger.info("Neuron hardware validation passed") - except ImportError: + else: self.logger.warning("Neuron libraries not available - running in compatibility mode") @override - def update_config(self, **model_config: Unpack[NeuronVLLMConfig]) -> None: + def update_config(self, **model_config: Unpack[NeuronVLLMConfig]) -> None: # type: ignore[override] validate_config_keys(model_config, self.NeuronVLLMConfig) self.config.update(model_config) @@ -74,14 +79,28 @@ def _format_request_message_contents(self, role: str, content: ContentBlock) -> text = f"[Attached document: {name} ({fmt})]" return [{"role": role, "content": text}] if "toolUse" in content: - return [{"role": role, "tool_calls": [{"function": {"name": content["toolUse"]["toolUseId"], "arguments": content["toolUse"]["input"]}}]}] + return [ + { + "role": role, + "tool_calls": [ + { + "function": { + "name": content["toolUse"]["toolUseId"], + "arguments": content["toolUse"]["input"], + } + } + ], + } + ] if "toolResult" in content: return [ formatted for tool_result in content["toolResult"]["content"] for formatted in self._format_request_message_contents( "tool", - {"text": json.dumps(tool_result["json"])} if "json" in tool_result else cast(ContentBlock, tool_result), + {"text": json.dumps(tool_result["json"])} + if "json" in tool_result + else cast(ContentBlock, tool_result), ) ] raise TypeError(f"Unsupported content type: {next(iter(content))}") @@ -95,7 +114,13 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s for formatted_message in self._format_request_message_contents(message["role"], content) ] - def format_request(self, messages: Messages, tool_specs: Optional[List[ToolSpec]] = None, system_prompt: Optional[str] = None, stream: bool = True) -> dict[str, Any]: + def format_request( + self, + messages: Messages, + tool_specs: Optional[List[ToolSpec]] = None, + system_prompt: Optional[str] = None, + stream: bool = True, + ) -> dict[str, Any]: """Return a dictionary suitable for OpenAI Async client.""" request: dict[str, Any] = { "messages": self._format_request_messages(messages, system_prompt), @@ -115,8 +140,9 @@ def format_request(self, messages: Messages, tool_specs: Optional[List[ToolSpec] } for t in tool_specs ] - if self.config.get("additional_args"): - request.update(self.config["additional_args"]) + additional_args = self.config.get("additional_args") + if additional_args: + request.update(additional_args) return request def format_chunk(self, event: dict[str, Any]) -> StreamEvent: @@ -140,7 +166,18 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: reason: StopReason = "tool_use" if event["data"] == "tool_use" else "end_turn" return {"messageStop": {"stopReason": reason}} case "metadata": - return {"metadata": {"usage": {}, "metrics": {}}} + return { + "metadata": { + "usage": { + "inputTokens": 0, + "outputTokens": 0, + "totalTokens": 0, + }, + "metrics": { + "latencyMs": 0, + }, + }, + } case _: raise RuntimeError(f"Unknown chunk_type: {event['chunk_type']}") @@ -203,7 +240,9 @@ async def structured_output( "description": f"Return a {output_model.__name__}", "inputSchema": {"json": output_model.model_json_schema()}, } - request = self.format_request(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, stream=False) + request = self.format_request( + messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, stream=False + ) request["tool_choice"] = {"type": "function", "function": {"name": tool_spec["name"]}} client = AsyncOpenAI( diff --git a/tests/strands/models/test_neuronvllm.py b/tests/strands/models/test_neuronvllm.py index 4b2499996..33ec321cc 100644 --- a/tests/strands/models/test_neuronvllm.py +++ b/tests/strands/models/test_neuronvllm.py @@ -399,8 +399,14 @@ def test_format_chunk_metadata(model: NeuronVLLMModel) -> None: tru_chunk = model.format_chunk(event) exp_chunk = { "metadata": { - "usage": {}, - "metrics": {}, + "usage": { + "inputTokens": 0, + "outputTokens": 0, + "totalTokens": 0, + }, + "metrics": { + "latencyMs": 0, + }, }, } @@ -492,33 +498,19 @@ async def test_stream_with_tool_calls( assert tru_events[-1] == {"messageStop": {"stopReason": "tool_use"}} # One toolUse start with expected name/id - tool_starts = [ - e - for e in tru_events - if e.get("contentBlockStart", {}).get("start", {}).get("toolUse") is not None - ] + tool_starts = [e for e in tru_events if e.get("contentBlockStart", {}).get("start", {}).get("toolUse") is not None] assert len(tool_starts) == 1 tool_use = tool_starts[0]["contentBlockStart"]["start"]["toolUse"] assert tool_use["name"] == "calculator" assert tool_use["toolUseId"] == "calculator" # One toolUse delta with expected input - tool_deltas = [ - e - for e in tru_events - if "contentBlockDelta" in e - and "toolUse" in e["contentBlockDelta"]["delta"] - ] + tool_deltas = [e for e in tru_events if "contentBlockDelta" in e and "toolUse" in e["contentBlockDelta"]["delta"]] assert len(tool_deltas) == 1 assert tool_deltas[0]["contentBlockDelta"]["delta"]["toolUse"]["input"] == '{"expression": "2+2"}' # One text delta with the assistant message - text_deltas = [ - e - for e in tru_events - if "contentBlockDelta" in e - and "text" in e["contentBlockDelta"]["delta"] - ] + text_deltas = [e for e in tru_events if "contentBlockDelta" in e and "text" in e["contentBlockDelta"]["delta"]] assert len(text_deltas) == 1 assert text_deltas[0]["contentBlockDelta"]["delta"]["text"] == "I'll calculate that for you" @@ -560,5 +552,3 @@ async def test_structured_output( tru_result = events[-1] exp_result = {"output": test_output_model_cls(name="John", age=30)} assert tru_result == exp_result - - diff --git a/tests_integ/models/test_model_neuronvllm.py b/tests_integ/models/test_model_neuronvllm.py index 3beca3baa..c1fd3d7ee 100644 --- a/tests_integ/models/test_model_neuronvllm.py +++ b/tests_integ/models/test_model_neuronvllm.py @@ -97,5 +97,3 @@ async def test_agent_structured_output_async(agent, weather): ) exp_weather = weather assert tru_weather == exp_weather - -