diff --git a/src/openai/__init__.py b/src/openai/__init__.py index e7411b3886..d73b31b9d7 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -7,7 +7,7 @@ from typing_extensions import override from . import types -from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes, omit, not_given +from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes, omit, not_given, TimeoutTypes from ._utils import file_from_path from ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions from ._models import BaseModel @@ -129,7 +129,7 @@ base_url: str | _httpx.URL | None = None -timeout: float | Timeout | None = DEFAULT_TIMEOUT +timeout: TimeoutTypes = DEFAULT_TIMEOUT max_retries: int = DEFAULT_MAX_RETRIES @@ -214,11 +214,11 @@ def base_url(self, url: _httpx.URL | str) -> None: @property # type: ignore @override - def timeout(self) -> float | Timeout | None: + def timeout(self) -> TimeoutTypes: return timeout @timeout.setter # type: ignore - def timeout(self, value: float | Timeout | None) -> None: # type: ignore + def timeout(self, value: TimeoutTypes) -> None: # type: ignore global timeout timeout = value diff --git a/src/openai/_base_client.py b/src/openai/_base_client.py index 9e536410d6..e16820b674 100644 --- a/src/openai/_base_client.py +++ b/src/openai/_base_client.py @@ -51,6 +51,7 @@ ResponseT, AnyMapping, PostParser, + TimeoutTypes, RequestFiles, HttpxSendArgs, RequestOptions, @@ -363,7 +364,7 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]): _version: str _base_url: URL max_retries: int - timeout: Union[float, Timeout, None] + timeout: TimeoutTypes _strict_response_validation: bool _idempotency_header: str | None _default_stream_cls: type[_DefaultStreamT] | None = None @@ -375,7 +376,7 @@ def __init__( base_url: str | URL, _strict_response_validation: bool, max_retries: int = DEFAULT_MAX_RETRIES, - timeout: float | Timeout | None = DEFAULT_TIMEOUT, + timeout: TimeoutTypes = DEFAULT_TIMEOUT, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, ) -> None: @@ -546,7 +547,7 @@ def _build_request( # TODO: report this error to httpx return self._client.build_request( # pyright: ignore[reportUnknownMemberType] headers=headers, - timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout, + timeout=cast(Any, self.timeout if isinstance(options.timeout, NotGiven) else options.timeout), method=options.method, url=prepared_url, # the `Query` type that we use is incompatible with qs' @@ -827,7 +828,7 @@ def __init__( version: str, base_url: str | URL, max_retries: int = DEFAULT_MAX_RETRIES, - timeout: float | Timeout | None | NotGiven = not_given, + timeout: TimeoutTypes | NotGiven = not_given, http_client: httpx.Client | None = None, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, @@ -1376,7 +1377,7 @@ def __init__( base_url: str | URL, _strict_response_validation: bool, max_retries: int = DEFAULT_MAX_RETRIES, - timeout: float | Timeout | None | NotGiven = not_given, + timeout: TimeoutTypes | NotGiven = not_given, http_client: httpx.AsyncClient | None = None, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, @@ -1856,7 +1857,7 @@ def make_request_options( extra_query: Query | None = None, extra_body: Body | None = None, idempotency_key: str | None = None, - timeout: float | httpx.Timeout | None | NotGiven = not_given, + timeout: TimeoutTypes | NotGiven = not_given, post_parser: PostParser | NotGiven = not_given, ) -> RequestOptions: """Create a dict of type RequestOptions without keys of NotGiven values.""" diff --git a/src/openai/_client.py b/src/openai/_client.py index a3b01b2ce6..226556033d 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -15,6 +15,7 @@ Timeout, NotGiven, Transport, + TimeoutTypes, ProxiesTypes, RequestOptions, not_given, @@ -105,7 +106,7 @@ def __init__( webhook_secret: str | None = None, base_url: str | httpx.URL | None = None, websocket_base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = not_given, + timeout: TimeoutTypes | NotGiven = not_given, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -347,7 +348,7 @@ def copy( webhook_secret: str | None = None, websocket_base_url: str | httpx.URL | None = None, base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = not_given, + timeout: TimeoutTypes | NotGiven = not_given, http_client: httpx.Client | None = None, max_retries: int | NotGiven = not_given, default_headers: Mapping[str, str] | None = None, @@ -456,7 +457,7 @@ def __init__( webhook_secret: str | None = None, base_url: str | httpx.URL | None = None, websocket_base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = not_given, + timeout: TimeoutTypes | NotGiven = not_given, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -698,7 +699,7 @@ def copy( webhook_secret: str | None = None, websocket_base_url: str | httpx.URL | None = None, base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = not_given, + timeout: TimeoutTypes | NotGiven = not_given, http_client: httpx.AsyncClient | None = None, max_retries: int | NotGiven = not_given, default_headers: Mapping[str, str] | None = None, diff --git a/src/openai/_models.py b/src/openai/_models.py index fac59c2cb8..d87f609cb6 100644 --- a/src/openai/_models.py +++ b/src/openai/_models.py @@ -33,6 +33,7 @@ Timeout, NotGiven, AnyMapping, + TimeoutTypes, HttpxRequestFiles, ) from ._utils import ( @@ -824,7 +825,7 @@ class FinalRequestOptionsInput(TypedDict, total=False): params: Query headers: Headers max_retries: int - timeout: float | Timeout | None + timeout: TimeoutTypes files: HttpxRequestFiles | None idempotency_key: str json_data: Body @@ -839,7 +840,7 @@ class FinalRequestOptions(pydantic.BaseModel): params: Query = {} headers: Union[Headers, NotGiven] = NotGiven() max_retries: Union[int, NotGiven] = NotGiven() - timeout: Union[float, Timeout, None, NotGiven] = NotGiven() + timeout: Union[TimeoutTypes, NotGiven] = NotGiven() files: Union[HttpxRequestFiles, None] = None idempotency_key: Union[str, None] = None post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven() diff --git a/src/openai/_types.py b/src/openai/_types.py index d7e2eaac5f..752500205e 100644 --- a/src/openai/_types.py +++ b/src/openai/_types.py @@ -105,10 +105,20 @@ NoneType = type(None) +TimeoutTypes = Union[ + Optional[float], + Timeout, + None, + Tuple[Optional[float], Optional[float]], + Tuple[Optional[float], Optional[float], Optional[float]], + Tuple[Optional[float], Optional[float], Optional[float], Optional[float]], +] + + class RequestOptions(TypedDict, total=False): headers: Headers max_retries: int - timeout: float | Timeout | None + timeout: TimeoutTypes params: Query extra_json: AnyMapping idempotency_key: str diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index ad64707261..e4e41b6219 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -7,7 +7,7 @@ import httpx -from .._types import NOT_GIVEN, Omit, Query, Timeout, NotGiven +from .._types import NOT_GIVEN, Omit, Query, Timeout, NotGiven, TimeoutTypes from .._utils import is_given, is_mapping from .._client import OpenAI, AsyncOpenAI from .._compat import model_copy @@ -100,7 +100,7 @@ def __init__( organization: str | None = None, webhook_secret: str | None = None, websocket_base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: TimeoutTypes | NotGiven = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -120,7 +120,7 @@ def __init__( organization: str | None = None, webhook_secret: str | None = None, websocket_base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: TimeoutTypes | NotGiven = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -140,7 +140,7 @@ def __init__( organization: str | None = None, webhook_secret: str | None = None, websocket_base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: TimeoutTypes | NotGiven = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -162,7 +162,7 @@ def __init__( webhook_secret: str | None = None, websocket_base_url: str | httpx.URL | None = None, base_url: str | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: TimeoutTypes | NotGiven = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -267,7 +267,7 @@ def copy( azure_ad_token: str | None = None, azure_ad_token_provider: AzureADTokenProvider | None = None, base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: TimeoutTypes | NotGiven = NOT_GIVEN, http_client: httpx.Client | None = None, max_retries: int | NotGiven = NOT_GIVEN, default_headers: Mapping[str, str] | None = None, @@ -379,7 +379,7 @@ def __init__( project: str | None = None, webhook_secret: str | None = None, websocket_base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: TimeoutTypes | NotGiven = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -400,7 +400,7 @@ def __init__( project: str | None = None, webhook_secret: str | None = None, websocket_base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: TimeoutTypes | NotGiven = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -421,7 +421,7 @@ def __init__( project: str | None = None, webhook_secret: str | None = None, websocket_base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: TimeoutTypes | NotGiven = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -443,7 +443,7 @@ def __init__( webhook_secret: str | None = None, base_url: str | None = None, websocket_base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: TimeoutTypes | NotGiven = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, @@ -548,7 +548,7 @@ def copy( azure_ad_token: str | None = None, azure_ad_token_provider: AsyncAzureADTokenProvider | None = None, base_url: str | httpx.URL | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + timeout: TimeoutTypes | NotGiven = NOT_GIVEN, http_client: httpx.AsyncClient | None = None, max_retries: int | NotGiven = NOT_GIVEN, default_headers: Mapping[str, str] | None = None, diff --git a/tests/test_client.py b/tests/test_client.py index e8d62f17f7..a7f4a6dbdb 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -260,6 +260,51 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic print(frame) raise AssertionError() + def test_request_timeout_tuple(self, client: OpenAI) -> None: + # 2-tuple + request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=(5.0, 10.0))) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout.connect == 5.0 + assert timeout.read == 10.0 + assert timeout.write is None + assert timeout.pool is None + + # 4-tuple + request = client._build_request( + FinalRequestOptions(method="get", url="/foo", timeout=(5.0, 10.0, 15.0, 20.0)) + ) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout.connect == 5.0 + assert timeout.read == 10.0 + assert timeout.write == 15.0 + assert timeout.pool == 20.0 + + def test_client_timeout_tuple(self) -> None: + # 2-tuple + client = OpenAI( + base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=(5.0, 10.0) + ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout.connect == 5.0 + assert timeout.read == 10.0 + client.close() + + # 4-tuple + client = OpenAI( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + timeout=(5.0, 10.0, 15.0, 20.0), + ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout.connect == 5.0 + assert timeout.read == 10.0 + assert timeout.write == 15.0 + assert timeout.pool == 20.0 + client.close() + def test_request_timeout(self, client: OpenAI) -> None: request = client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore @@ -1217,6 +1262,51 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic print(frame) raise AssertionError() + async def test_request_timeout_tuple(self, async_client: AsyncOpenAI) -> None: + # 2-tuple + request = async_client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=(5.0, 10.0))) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout.connect == 5.0 + assert timeout.read == 10.0 + assert timeout.write is None + assert timeout.pool is None + + # 4-tuple + request = async_client._build_request( + FinalRequestOptions(method="get", url="/foo", timeout=(5.0, 10.0, 15.0, 20.0)) + ) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout.connect == 5.0 + assert timeout.read == 10.0 + assert timeout.write == 15.0 + assert timeout.pool == 20.0 + + async def test_client_timeout_tuple(self) -> None: + # 2-tuple + client = AsyncOpenAI( + base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=(5.0, 10.0) + ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout.connect == 5.0 + assert timeout.read == 10.0 + await client.close() + + # 4-tuple + client = AsyncOpenAI( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + timeout=(5.0, 10.0, 15.0, 20.0), + ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout.connect == 5.0 + assert timeout.read == 10.0 + assert timeout.write == 15.0 + assert timeout.pool == 20.0 + await client.close() + async def test_request_timeout(self, async_client: AsyncOpenAI) -> None: request = async_client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore