From 2b533760841771f13cbdcf761d739b7271e50f68 Mon Sep 17 00:00:00 2001 From: notgitika Date: Fri, 3 Oct 2025 17:17:32 +0000 Subject: [PATCH 1/3] feat: add OpenAI Responses API model implementation --- README.md | 1 + src/strands/models/openai_responses.py | 529 +++++++++++++++++ tests/strands/models/test_openai_responses.py | 538 ++++++++++++++++++ tests_integ/models/providers.py | 12 + tests_integ/models/test_model_openai.py | 49 +- 5 files changed, 1118 insertions(+), 11 deletions(-) create mode 100644 src/strands/models/openai_responses.py create mode 100644 tests/strands/models/test_openai_responses.py diff --git a/README.md b/README.md index e7d1b2a7e..b17412e20 100644 --- a/README.md +++ b/README.md @@ -179,6 +179,7 @@ Built-in providers: - [MistralAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/mistral/) - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/) - [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) + - [OpenAI Responses API](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) - [SageMaker](https://strandsagents.com/latest/user-guide/concepts/model-providers/sagemaker/) - [Writer](https://strandsagents.com/latest/user-guide/concepts/model-providers/writer/) diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py new file mode 100644 index 000000000..fae95833c --- /dev/null +++ b/src/strands/models/openai_responses.py @@ -0,0 +1,529 @@ +"""OpenAI model provider using the Responses API. + +- Docs: https://platform.openai.com/docs/overview +""" + +import base64 +import json +import logging +import mimetypes +from typing import Any, AsyncGenerator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast + +import openai +from pydantic import BaseModel +from typing_extensions import Unpack, override + +from ..types.content import ContentBlock, Messages +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException +from ..types.streaming import StreamEvent +from ..types.tools import ToolResult, ToolSpec, ToolUse +from ._validation import validate_config_keys +from .model import Model + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class Client(Protocol): + """Protocol defining the OpenAI Responses API interface for the underlying provider client.""" + + @property + # pragma: no cover + def responses(self) -> Any: + """Responses interface.""" + ... + + +class OpenAIResponsesModel(Model): + """OpenAI Responses API model provider implementation.""" + + client: Client + client_args: dict[str, Any] + + class OpenAIResponsesConfig(TypedDict, total=False): + """Configuration options for OpenAI Responses API models. + + Attributes: + model_id: Model ID (e.g., "gpt-4o"). + For a complete list of supported models, see https://platform.openai.com/docs/models. + params: Model parameters (e.g., max_output_tokens, temperature, etc.). + For a complete list of supported parameters, see + https://platform.openai.com/docs/api-reference/responses/create. + """ + + model_id: str + params: Optional[dict[str, Any]] + + def __init__( + self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIResponsesConfig] + ) -> None: + """Initialize provider instance. + + Args: + client_args: Arguments for the OpenAI client. + For a complete list of supported arguments, see https://pypi.org/project/openai/. + **model_config: Configuration options for the OpenAI Responses API model. + """ + validate_config_keys(model_config, self.OpenAIResponsesConfig) + self.config = dict(model_config) + self.client_args = client_args or {} + + logger.debug("config=<%s> | initializing", self.config) + + @override + def update_config(self, **model_config: Unpack[OpenAIResponsesConfig]) -> None: # type: ignore[override] + """Update the OpenAI Responses API model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + validate_config_keys(model_config, self.OpenAIResponsesConfig) + self.config.update(model_config) + + @override + def get_config(self) -> OpenAIResponsesConfig: + """Get the OpenAI Responses API model configuration. + + Returns: + The OpenAI Responses API model configuration. + """ + return cast(OpenAIResponsesModel.OpenAIResponsesConfig, self.config) + + @classmethod + def _format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + """Format an OpenAI compatible content block. + + Args: + content: Message content. + + Returns: + OpenAI compatible content block. + + Raises: + TypeError: If the content block type cannot be converted to an OpenAI-compatible format. + """ + if "document" in content: + # only PDF type supported + mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") + file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") + return { + "type": "input_file", + "file_url": f"data:{mime_type};base64,{file_data}", + } + + if "image" in content: + mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") + image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") + + return { + "type": "input_image", + "image_url": f"data:{mime_type};base64,{image_data}", + } + + if "text" in content: + return {"type": "input_text", "text": content["text"]} + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + @classmethod + def _format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: + """Format an OpenAI compatible tool call. + + Args: + tool_use: Tool use requested by the model. + + Returns: + OpenAI compatible tool call. + """ + return { + "type": "function_call", + "call_id": tool_use["toolUseId"], + "name": tool_use["name"], + "arguments": json.dumps(tool_use["input"]), + } + + @classmethod + def _format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + """Format an OpenAI compatible tool message. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + OpenAI compatible tool message. + """ + output_parts = [] + + for content in tool_result["content"]: + if "json" in content: + output_parts.append(json.dumps(content["json"])) + elif "text" in content: + output_parts.append(content["text"]) + + return { + "type": "function_call_output", + "call_id": tool_result["toolUseId"], + "output": "\n".join(output_parts) if output_parts else "", + } + + @classmethod + def _format_request_messages(cls, messages: Messages) -> list[dict[str, Any]]: + """Format an OpenAI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + + Returns: + An OpenAI compatible messages array. + """ + formatted_messages: list[dict[str, Any]] = [] + + for message in messages: + role = message["role"] + if role == "system": + continue # type: ignore[unreachable] + + contents = message["content"] + + formatted_contents = [ + cls._format_request_message_content(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + ] + + formatted_tool_calls = [ + cls._format_request_message_tool_call(content["toolUse"]) + for content in contents + if "toolUse" in content + ] + + formatted_tool_messages = [ + cls._format_request_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + if formatted_contents: + formatted_messages.append( + { + "role": role, # "user" | "assistant" + "content": formatted_contents, + } + ) + + formatted_messages.extend(formatted_tool_calls) + formatted_messages.extend(formatted_tool_messages) + + return [ + message + for message in formatted_messages + if message.get("content") or message.get("type") in ["function_call", "function_call_output"] + ] + + def _format_request( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + ) -> dict[str, Any]: + """Format an OpenAI Responses API compatible response streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An OpenAI Responses API compatible response streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to an OpenAI-compatible + format. + """ + input_items = self._format_request_messages(messages) + request = { + "model": self.config["model_id"], + "input": input_items, + "stream": True, + **cast(dict[str, Any], self.config.get("params", {})), + } + + if system_prompt: + request["instructions"] = system_prompt + + # Add tools if provided + if tool_specs: + request["tools"] = [ + { + "type": "function", + "name": tool_spec["name"], + "description": tool_spec.get("description", ""), + "parameters": tool_spec["inputSchema"]["json"], + } + for tool_spec in tool_specs + ] + + return request + + def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format an OpenAI response event into a standardized message chunk. + + Args: + event: A response event from the OpenAI compatible model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + This error should never be encountered as chunk_type is controlled in the stream method. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "tool": + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": event["data"].function.name, + "toolUseId": event["data"].id, + } + } + } + } + + return {"contentBlockStart": {"start": {}}} + + case "content_delta": + if event["data_type"] == "tool": + return { + "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}} + } + + if event["data_type"] == "reasoning_content": + return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}} + + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event["data"]: + case "tool_calls": + return {"messageStop": {"stopReason": "tool_use"}} + case "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + }, + "metrics": { + "latencyMs": 0, # TODO + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']}> | unknown type") + + @override + async def stream( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Stream conversation with the OpenAI Responses API model. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Formatted message chunks from the model. + + Raises: + ContextWindowOverflowException: If the input exceeds the model's context window. + ModelThrottledException: If the request is throttled by OpenAI (rate limits). + """ + logger.debug("formatting request for OpenAI Responses API") + request = self._format_request(messages, tool_specs, system_prompt) + logger.debug("formatted request=<%s>", request) + + logger.debug("invoking OpenAI Responses API model") + + async with openai.AsyncOpenAI(**self.client_args) as client: + try: + response = await client.responses.create(**request) + except openai.APIError as e: + if hasattr(e, "code") and e.code == "context_length_exceeded": + logger.warning("OpenAI Responses API threw context window overflow error") + raise ContextWindowOverflowException(str(e)) from e + elif hasattr(e, "code") and e.code == "rate_limit_exceeded": + logger.warning("OpenAI Responses API threw rate limit error") + raise ModelThrottledException(str(e)) from e + else: + raise + + logger.debug("got response from OpenAI Responses API model") + + yield self._format_chunk({"chunk_type": "message_start"}) + + tool_calls: dict[str, dict[str, Any]] = {} + final_usage = None + has_text_content = False + + try: + async for event in response: + if hasattr(event, "type"): + if event.type == "response.output_text.delta": + # Text content streaming + if not has_text_content: + yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"}) + has_text_content = True + if hasattr(event, "delta") and isinstance(event.delta, str): + has_text_content = True + yield self._format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": event.delta} + ) + + elif event.type == "response.output_item.added": + # Tool call started + if ( + hasattr(event, "item") + and hasattr(event.item, "type") + and event.item.type == "function_call" + ): + call_id = getattr(event.item, "call_id", "unknown") + tool_calls[call_id] = { + "name": getattr(event.item, "name", ""), + "arguments": "", + "call_id": call_id, + "item_id": getattr(event.item, "id", ""), + } + + elif event.type == "response.function_call_arguments.delta": + # Tool arguments streaming - match by item_id + if hasattr(event, "delta") and hasattr(event, "item_id"): + for _call_id, call_info in tool_calls.items(): + if call_info["item_id"] == event.item_id: + call_info["arguments"] += event.delta + break + + elif event.type == "response.completed": + # Response complete + if hasattr(event, "response") and hasattr(event.response, "usage"): + final_usage = event.response.usage + break + except openai.APIError as e: + if hasattr(e, "code") and e.code == "context_length_exceeded": + logger.warning("OpenAI Responses API threw context window overflow error") + raise ContextWindowOverflowException(str(e)) from e + elif hasattr(e, "code") and e.code == "rate_limit_exceeded": + logger.warning("OpenAI Responses API threw rate limit error") + raise ModelThrottledException(str(e)) from e + else: + raise + + # Close text content if we had any + if has_text_content: + yield self._format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + + # Yield tool calls if any + for call_info in tool_calls.values(): + mock_tool_call = type( + "MockToolCall", + (), + { + "function": type( + "MockFunction", (), {"name": call_info["name"], "arguments": call_info["arguments"]} + )(), + "id": call_info["call_id"], + }, + )() + + yield self._format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call}) + yield self._format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call}) + yield self._format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + + finish_reason = "tool_calls" if tool_calls else "stop" + yield self._format_chunk({"chunk_type": "message_stop", "data": finish_reason}) + + if final_usage: + usage_data = type( + "Usage", + (), + { + "prompt_tokens": getattr(final_usage, "input_tokens", 0), + "completion_tokens": getattr(final_usage, "output_tokens", 0), + "total_tokens": getattr(final_usage, "total_tokens", 0), + }, + )() + yield self._format_chunk({"chunk_type": "metadata", "data": usage_data}) + + logger.debug("finished streaming response from OpenAI Responses API model") + + @override + async def structured_output( + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + """Get structured output from the OpenAI Responses API model. + + Args: + output_model: The output model to use for the agent. + prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Model events with the last being the structured output. + + Raises: + ContextWindowOverflowException: If the input exceeds the model's context window. + ModelThrottledException: If the request is throttled by OpenAI (rate limits). + """ + async with openai.AsyncOpenAI(**self.client_args) as client: + try: + response = await client.responses.parse( + model=self.get_config()["model_id"], + input=self._format_request(prompt, system_prompt=system_prompt)["input"], + text_format=output_model, + ) + except openai.BadRequestError as e: + if hasattr(e, "code") and e.code == "context_length_exceeded": + logger.warning("OpenAI Responses API threw context window overflow error") + raise ContextWindowOverflowException(str(e)) from e + raise + except openai.RateLimitError as e: + logger.warning("OpenAI Responses API threw rate limit error") + raise ModelThrottledException(str(e)) from e + except openai.APIError as e: + # Handle streaming API errors that come as APIError + error_message = str(e).lower() + if "context window" in error_message or "exceeds the context" in error_message: + logger.warning("OpenAI Responses API threw context window overflow error") + raise ContextWindowOverflowException(str(e)) from e + elif "rate limit" in error_message or "tokens per min" in error_message: + logger.warning("OpenAI Responses API threw rate limit error") + raise ModelThrottledException(str(e)) from e + raise + + if response.output_parsed: + yield {"output": response.output_parsed} + else: + raise ValueError("No valid parsed output found in the OpenAI Responses API response.") diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py new file mode 100644 index 000000000..eb78217d8 --- /dev/null +++ b/tests/strands/models/test_openai_responses.py @@ -0,0 +1,538 @@ +import unittest.mock + +import openai +import pydantic +import pytest + +import strands +from strands.models.openai_responses import OpenAIResponsesModel +from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException + + +@pytest.fixture +def openai_client(): + with unittest.mock.patch.object(strands.models.openai_responses.openai, "AsyncOpenAI") as mock_client_cls: + mock_client = unittest.mock.AsyncMock() + mock_client_cls.return_value.__aenter__.return_value = mock_client + yield mock_client + + +@pytest.fixture +def model_id(): + return "gpt-4o" + + +@pytest.fixture +def model(openai_client, model_id): + _ = openai_client + return OpenAIResponsesModel(model_id=model_id, params={"max_output_tokens": 100}) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def tool_specs(): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + }, + ] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +@pytest.fixture +def test_output_model_cls(): + class TestOutputModel(pydantic.BaseModel): + name: str + age: int + + return TestOutputModel + + +def test__init__(model_id): + model = OpenAIResponsesModel(model_id=model_id, params={"max_output_tokens": 100}) + + tru_config = model.get_config() + exp_config = {"model_id": "gpt-4o", "params": {"max_output_tokens": 100}} + + assert tru_config == exp_config + + +def test_update_config(model, model_id): + model.update_config(model_id=model_id) + + tru_model_id = model.get_config().get("model_id") + exp_model_id = model_id + + assert tru_model_id == exp_model_id + + +@pytest.mark.parametrize( + "content, exp_result", + [ + # Document + ( + { + "document": { + "format": "pdf", + "name": "test doc", + "source": {"bytes": b"document"}, + }, + }, + { + "type": "input_file", + "file_url": "data:application/pdf;base64,ZG9jdW1lbnQ=", + }, + ), + # Image + ( + { + "image": { + "format": "jpg", + "source": {"bytes": b"image"}, + }, + }, + { + "type": "input_image", + "image_url": "", + }, + ), + # Text + ( + {"text": "hello"}, + {"type": "input_text", "text": "hello"}, + ), + ], +) +def test_format_request_message_content(content, exp_result): + tru_result = OpenAIResponsesModel._format_request_message_content(content) + assert tru_result == exp_result + + +def test_format_request_message_content_unsupported_type(): + content = {"unsupported": {}} + + with pytest.raises(TypeError, match="content_type= | unsupported type"): + OpenAIResponsesModel._format_request_message_content(content) + + +def test_format_request_message_tool_call(): + tool_use = { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + } + + tru_result = OpenAIResponsesModel._format_request_message_tool_call(tool_use) + exp_result = { + "type": "function_call", + "call_id": "c1", + "name": "calculator", + "arguments": '{"expression": "2+2"}', + } + assert tru_result == exp_result + + +def test_format_request_tool_message(): + tool_result = { + "content": [{"text": "4"}, {"json": ["4"]}], + "status": "success", + "toolUseId": "c1", + } + + tru_result = OpenAIResponsesModel._format_request_tool_message(tool_result) + exp_result = { + "type": "function_call_output", + "call_id": "c1", + "output": '4\n["4"]', + } + assert tru_result == exp_result + + +def test_format_request_messages(system_prompt): + messages = [ + { + "content": [], + "role": "user", + }, + { + "content": [{"text": "hello"}], + "role": "user", + }, + { + "content": [ + {"text": "call tool"}, + { + "toolUse": { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + }, + }, + ], + "role": "assistant", + }, + { + "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"text": "4"}]}}], + "role": "user", + }, + ] + + tru_result = OpenAIResponsesModel._format_request_messages(messages) + exp_result = [ + { + "role": "user", + "content": [{"type": "input_text", "text": "hello"}], + }, + { + "role": "assistant", + "content": [{"type": "input_text", "text": "call tool"}], + }, + { + "type": "function_call", + "call_id": "c1", + "name": "calculator", + "arguments": '{"expression": "2+2"}', + }, + { + "type": "function_call_output", + "call_id": "c1", + "output": "4", + }, + ] + assert tru_result == exp_result + + +def test_format_request(model, messages, tool_specs, system_prompt): + tru_request = model._format_request(messages, tool_specs, system_prompt) + exp_request = { + "model": "gpt-4o", + "input": [ + { + "role": "user", + "content": [{"type": "input_text", "text": "test"}], + } + ], + "stream": True, + "instructions": system_prompt, + "tools": [ + { + "type": "function", + "name": "test_tool", + "description": "A test tool", + "parameters": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + ], + "max_output_tokens": 100, + } + assert tru_request == exp_request + + +@pytest.mark.parametrize( + ("event", "exp_chunk"), + [ + # Message start + ( + {"chunk_type": "message_start"}, + {"messageStart": {"role": "assistant"}}, + ), + # Content Start - Tool Use + ( + { + "chunk_type": "content_start", + "data_type": "tool", + "data": unittest.mock.Mock(**{"function.name": "calculator", "id": "c1"}), + }, + {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}}, + ), + # Content Start - Text + ( + {"chunk_type": "content_start", "data_type": "text"}, + {"contentBlockStart": {"start": {}}}, + ), + # Content Delta - Tool Use + ( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments='{"expression": "2+2"}')), + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, + ), + # Content Delta - Tool Use - None + ( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments=None)), + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}}}, + ), + # Content Delta - Reasoning Text + ( + {"chunk_type": "content_delta", "data_type": "reasoning_content", "data": "I'm thinking"}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "I'm thinking"}}}}, + ), + # Content Delta - Text + ( + {"chunk_type": "content_delta", "data_type": "text", "data": "hello"}, + {"contentBlockDelta": {"delta": {"text": "hello"}}}, + ), + # Content Stop + ( + {"chunk_type": "content_stop"}, + {"contentBlockStop": {}}, + ), + # Message Stop - Tool Use + ( + {"chunk_type": "message_stop", "data": "tool_calls"}, + {"messageStop": {"stopReason": "tool_use"}}, + ), + # Message Stop - Max Tokens + ( + {"chunk_type": "message_stop", "data": "length"}, + {"messageStop": {"stopReason": "max_tokens"}}, + ), + # Message Stop - End Turn + ( + {"chunk_type": "message_stop", "data": "stop"}, + {"messageStop": {"stopReason": "end_turn"}}, + ), + # Metadata + ( + { + "chunk_type": "metadata", + "data": unittest.mock.Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150), + }, + { + "metadata": { + "usage": { + "inputTokens": 100, + "outputTokens": 50, + "totalTokens": 150, + }, + "metrics": { + "latencyMs": 0, + }, + }, + }, + ), + ], +) +def test_format_chunk(event, exp_chunk, model): + tru_chunk = model._format_chunk(event) + assert tru_chunk == exp_chunk + + +def test_format_chunk_unknown_type(model): + event = {"chunk_type": "unknown"} + + with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): + model._format_chunk(event) + + +@pytest.mark.asyncio +async def test_stream(openai_client, model_id, model, agenerator, alist): + # Mock response events + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hello") + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_text_event, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "Hello"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "metrics": {"latencyMs": 0}, + } + }, + ] + + assert len(tru_events) == len(exp_events) + expected_request = { + "model": model_id, + "input": [{"role": "user", "content": [{"type": "input_text", "text": "test"}]}], + "stream": True, + "max_output_tokens": 100, + } + openai_client.responses.create.assert_called_once_with(**expected_request) + + +@pytest.mark.asyncio +async def test_stream_with_tool_calls(openai_client, model, agenerator, alist): + # Mock tool call events + mock_tool_event = unittest.mock.Mock( + type="response.output_item.added", + item=unittest.mock.Mock(type="function_call", call_id="call_123", name="calculator", id="item_456"), + ) + mock_args_event = unittest.mock.Mock( + type="response.function_call_arguments.delta", delta='{"expression": "2+2"}', item_id="item_456" + ) + mock_complete_event = unittest.mock.Mock( + type="response.completed", + response=unittest.mock.Mock(usage=unittest.mock.Mock(input_tokens=10, output_tokens=5, total_tokens=15)), + ) + + openai_client.responses.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_tool_event, mock_args_event, mock_complete_event]) + ) + + messages = [{"role": "user", "content": [{"text": "calculate 2+2"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + # Should include tool call events + assert any("toolUse" in str(event) for event in tru_events) + assert {"messageStop": {"stopReason": "tool_use"}} in tru_events + + +@pytest.mark.asyncio +async def test_structured_output(openai_client, model, test_output_model_cls, alist): + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + mock_parsed_instance = test_output_model_cls(name="John", age=30) + mock_response = unittest.mock.Mock(output_parsed=mock_parsed_instance) + + openai_client.responses.parse = unittest.mock.AsyncMock(return_value=mock_response) + + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + + tru_result = events[-1] + exp_result = {"output": test_output_model_cls(name="John", age=30)} + assert tru_result == exp_result + + +@pytest.mark.asyncio +async def test_stream_context_overflow_exception(openai_client, model, messages): + """Test that OpenAI context overflow errors are properly converted to ContextWindowOverflowException.""" + mock_error = openai.BadRequestError( + message="This model's maximum context length is 4096 tokens.", + response=unittest.mock.MagicMock(), + body={"error": {"code": "context_length_exceeded"}}, + ) + mock_error.code = "context_length_exceeded" + + openai_client.responses.create.side_effect = mock_error + + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "maximum context length" in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_stream_rate_limit_as_throttle(openai_client, model, messages): + """Test that rate limit errors are converted to ModelThrottledException.""" + mock_error = openai.RateLimitError( + message="Rate limit exceeded", + response=unittest.mock.MagicMock(), + body={"error": {"code": "rate_limit_exceeded"}}, + ) + mock_error.code = "rate_limit_exceeded" + + openai_client.responses.create.side_effect = mock_error + + with pytest.raises(ModelThrottledException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "Rate limit exceeded" in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_structured_output_context_overflow_exception(openai_client, model, messages, test_output_model_cls): + """Test that structured output handles context overflow properly.""" + mock_error = openai.BadRequestError( + message="This model's maximum context length is 4096 tokens.", + response=unittest.mock.MagicMock(), + body={"error": {"code": "context_length_exceeded"}}, + ) + mock_error.code = "context_length_exceeded" + + openai_client.responses.parse.side_effect = mock_error + + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.structured_output(test_output_model_cls, messages): + pass + + assert "maximum context length" in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_structured_output_rate_limit_as_throttle(openai_client, model, messages, test_output_model_cls): + """Test that structured output handles rate limit errors properly.""" + mock_error = openai.RateLimitError( + message="Rate limit exceeded", + response=unittest.mock.MagicMock(), + body={"error": {"code": "rate_limit_exceeded"}}, + ) + mock_error.code = "rate_limit_exceeded" + + openai_client.responses.parse.side_effect = mock_error + + with pytest.raises(ModelThrottledException) as exc_info: + async for _ in model.structured_output(test_output_model_cls, messages): + pass + + assert "Rate limit exceeded" in str(exc_info.value) + assert exc_info.value.__cause__ == mock_error + + +def test_config_validation_warns_on_unknown_keys(openai_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + OpenAIResponsesModel({"api_key": "test"}, model_id="test-model", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 57614b97f..7082f4d6e 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -16,6 +16,7 @@ from strands.models.mistral import MistralModel from strands.models.ollama import OllamaModel from strands.models.openai import OpenAIModel +from strands.models.openai_responses import OpenAIResponsesModel from strands.models.writer import WriterModel @@ -118,6 +119,16 @@ def __init__(self): }, ), ) +openai_responses = ProviderInfo( + id="openai_responses", + environment_variable="OPENAI_API_KEY", + factory=lambda: OpenAIResponsesModel( + model_id="gpt-4o", + client_args={ + "api_key": os.getenv("OPENAI_API_KEY"), + }, + ), +) writer = ProviderInfo( id="writer", environment_variable="WRITER_API_KEY", @@ -149,5 +160,6 @@ def __init__(self): litellm, mistral, openai, + openai_responses, writer, ] diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index 503fca898..1b3c35cc6 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -7,6 +7,7 @@ import strands from strands import Agent, tool from strands.models.openai import OpenAIModel +from strands.models.openai_responses import OpenAIResponsesModel from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException from tests_integ.models import providers @@ -14,10 +15,16 @@ pytestmark = providers.openai.mark -@pytest.fixture -def model(): - return OpenAIModel( - model_id="gpt-4o", +@pytest.fixture( + params=[ + ("openai", OpenAIModel, "gpt-4o"), + ("openai_responses", OpenAIResponsesModel, "gpt-4o"), + ] +) +def model(request): + model_name, model_class, model_id = request.param + return model_class( + model_id=model_id, client_args={ "api_key": os.getenv("OPENAI_API_KEY"), }, @@ -73,7 +80,7 @@ def test_image_path(request): return request.config.rootpath / "tests_integ" / "test_image.png" -def test_agent_invoke(agent): +def test_agent_invoke(agent, model): result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() @@ -81,7 +88,7 @@ def test_agent_invoke(agent): @pytest.mark.asyncio -async def test_agent_invoke_async(agent): +async def test_agent_invoke_async(agent, model): result = await agent.invoke_async("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() @@ -89,7 +96,7 @@ async def test_agent_invoke_async(agent): @pytest.mark.asyncio -async def test_agent_stream_async(agent): +async def test_agent_stream_async(agent, model): stream = agent.stream_async("What is the time and weather in New York?") async for event in stream: _ = event @@ -170,15 +177,22 @@ def tool_with_image_return(): agent("Run the the tool and analyze the image") -def test_context_window_overflow_integration(): +@pytest.mark.parametrize( + "model_class,model_id", + [ + (OpenAIModel, "gpt-4o-mini-2024-07-18"), + (OpenAIResponsesModel, "gpt-4o-mini-2024-07-18"), + ], +) +def test_context_window_overflow_integration(model_class, model_id): """Integration test for context window overflow with OpenAI. This test verifies that when a request exceeds the model's context window, the OpenAI model properly raises a ContextWindowOverflowException. """ # Use gpt-4o-mini which has a smaller context window to make this test more reliable - mini_model = OpenAIModel( - model_id="gpt-4o-mini-2024-07-18", + mini_model = model_class( + model_id=model_id, client_args={ "api_key": os.getenv("OPENAI_API_KEY"), }, @@ -198,7 +212,14 @@ def test_context_window_overflow_integration(): agent(long_text) -def test_rate_limit_throttling_integration_no_retries(model): +@pytest.mark.parametrize( + "model_class,model_id", + [ + (OpenAIModel, "gpt-4o"), + (OpenAIResponsesModel, "gpt-4o"), + ], +) +def test_rate_limit_throttling_integration_no_retries(model_class, model_id): """Integration test for rate limit handling with retries disabled. This test verifies that when a request exceeds OpenAI's rate limits, @@ -207,6 +228,12 @@ def test_rate_limit_throttling_integration_no_retries(model): """ # Patch the event loop constants to disable retries for this test with unittest.mock.patch("strands.event_loop.event_loop.MAX_ATTEMPTS", 1): + model = model_class( + model_id=model_id, + client_args={ + "api_key": os.getenv("OPENAI_API_KEY"), + }, + ) agent = Agent(model=model) # Create a message that's very long to trigger token-per-minute rate limits From 8eea5d9beea822c2f475fdab720a0bb34ec52cf4 Mon Sep 17 00:00:00 2001 From: notgitika Date: Wed, 21 Jan 2026 06:14:19 +0000 Subject: [PATCH 2/3] fix: address comments and refactor --- src/strands/models/openai_responses.py | 644 ++++++++++-------- tests/strands/models/test_openai_responses.py | 122 ++++ 2 files changed, 486 insertions(+), 280 deletions(-) diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index fae95833c..e2e4945bd 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -1,13 +1,26 @@ """OpenAI model provider using the Responses API. -- Docs: https://platform.openai.com/docs/overview +The Responses API is OpenAI's newer API that differs from the Chat Completions API in several key ways: + +1. The Responses API can maintain conversation state server-side through "previous_response_id", + while Chat Completions is stateless and requires sending full conversation history each time. + Note: This implementation currently only implements the stateless approach. + +2. Responses API uses "input" (list of items) instead of "messages", and system + prompts are passed as "instructions" rather than a system role message. + +3. Responses API supports built-in tools (web search, code interpreter, file search) + Note: These are not yet implemented in this provider. + +- Docs: https://platform.openai.com/docs/api-reference/responses """ import base64 import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, Protocol, TypedDict, TypeVar, cast import openai from pydantic import BaseModel @@ -16,7 +29,7 @@ from ..types.content import ContentBlock, Messages from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec, ToolUse +from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse from ._validation import validate_config_keys from .model import Model @@ -24,6 +37,32 @@ T = TypeVar("T", bound=BaseModel) +# Maximum file size for media content in tool results (20MB) +MAX_MEDIA_SIZE_BYTES = 20 * 1024 * 1024 + + +def _encode_media_to_data_url(data: bytes, format_ext: str, media_type: str = "image") -> str: + """Encode media bytes to a base64 data URL with size validation. + + Args: + data: Raw bytes of the media content. + format_ext: File format extension (e.g., "png", "pdf"). + media_type: Type of media for error messages ("image" or "document"). + + Returns: + Base64-encoded data URL string. + + Raises: + ValueError: If the media size exceeds the maximum allowed size (20MB). + """ + if len(data) > MAX_MEDIA_SIZE_BYTES: + raise ValueError( + f"{media_type.capitalize()} size {len(data)} bytes exceeds maximum of {MAX_MEDIA_SIZE_BYTES} bytes (20MB)" + ) + mime_type = mimetypes.types_map.get(f".{format_ext}", "application/octet-stream") + encoded_data = base64.b64encode(data).decode("utf-8") + return f"data:{mime_type};base64,{encoded_data}" + class Client(Protocol): """Protocol defining the OpenAI Responses API interface for the underlying provider client.""" @@ -53,10 +92,10 @@ class OpenAIResponsesConfig(TypedDict, total=False): """ model_id: str - params: Optional[dict[str, Any]] + params: dict[str, Any] | None def __init__( - self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIResponsesConfig] + self, client_args: dict[str, Any] | None = None, **model_config: Unpack[OpenAIResponsesConfig] ) -> None: """Initialize provider instance. @@ -90,260 +129,14 @@ def get_config(self) -> OpenAIResponsesConfig: """ return cast(OpenAIResponsesModel.OpenAIResponsesConfig, self.config) - @classmethod - def _format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: - """Format an OpenAI compatible content block. - - Args: - content: Message content. - - Returns: - OpenAI compatible content block. - - Raises: - TypeError: If the content block type cannot be converted to an OpenAI-compatible format. - """ - if "document" in content: - # only PDF type supported - mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") - file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") - return { - "type": "input_file", - "file_url": f"data:{mime_type};base64,{file_data}", - } - - if "image" in content: - mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") - image_data = base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8") - - return { - "type": "input_image", - "image_url": f"data:{mime_type};base64,{image_data}", - } - - if "text" in content: - return {"type": "input_text", "text": content["text"]} - - raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") - - @classmethod - def _format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: - """Format an OpenAI compatible tool call. - - Args: - tool_use: Tool use requested by the model. - - Returns: - OpenAI compatible tool call. - """ - return { - "type": "function_call", - "call_id": tool_use["toolUseId"], - "name": tool_use["name"], - "arguments": json.dumps(tool_use["input"]), - } - - @classmethod - def _format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: - """Format an OpenAI compatible tool message. - - Args: - tool_result: Tool result collected from a tool execution. - - Returns: - OpenAI compatible tool message. - """ - output_parts = [] - - for content in tool_result["content"]: - if "json" in content: - output_parts.append(json.dumps(content["json"])) - elif "text" in content: - output_parts.append(content["text"]) - - return { - "type": "function_call_output", - "call_id": tool_result["toolUseId"], - "output": "\n".join(output_parts) if output_parts else "", - } - - @classmethod - def _format_request_messages(cls, messages: Messages) -> list[dict[str, Any]]: - """Format an OpenAI compatible messages array. - - Args: - messages: List of message objects to be processed by the model. - - Returns: - An OpenAI compatible messages array. - """ - formatted_messages: list[dict[str, Any]] = [] - - for message in messages: - role = message["role"] - if role == "system": - continue # type: ignore[unreachable] - - contents = message["content"] - - formatted_contents = [ - cls._format_request_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) - ] - - formatted_tool_calls = [ - cls._format_request_message_tool_call(content["toolUse"]) - for content in contents - if "toolUse" in content - ] - - formatted_tool_messages = [ - cls._format_request_tool_message(content["toolResult"]) - for content in contents - if "toolResult" in content - ] - - if formatted_contents: - formatted_messages.append( - { - "role": role, # "user" | "assistant" - "content": formatted_contents, - } - ) - - formatted_messages.extend(formatted_tool_calls) - formatted_messages.extend(formatted_tool_messages) - - return [ - message - for message in formatted_messages - if message.get("content") or message.get("type") in ["function_call", "function_call_output"] - ] - - def _format_request( - self, - messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, - ) -> dict[str, Any]: - """Format an OpenAI Responses API compatible response streaming request. - - Args: - messages: List of message objects to be processed by the model. - tool_specs: List of tool specifications to make available to the model. - system_prompt: System prompt to provide context to the model. - - Returns: - An OpenAI Responses API compatible response streaming request. - - Raises: - TypeError: If a message contains a content block type that cannot be converted to an OpenAI-compatible - format. - """ - input_items = self._format_request_messages(messages) - request = { - "model": self.config["model_id"], - "input": input_items, - "stream": True, - **cast(dict[str, Any], self.config.get("params", {})), - } - - if system_prompt: - request["instructions"] = system_prompt - - # Add tools if provided - if tool_specs: - request["tools"] = [ - { - "type": "function", - "name": tool_spec["name"], - "description": tool_spec.get("description", ""), - "parameters": tool_spec["inputSchema"]["json"], - } - for tool_spec in tool_specs - ] - - return request - - def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: - """Format an OpenAI response event into a standardized message chunk. - - Args: - event: A response event from the OpenAI compatible model. - - Returns: - The formatted chunk. - - Raises: - RuntimeError: If chunk_type is not recognized. - This error should never be encountered as chunk_type is controlled in the stream method. - """ - match event["chunk_type"]: - case "message_start": - return {"messageStart": {"role": "assistant"}} - - case "content_start": - if event["data_type"] == "tool": - return { - "contentBlockStart": { - "start": { - "toolUse": { - "name": event["data"].function.name, - "toolUseId": event["data"].id, - } - } - } - } - - return {"contentBlockStart": {"start": {}}} - - case "content_delta": - if event["data_type"] == "tool": - return { - "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}} - } - - if event["data_type"] == "reasoning_content": - return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}} - - return {"contentBlockDelta": {"delta": {"text": event["data"]}}} - - case "content_stop": - return {"contentBlockStop": {}} - - case "message_stop": - match event["data"]: - case "tool_calls": - return {"messageStop": {"stopReason": "tool_use"}} - case "length": - return {"messageStop": {"stopReason": "max_tokens"}} - case _: - return {"messageStop": {"stopReason": "end_turn"}} - - case "metadata": - return { - "metadata": { - "usage": { - "inputTokens": event["data"].prompt_tokens, - "outputTokens": event["data"].completion_tokens, - "totalTokens": event["data"].total_tokens, - }, - "metrics": { - "latencyMs": 0, # TODO - }, - }, - } - - case _: - raise RuntimeError(f"chunk_type=<{event['chunk_type']}> | unknown type") - @override async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + *, + tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the OpenAI Responses API model. @@ -352,6 +145,7 @@ async def stream( messages: List of message objects to be processed by the model. tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -362,7 +156,7 @@ async def stream( ModelThrottledException: If the request is throttled by OpenAI (rate limits). """ logger.debug("formatting request for OpenAI Responses API") - request = self._format_request(messages, tool_specs, system_prompt) + request = self._format_request(messages, tool_specs, system_prompt, tool_choice) logger.debug("formatted request=<%s>", request) logger.debug("invoking OpenAI Responses API model") @@ -370,15 +164,14 @@ async def stream( async with openai.AsyncOpenAI(**self.client_args) as client: try: response = await client.responses.create(**request) - except openai.APIError as e: + except openai.BadRequestError as e: if hasattr(e, "code") and e.code == "context_length_exceeded": logger.warning("OpenAI Responses API threw context window overflow error") raise ContextWindowOverflowException(str(e)) from e - elif hasattr(e, "code") and e.code == "rate_limit_exceeded": - logger.warning("OpenAI Responses API threw rate limit error") - raise ModelThrottledException(str(e)) from e - else: - raise + raise + except openai.RateLimitError as e: + logger.warning("OpenAI Responses API threw rate limit error") + raise ModelThrottledException(str(e)) from e logger.debug("got response from OpenAI Responses API model") @@ -430,15 +223,14 @@ async def stream( if hasattr(event, "response") and hasattr(event.response, "usage"): final_usage = event.response.usage break - except openai.APIError as e: + except openai.BadRequestError as e: if hasattr(e, "code") and e.code == "context_length_exceeded": logger.warning("OpenAI Responses API threw context window overflow error") raise ContextWindowOverflowException(str(e)) from e - elif hasattr(e, "code") and e.code == "rate_limit_exceeded": - logger.warning("OpenAI Responses API threw rate limit error") - raise ModelThrottledException(str(e)) from e - else: - raise + raise + except openai.RateLimitError as e: + logger.warning("OpenAI Responses API threw rate limit error") + raise ModelThrottledException(str(e)) from e # Close text content if we had any if has_text_content: @@ -480,8 +272,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the OpenAI Responses API model. Args: @@ -512,18 +304,310 @@ async def structured_output( except openai.RateLimitError as e: logger.warning("OpenAI Responses API threw rate limit error") raise ModelThrottledException(str(e)) from e - except openai.APIError as e: - # Handle streaming API errors that come as APIError - error_message = str(e).lower() - if "context window" in error_message or "exceeds the context" in error_message: - logger.warning("OpenAI Responses API threw context window overflow error") - raise ContextWindowOverflowException(str(e)) from e - elif "rate limit" in error_message or "tokens per min" in error_message: - logger.warning("OpenAI Responses API threw rate limit error") - raise ModelThrottledException(str(e)) from e - raise if response.output_parsed: yield {"output": response.output_parsed} else: raise ValueError("No valid parsed output found in the OpenAI Responses API response.") + + def _format_request( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + tool_choice: ToolChoice | None = None, + ) -> dict[str, Any]: + """Format an OpenAI Responses API compatible response streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. + + Returns: + An OpenAI Responses API compatible response streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to an OpenAI-compatible + format. + """ + input_items = self._format_request_messages(messages) + request = { + "model": self.config["model_id"], + "input": input_items, + "stream": True, + **cast(dict[str, Any], self.config.get("params", {})), + } + + if system_prompt: + request["instructions"] = system_prompt + + # Add tools if provided + if tool_specs: + request["tools"] = [ + { + "type": "function", + "name": tool_spec["name"], + "description": tool_spec.get("description", ""), + "parameters": tool_spec["inputSchema"]["json"], + } + for tool_spec in tool_specs + ] + # Add tool_choice if provided + request.update(self._format_request_tool_choice(tool_choice)) + + return request + + @classmethod + def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str, Any]: + """Format a tool choice for OpenAI Responses API compatibility. + + Args: + tool_choice: Tool choice configuration. + + Returns: + OpenAI Responses API compatible tool choice format. + """ + if not tool_choice: + return {} + + match tool_choice: + case {"auto": _}: + return {"tool_choice": "auto"} + case {"any": _}: + return {"tool_choice": "required"} + case {"tool": {"name": tool_name}}: + return {"tool_choice": {"type": "function", "name": tool_name}} + case _: + # Default to auto for unknown formats + return {"tool_choice": "auto"} + + @classmethod + def _format_request_messages(cls, messages: Messages) -> list[dict[str, Any]]: + """Format an OpenAI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + + Returns: + An OpenAI compatible messages array. + """ + formatted_messages: list[dict[str, Any]] = [] + + for message in messages: + role = message["role"] + if role == "system": + # Skip system messages - the Responses API uses "instructions" parameter + # for system prompts instead of including them in the input items array. + # This is handled in _format_request() where system_prompt is passed separately. + continue # type: ignore[unreachable] + + contents = message["content"] + + formatted_contents = [ + cls._format_request_message_content(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + ] + + formatted_tool_calls = [ + cls._format_request_message_tool_call(content["toolUse"]) + for content in contents + if "toolUse" in content + ] + + formatted_tool_messages = [ + cls._format_request_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + if formatted_contents: + formatted_messages.append( + { + "role": role, # "user" | "assistant" + "content": formatted_contents, + } + ) + + formatted_messages.extend(formatted_tool_calls) + formatted_messages.extend(formatted_tool_messages) + + return [ + message + for message in formatted_messages + if message.get("content") or message.get("type") in ["function_call", "function_call_output"] + ] + + @classmethod + def _format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + """Format an OpenAI compatible content block. + + Args: + content: Message content. + + Returns: + OpenAI compatible content block. + + Raises: + TypeError: If the content block type cannot be converted to an OpenAI-compatible format. + ValueError: If the image or document size exceeds the maximum allowed size (20MB). + """ + if "document" in content: + data_url = _encode_media_to_data_url( + content["document"]["source"]["bytes"], content["document"]["format"], "document" + ) + return {"type": "input_file", "file_url": data_url} + + if "image" in content: + data_url = _encode_media_to_data_url( + content["image"]["source"]["bytes"], content["image"]["format"], "image" + ) + return {"type": "input_image", "image_url": data_url} + + if "text" in content: + return {"type": "input_text", "text": content["text"]} + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + @classmethod + def _format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: + """Format an OpenAI compatible tool call. + + Args: + tool_use: Tool use requested by the model. + + Returns: + OpenAI compatible tool call. + """ + return { + "type": "function_call", + "call_id": tool_use["toolUseId"], + "name": tool_use["name"], + "arguments": json.dumps(tool_use["input"]), + } + + @classmethod + def _format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + """Format an OpenAI compatible tool message. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + OpenAI compatible tool message. + + Raises: + ValueError: If the image or document size exceeds the maximum allowed size (20MB). + + Note: + The Responses API's function_call_output can be either a string (typically JSON encoded) + or an array of content objects when returning images/files. + See: https://platform.openai.com/docs/guides/function-calling + """ + output_parts: list[dict[str, Any]] = [] + has_media = False + + for content in tool_result["content"]: + if "json" in content: + output_parts.append({"type": "input_text", "text": json.dumps(content["json"])}) + elif "text" in content: + output_parts.append({"type": "input_text", "text": content["text"]}) + elif "image" in content: + has_media = True + data_url = _encode_media_to_data_url( + content["image"]["source"]["bytes"], content["image"]["format"], "image" + ) + output_parts.append({"type": "input_image", "image_url": data_url}) + elif "document" in content: + has_media = True + data_url = _encode_media_to_data_url( + content["document"]["source"]["bytes"], content["document"]["format"], "document" + ) + output_parts.append({"type": "input_file", "file_url": data_url}) + + # Return array if has media content, otherwise join as string for simpler text-only cases + output: list[dict[str, Any]] | str + if has_media: + output = output_parts + else: + output = "\n".join(part.get("text", "") for part in output_parts) if output_parts else "" + + return { + "type": "function_call_output", + "call_id": tool_result["toolUseId"], + "output": output, + } + + def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format an OpenAI response event into a standardized message chunk. + + Args: + event: A response event from the OpenAI compatible model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + This error should never be encountered as chunk_type is controlled in the stream method. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "tool": + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": event["data"].function.name, + "toolUseId": event["data"].id, + } + } + } + } + + return {"contentBlockStart": {"start": {}}} + + case "content_delta": + if event["data_type"] == "tool": + return { + "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}} + } + + if event["data_type"] == "reasoning_content": + return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}} + + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event["data"]: + case "tool_calls": + return {"messageStop": {"stopReason": "tool_use"}} + case "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + }, + "metrics": { + "latencyMs": 0, # TODO + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']}> | unknown type") diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index eb78217d8..0ac7d27ad 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -166,6 +166,51 @@ def test_format_request_tool_message(): assert tru_result == exp_result +def test_format_request_tool_message_with_image(): + """Test that tool results with images return an array output.""" + tool_result = { + "content": [ + {"text": "Here is the image:"}, + {"image": {"format": "png", "source": {"bytes": b"fake_image_data"}}}, + ], + "status": "success", + "toolUseId": "c2", + } + + tru_result = OpenAIResponsesModel._format_request_tool_message(tool_result) + + assert tru_result["type"] == "function_call_output" + assert tru_result["call_id"] == "c2" + # When images are present, output should be an array + assert isinstance(tru_result["output"], list) + assert len(tru_result["output"]) == 2 + assert tru_result["output"][0]["type"] == "input_text" + assert tru_result["output"][0]["text"] == "Here is the image:" + assert tru_result["output"][1]["type"] == "input_image" + assert "image_url" in tru_result["output"][1] + + +def test_format_request_tool_message_with_document(): + """Test that tool results with documents return an array output.""" + tool_result = { + "content": [ + {"document": {"format": "pdf", "name": "test.pdf", "source": {"bytes": b"fake_pdf_data"}}}, + ], + "status": "success", + "toolUseId": "c3", + } + + tru_result = OpenAIResponsesModel._format_request_tool_message(tool_result) + + assert tru_result["type"] == "function_call_output" + assert tru_result["call_id"] == "c3" + # When documents are present, output should be an array + assert isinstance(tru_result["output"], list) + assert len(tru_result["output"]) == 1 + assert tru_result["output"][0]["type"] == "input_file" + assert "file_url" in tru_result["output"][0] + + def test_format_request_messages(system_prompt): messages = [ { @@ -536,3 +581,80 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings assert len(captured_warnings) == 1 assert "Invalid configuration parameters" in str(captured_warnings[0].message) assert "wrong_param" in str(captured_warnings[0].message) + + +@pytest.mark.parametrize( + ("tool_choice", "expected"), + [ + (None, {}), + ({"auto": {}}, {"tool_choice": "auto"}), + ({"any": {}}, {"tool_choice": "required"}), + ({"tool": {"name": "calculator"}}, {"tool_choice": {"type": "function", "name": "calculator"}}), + ({"unknown": {}}, {"tool_choice": "auto"}), # Test default fallback + ], +) +def test_format_request_tool_choice(tool_choice, expected): + """Test that tool_choice is properly formatted for the Responses API.""" + result = OpenAIResponsesModel._format_request_tool_choice(tool_choice) + assert result == expected + + +def test_format_request_with_tool_choice(model, messages, tool_specs): + """Test that tool_choice is properly included in the request.""" + tool_choice = {"tool": {"name": "test_tool"}} + request = model._format_request(messages, tool_specs, tool_choice=tool_choice) + + assert "tool_choice" in request + assert request["tool_choice"] == {"type": "function", "name": "test_tool"} + + +def test_format_request_message_content_image_size_limit(): + """Test that oversized images raise ValueError.""" + from strands.models.openai_responses import MAX_MEDIA_SIZE_BYTES + + oversized_data = b"x" * (MAX_MEDIA_SIZE_BYTES + 1) + content = {"image": {"format": "png", "source": {"bytes": oversized_data}}} + + with pytest.raises(ValueError, match="Image size .* exceeds maximum"): + OpenAIResponsesModel._format_request_message_content(content) + + +def test_format_request_message_content_document_size_limit(): + """Test that oversized documents raise ValueError.""" + from strands.models.openai_responses import MAX_MEDIA_SIZE_BYTES + + oversized_data = b"x" * (MAX_MEDIA_SIZE_BYTES + 1) + content = {"document": {"format": "pdf", "name": "large.pdf", "source": {"bytes": oversized_data}}} + + with pytest.raises(ValueError, match="Document size .* exceeds maximum"): + OpenAIResponsesModel._format_request_message_content(content) + + +def test_format_request_tool_message_image_size_limit(): + """Test that oversized images in tool results raise ValueError.""" + from strands.models.openai_responses import MAX_MEDIA_SIZE_BYTES + + oversized_data = b"x" * (MAX_MEDIA_SIZE_BYTES + 1) + tool_result = { + "content": [{"image": {"format": "png", "source": {"bytes": oversized_data}}}], + "status": "success", + "toolUseId": "c1", + } + + with pytest.raises(ValueError, match="Image size .* exceeds maximum"): + OpenAIResponsesModel._format_request_tool_message(tool_result) + + +def test_format_request_tool_message_document_size_limit(): + """Test that oversized documents in tool results raise ValueError.""" + from strands.models.openai_responses import MAX_MEDIA_SIZE_BYTES + + oversized_data = b"x" * (MAX_MEDIA_SIZE_BYTES + 1) + tool_result = { + "content": [{"document": {"format": "pdf", "name": "large.pdf", "source": {"bytes": oversized_data}}}], + "status": "success", + "toolUseId": "c1", + } + + with pytest.raises(ValueError, match="Document size .* exceeds maximum"): + OpenAIResponsesModel._format_request_tool_message(tool_result) From e09f874417112f8195c2153fa8d6de21b65147b5 Mon Sep 17 00:00:00 2001 From: notgitika Date: Wed, 21 Jan 2026 06:40:36 +0000 Subject: [PATCH 3/3] feat: add conditional to check v1 vs v2 ; add tests to increase coverage --- src/strands/models/openai_responses.py | 17 +++ tests/strands/models/test_openai_responses.py | 102 +++++++++++++++--- 2 files changed, 106 insertions(+), 13 deletions(-) diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index e2e4945bd..ad05ee3da 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -20,9 +20,11 @@ import logging import mimetypes from collections.abc import AsyncGenerator +from importlib.metadata import version as get_package_version from typing import Any, Protocol, TypedDict, TypeVar, cast import openai +from packaging.version import Version from pydantic import BaseModel from typing_extensions import Unpack, override @@ -37,6 +39,9 @@ T = TypeVar("T", bound=BaseModel) +# Minimum OpenAI SDK version required for Responses API +_MIN_OPENAI_VERSION = Version("2.0.0") + # Maximum file size for media content in tool results (20MB) MAX_MEDIA_SIZE_BYTES = 20 * 1024 * 1024 @@ -103,7 +108,19 @@ def __init__( client_args: Arguments for the OpenAI client. For a complete list of supported arguments, see https://pypi.org/project/openai/. **model_config: Configuration options for the OpenAI Responses API model. + + Raises: + ImportError: If the installed OpenAI SDK version is less than 2.0.0. """ + # Validate OpenAI SDK version - Responses API requires v2.0.0+ + openai_version = Version(get_package_version("openai")) + if openai_version < _MIN_OPENAI_VERSION: + raise ImportError( + f"OpenAIResponsesModel requires openai>={_MIN_OPENAI_VERSION} (found {openai_version}). " + "Install/upgrade with: pip install -U openai. " + "For older SDKs, use OpenAIModel (Chat Completions)." + ) + validate_config_keys(model_config, self.OpenAIResponsesConfig) self.config = dict(model_config) self.client_args = client_args or {} diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index 0ac7d27ad..a3d857ca2 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -5,7 +5,7 @@ import pytest import strands -from strands.models.openai_responses import OpenAIResponsesModel +from strands.models.openai_responses import MAX_MEDIA_SIZE_BYTES, OpenAIResponsesModel from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException @@ -23,8 +23,17 @@ def model_id(): @pytest.fixture -def model(openai_client, model_id): +def mock_openai_version(): + """Mock the OpenAI version check to allow testing with v1.x SDK.""" + with unittest.mock.patch("strands.models.openai_responses.get_package_version") as mock_version: + mock_version.return_value = "2.0.0" + yield mock_version + + +@pytest.fixture +def model(openai_client, model_id, mock_openai_version): _ = openai_client + _ = mock_openai_version return OpenAIResponsesModel(model_id=model_id, params={"max_output_tokens": 100}) @@ -66,7 +75,8 @@ class TestOutputModel(pydantic.BaseModel): return TestOutputModel -def test__init__(model_id): +def test__init__(model_id, mock_openai_version): + _ = mock_openai_version model = OpenAIResponsesModel(model_id=model_id, params={"max_output_tokens": 100}) tru_config = model.get_config() @@ -522,7 +532,71 @@ async def test_stream_rate_limit_as_throttle(openai_client, model, messages): pass assert "Rate limit exceeded" in str(exc_info.value) - assert exc_info.value.__cause__ == mock_error + + +@pytest.mark.asyncio +async def test_stream_bad_request_non_context_overflow(openai_client, model, messages): + """Test that non-context-overflow BadRequestErrors are re-raised.""" + mock_error = openai.BadRequestError( + message="Invalid request format", + response=unittest.mock.MagicMock(), + body={"error": {"code": "invalid_request"}}, + ) + mock_error.code = "invalid_request" + + openai_client.responses.create.side_effect = mock_error + + with pytest.raises(openai.BadRequestError) as exc_info: + async for _ in model.stream(messages): + pass + + assert exc_info.value == mock_error + + +@pytest.mark.asyncio +async def test_stream_error_during_iteration(openai_client, model, messages, agenerator): + """Test that errors during streaming iteration are properly handled.""" + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hello") + + async def error_generator(): + yield mock_text_event + raise openai.RateLimitError( + message="Rate limit during stream", + response=unittest.mock.MagicMock(), + body={"error": {"code": "rate_limit_exceeded"}}, + ) + + openai_client.responses.create = unittest.mock.AsyncMock(return_value=error_generator()) + + with pytest.raises(ModelThrottledException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "Rate limit during stream" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_stream_context_overflow_during_iteration(openai_client, model, messages): + """Test that context overflow during streaming iteration is properly handled.""" + mock_text_event = unittest.mock.Mock(type="response.output_text.delta", delta="Hello") + + async def error_generator(): + yield mock_text_event + error = openai.BadRequestError( + message="Context length exceeded during stream", + response=unittest.mock.MagicMock(), + body={"error": {"code": "context_length_exceeded"}}, + ) + error.code = "context_length_exceeded" + raise error + + openai_client.responses.create = unittest.mock.AsyncMock(return_value=error_generator()) + + with pytest.raises(ContextWindowOverflowException) as exc_info: + async for _ in model.stream(messages): + pass + + assert "Context length exceeded" in str(exc_info.value) @pytest.mark.asyncio @@ -565,8 +639,9 @@ async def test_structured_output_rate_limit_as_throttle(openai_client, model, me assert exc_info.value.__cause__ == mock_error -def test_config_validation_warns_on_unknown_keys(openai_client, captured_warnings): +def test_config_validation_warns_on_unknown_keys(openai_client, captured_warnings, mock_openai_version): """Test that unknown config keys emit a warning.""" + _ = mock_openai_version OpenAIResponsesModel({"api_key": "test"}, model_id="test-model", invalid_param="test") assert len(captured_warnings) == 1 @@ -610,8 +685,6 @@ def test_format_request_with_tool_choice(model, messages, tool_specs): def test_format_request_message_content_image_size_limit(): """Test that oversized images raise ValueError.""" - from strands.models.openai_responses import MAX_MEDIA_SIZE_BYTES - oversized_data = b"x" * (MAX_MEDIA_SIZE_BYTES + 1) content = {"image": {"format": "png", "source": {"bytes": oversized_data}}} @@ -621,8 +694,6 @@ def test_format_request_message_content_image_size_limit(): def test_format_request_message_content_document_size_limit(): """Test that oversized documents raise ValueError.""" - from strands.models.openai_responses import MAX_MEDIA_SIZE_BYTES - oversized_data = b"x" * (MAX_MEDIA_SIZE_BYTES + 1) content = {"document": {"format": "pdf", "name": "large.pdf", "source": {"bytes": oversized_data}}} @@ -632,8 +703,6 @@ def test_format_request_message_content_document_size_limit(): def test_format_request_tool_message_image_size_limit(): """Test that oversized images in tool results raise ValueError.""" - from strands.models.openai_responses import MAX_MEDIA_SIZE_BYTES - oversized_data = b"x" * (MAX_MEDIA_SIZE_BYTES + 1) tool_result = { "content": [{"image": {"format": "png", "source": {"bytes": oversized_data}}}], @@ -647,8 +716,6 @@ def test_format_request_tool_message_image_size_limit(): def test_format_request_tool_message_document_size_limit(): """Test that oversized documents in tool results raise ValueError.""" - from strands.models.openai_responses import MAX_MEDIA_SIZE_BYTES - oversized_data = b"x" * (MAX_MEDIA_SIZE_BYTES + 1) tool_result = { "content": [{"document": {"format": "pdf", "name": "large.pdf", "source": {"bytes": oversized_data}}}], @@ -658,3 +725,12 @@ def test_format_request_tool_message_document_size_limit(): with pytest.raises(ValueError, match="Document size .* exceeds maximum"): OpenAIResponsesModel._format_request_tool_message(tool_result) + + +def test_openai_version_check(): + """Test that initialization fails with old OpenAI SDK version.""" + with unittest.mock.patch("strands.models.openai_responses.get_package_version") as mock_version: + mock_version.return_value = "1.99.0" + + with pytest.raises(ImportError, match="OpenAIResponsesModel requires openai>=2.0.0"): + OpenAIResponsesModel(model_id="gpt-4o")