diff --git a/src/strands/models/neuronvllm.py b/src/strands/models/neuronvllm.py new file mode 100644 index 000000000..9023d0b1a --- /dev/null +++ b/src/strands/models/neuronvllm.py @@ -0,0 +1,260 @@ +"""Neuron-vLLM model provider implementation.""" + +import importlib.util +import json +import logging +from typing import Any, AsyncGenerator, Dict, List, Optional, Type, TypeVar, Union, cast + +from openai import AsyncOpenAI +from pydantic import BaseModel +from typing_extensions import TypedDict, Unpack, override + +from ..types.content import ContentBlock, Messages, SystemContentBlock +from ..types.streaming import StopReason, StreamEvent +from ..types.tools import ToolChoice, ToolSpec +from ._validation import validate_config_keys, warn_on_tool_choice_not_supported +from .model import Model + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class NeuronVLLMModel(Model): + """Neuron-vLLM model provider implementation.""" + + class NeuronVLLMConfig(TypedDict, total=False): + """Configuration for NeuronVLLMModel.""" + + model_id: str + max_model_len: Optional[int] + max_num_seqs: Optional[int] + tensor_parallel_size: Optional[int] + block_size: Optional[int] + enable_prefix_caching: Optional[bool] + neuron_config: Optional[Dict[str, Any]] + device: Optional[str] + temperature: Optional[float] + top_p: Optional[float] + max_tokens: Optional[int] + stop_sequences: Optional[List[str]] + additional_args: Optional[Dict[str, Any]] + openai_api_key: Optional[str] + openai_api_base: Optional[str] + + def __init__(self, config: NeuronVLLMConfig): + """Initialize the NeuronVLLMModel with the given configuration.""" + validate_config_keys(config, self.NeuronVLLMConfig) + self.config = config + self.logger = logging.getLogger(__name__) + if not config.get("model_id"): + raise ValueError("model_id is required") + self._validate_hardware() + self.logger.info("Initializing NeuronVLLMModel with model: %s", config["model_id"]) + + def _validate_hardware(self) -> None: + if importlib.util.find_spec("torch_neuronx") is not None: + self.logger.info("Neuron hardware validation passed") + else: + self.logger.warning("Neuron libraries not available - running in compatibility mode") + + @override + def update_config(self, **model_config: Unpack[NeuronVLLMConfig]) -> None: # type: ignore[override] + validate_config_keys(model_config, self.NeuronVLLMConfig) + self.config.update(model_config) + + @override + def get_config(self) -> NeuronVLLMConfig: + return self.config + + def _format_request_message_contents(self, role: str, content: ContentBlock) -> list[dict[str, Any]]: + if "text" in content: + return [{"role": role, "content": content["text"]}] + if "image" in content: + return [{"role": role, "images": [content["image"]["source"]["bytes"]]}] + if "document" in content: + doc = content["document"] + name = doc.get("name", "document") + fmt = doc.get("format", "unknown") + text = f"[Attached document: {name} ({fmt})]" + return [{"role": role, "content": text}] + if "toolUse" in content: + return [ + { + "role": role, + "tool_calls": [ + { + "function": { + "name": content["toolUse"]["toolUseId"], + "arguments": content["toolUse"]["input"], + } + } + ], + } + ] + if "toolResult" in content: + return [ + formatted + for tool_result in content["toolResult"]["content"] + for formatted in self._format_request_message_contents( + "tool", + {"text": json.dumps(tool_result["json"])} + if "json" in tool_result + else cast(ContentBlock, tool_result), + ) + ] + raise TypeError(f"Unsupported content type: {next(iter(content))}") + + def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + system_message = [{"role": "system", "content": system_prompt}] if system_prompt else [] + return system_message + [ + formatted_message + for message in messages + for content in message["content"] + for formatted_message in self._format_request_message_contents(message["role"], content) + ] + + def format_request( + self, + messages: Messages, + tool_specs: Optional[List[ToolSpec]] = None, + system_prompt: Optional[str] = None, + stream: bool = True, + ) -> dict[str, Any]: + """Return a dictionary suitable for OpenAI Async client.""" + request: dict[str, Any] = { + "messages": self._format_request_messages(messages, system_prompt), + "model": self.config["model_id"], + "temperature": self.config.get("temperature"), + "top_p": self.config.get("top_p"), + "max_tokens": self.config.get("max_tokens"), + "stop": self.config.get("stop_sequences"), + "stream": stream, + } + if tool_specs: + request["functions"] = [ + { + "name": t["name"], + "description": t["description"], + "parameters": t["inputSchema"]["json"], + } + for t in tool_specs + ] + additional_args = self.config.get("additional_args") + if additional_args: + request.update(additional_args) + return request + + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Convert raw events into StreamEvent.""" + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + case "content_start": + if event["data_type"] == "text": + return {"contentBlockStart": {"start": {}}} + tool_name = event["data"].function.name + return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_name}}}} + case "content_delta": + if event["data_type"] == "text": + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + tool_arguments = event["data"].function.arguments + return {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(tool_arguments)}}}} + case "content_stop": + return {"contentBlockStop": {}} + case "message_stop": + reason: StopReason = "tool_use" if event["data"] == "tool_use" else "end_turn" + return {"messageStop": {"stopReason": reason}} + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": 0, + "outputTokens": 0, + "totalTokens": 0, + }, + "metrics": { + "latencyMs": 0, + }, + }, + } + case _: + raise RuntimeError(f"Unknown chunk_type: {event['chunk_type']}") + + @override + async def stream( + self, + messages: Messages, + tool_specs: Optional[List[ToolSpec]] = None, + system_prompt: Optional[str] = None, + *, + tool_choice: ToolChoice | None = None, + system_prompt_content: Optional[List[SystemContentBlock]] = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + warn_on_tool_choice_not_supported(tool_choice) + + request = self.format_request(messages, tool_specs, system_prompt, stream=True) + client = AsyncOpenAI( + api_key=self.config.get("openai_api_key", "EMPTY"), + base_url=self.config.get("openai_api_base", "http://localhost:8084/v1"), + ) + + tool_requested = False + finish_reason: str | None = None + + yield self.format_chunk({"chunk_type": "message_start"}) + yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) + + stream_response = await client.chat.completions.create(**request) + async for chunk in stream_response: + choice = chunk.choices[0] + delta = choice.delta + + if delta.content: + yield self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": delta.content}) + + if delta.tool_calls: + for tool_call in delta.tool_calls: + yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_call}) + yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_call}) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool", "data": tool_call}) + tool_requested = True + + if choice.finish_reason: + finish_reason = choice.finish_reason + + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + yield self.format_chunk({"chunk_type": "message_stop", "data": "tool_use" if tool_requested else finish_reason}) + + @override + async def structured_output( + self, + output_model: Type[T], + prompt: Messages, + system_prompt: Optional[str] = None, + **kwargs: Any, + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + tool_spec: ToolSpec = { + "name": output_model.__name__, + "description": f"Return a {output_model.__name__}", + "inputSchema": {"json": output_model.model_json_schema()}, + } + request = self.format_request( + messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, stream=False + ) + request["tool_choice"] = {"type": "function", "function": {"name": tool_spec["name"]}} + + client = AsyncOpenAI( + api_key=self.config.get("openai_api_key", "EMPTY"), + base_url=self.config.get("openai_api_base", "http://localhost:8084/v1"), + ) + response = await client.chat.completions.create(**request) + + message = response.choices[0].message + if not message.tool_calls: + raise ValueError("Expected structured output via tool call") + + tool_call = message.tool_calls[0] + output = output_model.model_validate_json(tool_call.function.arguments) + yield {"output": output} diff --git a/tests/strands/models/test_neuronvllm.py b/tests/strands/models/test_neuronvllm.py new file mode 100644 index 000000000..33ec321cc --- /dev/null +++ b/tests/strands/models/test_neuronvllm.py @@ -0,0 +1,554 @@ +import json +import unittest.mock + +import pydantic +import pytest + +from strands.models.neuronvllm import NeuronVLLMModel +from strands.types.content import Messages + + +@pytest.fixture +def neuronvllm_client(monkeypatch: pytest.MonkeyPatch) -> unittest.mock.Mock: + from strands import models + + mock_client_cls = unittest.mock.Mock() + mock_client = unittest.mock.AsyncMock() + mock_client.chat.completions.create = unittest.mock.AsyncMock() + mock_client_cls.return_value = mock_client + + monkeypatch.setattr(models.neuronvllm, "AsyncOpenAI", mock_client_cls) + return mock_client + + +@pytest.fixture +def model_id() -> str: + return "m1" + + +@pytest.fixture +def model(model_id: str) -> NeuronVLLMModel: + return NeuronVLLMModel({"model_id": model_id}) + + +@pytest.fixture +def messages() -> Messages: + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def system_prompt() -> str: + return "s1" + + +@pytest.fixture +def test_output_model_cls() -> type[pydantic.BaseModel]: + class TestOutputModel(pydantic.BaseModel): + name: str + age: int + + return TestOutputModel + + +def test__init__model_configs(model_id: str) -> None: + model = NeuronVLLMModel({"model_id": model_id, "max_tokens": 1}) + + tru_max_tokens = model.get_config().get("max_tokens") + exp_max_tokens = 1 + + assert tru_max_tokens == exp_max_tokens + + +def test_update_config(model: NeuronVLLMModel, model_id: str) -> None: + model.update_config(model_id=model_id) + + tru_model_id = model.get_config().get("model_id") + exp_model_id = model_id + + assert tru_model_id == exp_model_id + + +def test_format_request_default(model: NeuronVLLMModel, messages: Messages, model_id: str) -> None: + tru_request = model.format_request(messages) + exp_request = { + "messages": [{"role": "user", "content": "test"}], + "model": model_id, + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + } + + assert tru_request == exp_request + + +def test_format_request_with_override(model: NeuronVLLMModel, messages: Messages, model_id: str) -> None: + model.update_config(model_id=model_id) + tru_request = model.format_request(messages, tool_specs=None) + exp_request = { + "messages": [{"role": "user", "content": "test"}], + "model": model_id, + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + } + + assert tru_request == exp_request + + +def test_format_request_with_system_prompt( + model: NeuronVLLMModel, messages: Messages, model_id: str, system_prompt: str +) -> None: + tru_request = model.format_request(messages, system_prompt=system_prompt) + exp_request = { + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": "test"}, + ], + "model": model_id, + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + } + + assert tru_request == exp_request + + +def test_format_request_with_image(model: NeuronVLLMModel, model_id: str) -> None: + messages: Messages = [{"role": "user", "content": [{"image": {"source": {"bytes": "base64encodedimage"}}}]}] + + tru_request = model.format_request(messages) + exp_request = { + "messages": [{"role": "user", "images": ["base64encodedimage"]}], + "model": model_id, + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + } + + assert tru_request == exp_request + + +def test_format_request_with_tool_use(model: NeuronVLLMModel, model_id: str) -> None: + messages: Messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "calculator", "input": '{"expression": "2+2"}'}}]} + ] + + tru_request = model.format_request(messages) + exp_request = { + "messages": [ + { + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "calculator", + "arguments": '{"expression": "2+2"}', + } + } + ], + } + ], + "model": model_id, + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + } + + assert tru_request == exp_request + + +def test_format_request_with_tool_result(model: NeuronVLLMModel, model_id: str) -> None: + messages: Messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "calculator", + "status": "success", + "content": [ + {"text": "4"}, + {"image": {"source": {"bytes": b"image"}}}, + {"json": ["4"]}, + ], + }, + }, + { + "text": "see results", + }, + ], + }, + ] + + tru_request = model.format_request(messages) + exp_request = { + "messages": [ + { + "role": "tool", + "content": "4", + }, + { + "role": "tool", + "images": [b"image"], + }, + { + "role": "tool", + "content": '["4"]', + }, + { + "role": "user", + "content": "see results", + }, + ], + "model": model_id, + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + } + + assert tru_request == exp_request + + +def test_format_request_with_unsupported_type(model: NeuronVLLMModel) -> None: + messages: Messages = [ + { + "role": "user", + "content": [{"unsupported": {}}], + }, + ] + + with pytest.raises(TypeError, match="Unsupported content type: unsupported"): + model.format_request(messages) + + +def test_format_request_with_tool_specs(model: NeuronVLLMModel, messages: Messages, model_id: str) -> None: + tool_specs = [ + { + "name": "calculator", + "description": "Calculate mathematical expressions", + "inputSchema": { + "json": {"type": "object", "properties": {"expression": {"type": "string"}}, "required": ["expression"]} + }, + } + ] + + tru_request = model.format_request(messages, tool_specs) + exp_request = { + "messages": [{"role": "user", "content": "test"}], + "model": model_id, + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + "functions": [ + { + "name": "calculator", + "description": "Calculate mathematical expressions", + "parameters": { + "type": "object", + "properties": {"expression": {"type": "string"}}, + "required": ["expression"], + }, + } + ], + } + + assert tru_request == exp_request + + +def test_format_request_with_inference_config(model: NeuronVLLMModel, messages: Messages, model_id: str) -> None: + inference_config = { + "max_tokens": 1, + "stop_sequences": ["stop"], + "temperature": 1.0, + "top_p": 1.0, + } + + model.update_config(**inference_config) + tru_request = model.format_request(messages) + exp_request = { + "messages": [{"role": "user", "content": "test"}], + "model": model_id, + "temperature": inference_config["temperature"], + "top_p": inference_config["top_p"], + "max_tokens": inference_config["max_tokens"], + "stop": inference_config["stop_sequences"], + "stream": True, + } + + assert tru_request == exp_request + + +def test_format_request_with_additional_args(model: NeuronVLLMModel, messages: Messages, model_id: str) -> None: + additional_args = {"o1": 1} + + model.update_config(additional_args=additional_args) + tru_request = model.format_request(messages) + exp_request = { + "messages": [{"role": "user", "content": "test"}], + "model": model_id, + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + "o1": 1, + } + + assert tru_request == exp_request + + +def test_format_chunk_message_start(model: NeuronVLLMModel) -> None: + event = {"chunk_type": "message_start"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"messageStart": {"role": "assistant"}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_start_text(model: NeuronVLLMModel) -> None: + event = {"chunk_type": "content_start", "data_type": "text"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockStart": {"start": {}}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_start_tool(model: NeuronVLLMModel) -> None: + mock_function = unittest.mock.Mock() + mock_function.function.name = "calculator" + + event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_function} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "calculator"}}}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_delta_text(model: NeuronVLLMModel) -> None: + event = {"chunk_type": "content_delta", "data_type": "text", "data": "Hello"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_delta_tool(model: NeuronVLLMModel) -> None: + event = { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments={"expression": "2+2"})), + } + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps({"expression": "2+2"})}}}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_content_stop(model: NeuronVLLMModel) -> None: + event = {"chunk_type": "content_stop"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"contentBlockStop": {}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_message_stop_end_turn(model: NeuronVLLMModel) -> None: + event = {"chunk_type": "message_stop", "data": "stop"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"messageStop": {"stopReason": "end_turn"}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_message_stop_tool_use(model: NeuronVLLMModel) -> None: + event = {"chunk_type": "message_stop", "data": "tool_use"} + + tru_chunk = model.format_chunk(event) + exp_chunk = {"messageStop": {"stopReason": "tool_use"}} + + assert tru_chunk == exp_chunk + + +def test_format_chunk_metadata(model: NeuronVLLMModel) -> None: + event = { + "chunk_type": "metadata", + "data": unittest.mock.Mock(), + } + + tru_chunk = model.format_chunk(event) + exp_chunk = { + "metadata": { + "usage": { + "inputTokens": 0, + "outputTokens": 0, + "totalTokens": 0, + }, + "metrics": { + "latencyMs": 0, + }, + }, + } + + assert tru_chunk == exp_chunk + + +def test_format_chunk_other(model: NeuronVLLMModel) -> None: + event = {"chunk_type": "other"} + + with pytest.raises(RuntimeError, match="Unknown chunk_type: other"): + model.format_chunk(event) + + +@pytest.mark.asyncio +async def test_stream( + neuronvllm_client: unittest.mock.Mock, + model: NeuronVLLMModel, + agenerator, + alist, +) -> None: + mock_chunk = unittest.mock.Mock() + mock_choice = unittest.mock.Mock() + mock_delta = unittest.mock.Mock() + mock_delta.content = "Hello" + mock_delta.tool_calls = None + mock_choice.delta = mock_delta + mock_choice.finish_reason = "stop" + mock_chunk.choices = [mock_choice] + + neuronvllm_client.chat.completions.create.return_value = agenerator([mock_chunk]) + + messages: Messages = [{"role": "user", "content": [{"text": "Hello"}]}] + response = model.stream(messages) + + tru_events = await alist(response) + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "Hello"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + + assert tru_events == exp_events + + expected_request = { + "messages": [{"role": "user", "content": "Hello"}], + "model": "m1", + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + } + neuronvllm_client.chat.completions.create.assert_awaited_once_with(**expected_request) + + +@pytest.mark.asyncio +async def test_stream_with_tool_calls( + neuronvllm_client: unittest.mock.Mock, + model: NeuronVLLMModel, + agenerator, + alist, +) -> None: + mock_chunk = unittest.mock.Mock() + mock_choice = unittest.mock.Mock() + mock_delta = unittest.mock.Mock() + + mock_tool_call = unittest.mock.Mock() + mock_tool_call.function.name = "calculator" + mock_tool_call.function.arguments = {"expression": "2+2"} + + mock_delta.content = "I'll calculate that for you" + mock_delta.tool_calls = [mock_tool_call] + mock_choice.delta = mock_delta + mock_choice.finish_reason = "stop" + mock_chunk.choices = [mock_choice] + + neuronvllm_client.chat.completions.create.return_value = agenerator([mock_chunk]) + + messages: Messages = [{"role": "user", "content": [{"text": "Calculate 2+2"}]}] + response = model.stream(messages) + + tru_events = await alist(response) + + # Basic structural checks: first start, last stop + assert tru_events[0] == {"messageStart": {"role": "assistant"}} + assert tru_events[1] == {"contentBlockStart": {"start": {}}} + assert tru_events[-1] == {"messageStop": {"stopReason": "tool_use"}} + + # One toolUse start with expected name/id + tool_starts = [e for e in tru_events if e.get("contentBlockStart", {}).get("start", {}).get("toolUse") is not None] + assert len(tool_starts) == 1 + tool_use = tool_starts[0]["contentBlockStart"]["start"]["toolUse"] + assert tool_use["name"] == "calculator" + assert tool_use["toolUseId"] == "calculator" + + # One toolUse delta with expected input + tool_deltas = [e for e in tru_events if "contentBlockDelta" in e and "toolUse" in e["contentBlockDelta"]["delta"]] + assert len(tool_deltas) == 1 + assert tool_deltas[0]["contentBlockDelta"]["delta"]["toolUse"]["input"] == '{"expression": "2+2"}' + + # One text delta with the assistant message + text_deltas = [e for e in tru_events if "contentBlockDelta" in e and "text" in e["contentBlockDelta"]["delta"]] + assert len(text_deltas) == 1 + assert text_deltas[0]["contentBlockDelta"]["delta"]["text"] == "I'll calculate that for you" + + expected_request = { + "messages": [{"role": "user", "content": "Calculate 2+2"}], + "model": "m1", + "temperature": None, + "top_p": None, + "max_tokens": None, + "stop": None, + "stream": True, + } + neuronvllm_client.chat.completions.create.assert_awaited_once_with(**expected_request) + + +@pytest.mark.asyncio +async def test_structured_output( + neuronvllm_client: unittest.mock.Mock, + model: NeuronVLLMModel, + test_output_model_cls: type[pydantic.BaseModel], + alist, +) -> None: + messages: Messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + mock_response = unittest.mock.Mock() + mock_tool_call = unittest.mock.Mock() + mock_tool_call.function.arguments = '{"name": "John", "age": 30}' + mock_message = unittest.mock.Mock() + mock_message.tool_calls = [mock_tool_call] + mock_choice = unittest.mock.Mock() + mock_choice.message = mock_message + mock_response.choices = [mock_choice] + + neuronvllm_client.chat.completions.create = unittest.mock.AsyncMock(return_value=mock_response) + + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + + tru_result = events[-1] + exp_result = {"output": test_output_model_cls(name="John", age=30)} + assert tru_result == exp_result diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 75cc58f74..080e09c94 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -14,6 +14,7 @@ from strands.models.litellm import LiteLLMModel from strands.models.llamaapi import LlamaAPIModel from strands.models.mistral import MistralModel +from strands.models.neuronvllm import NeuronVLLMModel from strands.models.ollama import OllamaModel from strands.models.openai import OpenAIModel from strands.models.writer import WriterModel @@ -59,6 +60,36 @@ def __init__(self): ) +class NeuronVLLMProviderInfo(ProviderInfo): + """Special case Neuron vLLM as it's dependent on the server being available.""" + + def __init__(self) -> None: + base_url = os.getenv("NEURON_VLLM_BASE_URL", "http://localhost:8084/v1") + model_id = os.getenv("NEURON_VLLM_MODEL_ID", "neuron-llama-3.1-70b-instruct") + + super().__init__( + id="neuronvllm", + factory=lambda: NeuronVLLMModel( + { + "model_id": model_id, + "openai_api_base": base_url, + "openai_api_key": os.getenv("NEURON_VLLM_API_KEY", "EMPTY"), + } + ), + ) + + is_server_available = False + try: + is_server_available = requests.get(base_url).ok + except requests.exceptions.ConnectionError: + pass + + self.mark = mark.skipif( + not is_server_available, + reason=f"Local Neuron vLLM endpoint not available at {base_url}", + ) + + anthropic = ProviderInfo( id="anthropic", environment_variable="ANTHROPIC_API_KEY", @@ -138,6 +169,7 @@ def __init__(self): ) ollama = OllamaProviderInfo() +neuronvllm = NeuronVLLMProviderInfo() all_providers = [ @@ -150,4 +182,5 @@ def __init__(self): mistral, openai, writer, + neuronvllm, ] diff --git a/tests_integ/models/test_model_neuronvllm.py b/tests_integ/models/test_model_neuronvllm.py new file mode 100644 index 000000000..c1fd3d7ee --- /dev/null +++ b/tests_integ/models/test_model_neuronvllm.py @@ -0,0 +1,99 @@ +import os + +import pytest +from pydantic import BaseModel + +import strands +from strands import Agent +from strands.models.neuronvllm import NeuronVLLMModel +from tests_integ.models import providers + +# these tests only run if we have the neuron-vLLM server running +pytestmark = providers.neuronvllm.mark + + +@pytest.fixture +def model() -> NeuronVLLMModel: + base_url = os.getenv("OPENAI_API_BASE_URL", "http://localhost:8084/v1") + model_id = os.getenv("NEURON_VLLM_MODEL_ID", "neuron-llama-3.1-70b-instruct") + api_key = os.getenv("OPENAI_API_KEY", "EMPTY") + + return NeuronVLLMModel( + { + "model_id": model_id, + "openai_api_base": base_url, + "openai_api_key": api_key, + } + ) + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def agent(model: NeuronVLLMModel, tools): + return Agent(model=model, tools=tools) + + +@pytest.fixture +def weather(): + class Weather(BaseModel): + """Extracts the time and weather from the user's message with the exact strings.""" + + time: str + weather: str + + return Weather(time="12:00", weather="sunny") + + +def test_agent_invoke(agent): + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +async def test_agent_invoke_async(agent): + result = await agent.invoke_async("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +async def test_agent_stream_async(agent): + stream = agent.stream_async("What is the time and weather in New York?") + async for event in stream: + _ = event + + result = event["result"] + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +def test_agent_structured_output(agent, weather): + tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") + exp_weather = weather + assert tru_weather == exp_weather + + +@pytest.mark.asyncio +async def test_agent_structured_output_async(agent, weather): + tru_weather = await agent.structured_output_async( + type(weather), + "The time is 12:00 and the weather is sunny", + ) + exp_weather = weather + assert tru_weather == exp_weather