Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing_extensions import override

from . import types
from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes, omit, not_given
from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes, omit, not_given, TimeoutTypes
from ._utils import file_from_path
from ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions
from ._models import BaseModel
Expand Down Expand Up @@ -129,7 +129,7 @@

base_url: str | _httpx.URL | None = None

timeout: float | Timeout | None = DEFAULT_TIMEOUT
timeout: TimeoutTypes = DEFAULT_TIMEOUT

max_retries: int = DEFAULT_MAX_RETRIES

Expand Down Expand Up @@ -214,11 +214,11 @@ def base_url(self, url: _httpx.URL | str) -> None:

@property # type: ignore
@override
def timeout(self) -> float | Timeout | None:
def timeout(self) -> TimeoutTypes:
return timeout

@timeout.setter # type: ignore
def timeout(self, value: float | Timeout | None) -> None: # type: ignore
def timeout(self, value: TimeoutTypes) -> None: # type: ignore
global timeout

timeout = value
Expand Down
13 changes: 7 additions & 6 deletions src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
ResponseT,
AnyMapping,
PostParser,
TimeoutTypes,
RequestFiles,
HttpxSendArgs,
RequestOptions,
Expand Down Expand Up @@ -363,7 +364,7 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
_version: str
_base_url: URL
max_retries: int
timeout: Union[float, Timeout, None]
timeout: TimeoutTypes
_strict_response_validation: bool
_idempotency_header: str | None
_default_stream_cls: type[_DefaultStreamT] | None = None
Expand All @@ -375,7 +376,7 @@ def __init__(
base_url: str | URL,
_strict_response_validation: bool,
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: float | Timeout | None = DEFAULT_TIMEOUT,
timeout: TimeoutTypes = DEFAULT_TIMEOUT,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
) -> None:
Expand Down Expand Up @@ -546,7 +547,7 @@ def _build_request(
# TODO: report this error to httpx
return self._client.build_request( # pyright: ignore[reportUnknownMemberType]
headers=headers,
timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout,
timeout=cast(Any, self.timeout if isinstance(options.timeout, NotGiven) else options.timeout),
method=options.method,
url=prepared_url,
# the `Query` type that we use is incompatible with qs'
Expand Down Expand Up @@ -827,7 +828,7 @@ def __init__(
version: str,
base_url: str | URL,
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: float | Timeout | None | NotGiven = not_given,
timeout: TimeoutTypes | NotGiven = not_given,
http_client: httpx.Client | None = None,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
Expand Down Expand Up @@ -1376,7 +1377,7 @@ def __init__(
base_url: str | URL,
_strict_response_validation: bool,
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: float | Timeout | None | NotGiven = not_given,
timeout: TimeoutTypes | NotGiven = not_given,
http_client: httpx.AsyncClient | None = None,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
Expand Down Expand Up @@ -1856,7 +1857,7 @@ def make_request_options(
extra_query: Query | None = None,
extra_body: Body | None = None,
idempotency_key: str | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
timeout: TimeoutTypes | NotGiven = not_given,
post_parser: PostParser | NotGiven = not_given,
) -> RequestOptions:
"""Create a dict of type RequestOptions without keys of NotGiven values."""
Expand Down
9 changes: 5 additions & 4 deletions src/openai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Timeout,
NotGiven,
Transport,
TimeoutTypes,
ProxiesTypes,
RequestOptions,
not_given,
Expand Down Expand Up @@ -105,7 +106,7 @@ def __init__(
webhook_secret: str | None = None,
base_url: str | httpx.URL | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = not_given,
timeout: TimeoutTypes | NotGiven = not_given,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
Expand Down Expand Up @@ -347,7 +348,7 @@ def copy(
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = not_given,
timeout: TimeoutTypes | NotGiven = not_given,
http_client: httpx.Client | None = None,
max_retries: int | NotGiven = not_given,
default_headers: Mapping[str, str] | None = None,
Expand Down Expand Up @@ -456,7 +457,7 @@ def __init__(
webhook_secret: str | None = None,
base_url: str | httpx.URL | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = not_given,
timeout: TimeoutTypes | NotGiven = not_given,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
Expand Down Expand Up @@ -698,7 +699,7 @@ def copy(
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = not_given,
timeout: TimeoutTypes | NotGiven = not_given,
http_client: httpx.AsyncClient | None = None,
max_retries: int | NotGiven = not_given,
default_headers: Mapping[str, str] | None = None,
Expand Down
5 changes: 3 additions & 2 deletions src/openai/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Timeout,
NotGiven,
AnyMapping,
TimeoutTypes,
HttpxRequestFiles,
)
from ._utils import (
Expand Down Expand Up @@ -824,7 +825,7 @@ class FinalRequestOptionsInput(TypedDict, total=False):
params: Query
headers: Headers
max_retries: int
timeout: float | Timeout | None
timeout: TimeoutTypes
files: HttpxRequestFiles | None
idempotency_key: str
json_data: Body
Expand All @@ -839,7 +840,7 @@ class FinalRequestOptions(pydantic.BaseModel):
params: Query = {}
headers: Union[Headers, NotGiven] = NotGiven()
max_retries: Union[int, NotGiven] = NotGiven()
timeout: Union[float, Timeout, None, NotGiven] = NotGiven()
timeout: Union[TimeoutTypes, NotGiven] = NotGiven()
files: Union[HttpxRequestFiles, None] = None
idempotency_key: Union[str, None] = None
post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
Expand Down
12 changes: 11 additions & 1 deletion src/openai/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,20 @@
NoneType = type(None)


TimeoutTypes = Union[
Optional[float],
Timeout,
None,
Tuple[Optional[float], Optional[float]],
Tuple[Optional[float], Optional[float], Optional[float]],
Tuple[Optional[float], Optional[float], Optional[float], Optional[float]],
]


class RequestOptions(TypedDict, total=False):
headers: Headers
max_retries: int
timeout: float | Timeout | None
timeout: TimeoutTypes
params: Query
extra_json: AnyMapping
idempotency_key: str
Expand Down
22 changes: 11 additions & 11 deletions src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import httpx

from .._types import NOT_GIVEN, Omit, Query, Timeout, NotGiven
from .._types import NOT_GIVEN, Omit, Query, Timeout, NotGiven, TimeoutTypes
from .._utils import is_given, is_mapping
from .._client import OpenAI, AsyncOpenAI
from .._compat import model_copy
Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(
organization: str | None = None,
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
timeout: TimeoutTypes | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
Expand All @@ -120,7 +120,7 @@ def __init__(
organization: str | None = None,
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
timeout: TimeoutTypes | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
Expand All @@ -140,7 +140,7 @@ def __init__(
organization: str | None = None,
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
timeout: TimeoutTypes | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
Expand All @@ -162,7 +162,7 @@ def __init__(
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
base_url: str | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
timeout: TimeoutTypes | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
Expand Down Expand Up @@ -267,7 +267,7 @@ def copy(
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
timeout: TimeoutTypes | NotGiven = NOT_GIVEN,
http_client: httpx.Client | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
default_headers: Mapping[str, str] | None = None,
Expand Down Expand Up @@ -379,7 +379,7 @@ def __init__(
project: str | None = None,
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
timeout: TimeoutTypes | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
Expand All @@ -400,7 +400,7 @@ def __init__(
project: str | None = None,
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
timeout: TimeoutTypes | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
Expand All @@ -421,7 +421,7 @@ def __init__(
project: str | None = None,
webhook_secret: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
timeout: TimeoutTypes | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
Expand All @@ -443,7 +443,7 @@ def __init__(
webhook_secret: str | None = None,
base_url: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
timeout: TimeoutTypes | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
Expand Down Expand Up @@ -548,7 +548,7 @@ def copy(
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
timeout: TimeoutTypes | NotGiven = NOT_GIVEN,
http_client: httpx.AsyncClient | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
default_headers: Mapping[str, str] | None = None,
Expand Down
90 changes: 90 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,51 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic
print(frame)
raise AssertionError()

def test_request_timeout_tuple(self, client: OpenAI) -> None:
# 2-tuple
request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=(5.0, 10.0)))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout.connect == 5.0
assert timeout.read == 10.0
assert timeout.write is None
assert timeout.pool is None

# 4-tuple
request = client._build_request(
FinalRequestOptions(method="get", url="/foo", timeout=(5.0, 10.0, 15.0, 20.0))
)
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout.connect == 5.0
assert timeout.read == 10.0
assert timeout.write == 15.0
assert timeout.pool == 20.0

def test_client_timeout_tuple(self) -> None:
# 2-tuple
client = OpenAI(
base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=(5.0, 10.0)
)
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout.connect == 5.0
assert timeout.read == 10.0
client.close()

# 4-tuple
client = OpenAI(
base_url=base_url,
api_key=api_key,
_strict_response_validation=True,
timeout=(5.0, 10.0, 15.0, 20.0),
)
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout.connect == 5.0
assert timeout.read == 10.0
assert timeout.write == 15.0
assert timeout.pool == 20.0
client.close()

def test_request_timeout(self, client: OpenAI) -> None:
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
Expand Down Expand Up @@ -1217,6 +1262,51 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic
print(frame)
raise AssertionError()

async def test_request_timeout_tuple(self, async_client: AsyncOpenAI) -> None:
# 2-tuple
request = async_client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=(5.0, 10.0)))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout.connect == 5.0
assert timeout.read == 10.0
assert timeout.write is None
assert timeout.pool is None

# 4-tuple
request = async_client._build_request(
FinalRequestOptions(method="get", url="/foo", timeout=(5.0, 10.0, 15.0, 20.0))
)
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout.connect == 5.0
assert timeout.read == 10.0
assert timeout.write == 15.0
assert timeout.pool == 20.0

async def test_client_timeout_tuple(self) -> None:
# 2-tuple
client = AsyncOpenAI(
base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=(5.0, 10.0)
)
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout.connect == 5.0
assert timeout.read == 10.0
await client.close()

# 4-tuple
client = AsyncOpenAI(
base_url=base_url,
api_key=api_key,
_strict_response_validation=True,
timeout=(5.0, 10.0, 15.0, 20.0),
)
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout.connect == 5.0
assert timeout.read == 10.0
assert timeout.write == 15.0
assert timeout.pool == 20.0
await client.close()

async def test_request_timeout(self, async_client: AsyncOpenAI) -> None:
request = async_client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
Expand Down