Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions examples/gemini_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import asyncio
import os
from ai_sdk import gemini, generate_text, stream_text

from dotenv import load_dotenv

load_dotenv()

if not os.getenv("GEMINI_API_KEY"):
print("Please set GEMINI_API_KEY environment variable.")
exit(1)


async def main():
model = gemini("gemini-2.5-flash")

print("--- Generating Text ---")
response = generate_text(
model=model,
prompt="Tell me a joke about programming.",
)
print(response.text)

print("\n--- Streaming Text ---")
result = stream_text(
model=model,
prompt="Write a haiku about Python. Make it LONG.",
)
async for delta in result.text_stream:
print(delta, end="", flush=True)
print()


if __name__ == "__main__":
asyncio.run(main())
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
asyncio_mode = auto
2 changes: 2 additions & 0 deletions src/ai_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .providers.openai import openai
from .tool import tool, Tool
from .providers.anthropic import anthropic
from .providers.gemini import gemini
from .agent import Agent

"""Public entry-point for the *Python* port of Vercel's AI SDK.
Expand All @@ -26,6 +27,7 @@
"cosine_similarity",
"openai",
"anthropic",
"gemini",
"tool",
"Tool",
"Agent",
Expand Down
1 change: 1 addition & 0 deletions src/ai_sdk/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def cosine_similarity(vec_a: Sequence[float], vec_b: Sequence[float]) -> float:

return dot / (norm_a * norm_b)


# ---------------------------------------------------------------------------
# Public helpers
# ---------------------------------------------------------------------------
Expand Down
141 changes: 141 additions & 0 deletions src/ai_sdk/providers/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from __future__ import annotations

import os
from typing import Any, AsyncIterator, Dict, List, Optional
from openai import OpenAI # type: ignore[import]

from .language_model import LanguageModel
from .openai import _build_chat_messages


class GeminiModel(LanguageModel):
"""OpenAI SDK compatibility provider for Google Gemini models."""

def __init__(
self,
model: str,
*,
api_key: Optional[str] = None,
base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai/",
**default_kwargs: Any,
) -> None:
# Use OpenAI SDK to talk to Gemini endpoint
# Resolve API key from argument or environment variable
api_key = api_key or os.getenv("GEMINI_API_KEY")
self._client = OpenAI(api_key=api_key, base_url=base_url)
self._model = model
self._default_kwargs = default_kwargs

def generate_text(
self,
*,
prompt: str | None = None,
system: str | None = None,
messages: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Generate a completion via the OpenAI SDK compatibility layer against Gemini."""
if prompt is None and not messages:
raise ValueError("Either 'prompt' or 'messages' must be provided.")

# Build chat messages using OpenAI helper
chat_messages = _build_chat_messages(
prompt=prompt, system=system, messages=messages
)

# Merge default kwargs with call-site overrides
request_kwargs: Dict[str, Any] = {**self._default_kwargs, **kwargs}

# Call via OpenAI SDK
resp = self._client.chat.completions.create(
model=self._model,
messages=chat_messages,
**request_kwargs,
)

choice = resp.choices[0]
text = choice.message.content or ""
finish_reason = choice.finish_reason or "unknown"

# Extract tool_calls if present
tool_calls = []
if getattr(choice.message, "tool_calls", None):
import json as _json

for call in choice.message.tool_calls: # type: ignore[attr-defined]
try:
args = _json.loads(call.function.arguments)
except Exception:
args = {"raw": call.function.arguments}
tool_calls.append(
{
"tool_call_id": call.id,
"tool_name": call.function.name,
"args": args,
}
)
finish_reason = "tool"

# Usage if available
usage = resp.usage.model_dump() if hasattr(resp, "usage") else None
return {
"text": text,
"finish_reason": finish_reason,
"usage": usage,
"raw_response": resp,
"tool_calls": tool_calls or None,
}

def stream_text(
self,
*,
prompt: str | None = None,
system: str | None = None,
messages: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
"""Stream deltas via OpenAI SDK compatibility."""
if prompt is None and not messages:
raise ValueError("Either 'prompt' or 'messages' must be provided.")

# Build messages and merge kwargs
chat_messages = _build_chat_messages(
prompt=prompt, system=system, messages=messages
)
request_kwargs: Dict[str, Any] = {**self._default_kwargs, **kwargs}

# Use AsyncOpenAI for streaming to avoid threading issues
from openai import AsyncOpenAI

async_client = AsyncOpenAI(
api_key=self._client.api_key, base_url=self._client.base_url
)

async def _generator() -> AsyncIterator[str]:
stream = await async_client.chat.completions.create(
model=self._model,
messages=chat_messages,
stream=True,
**request_kwargs,
)
async for chunk in stream:
delta = chunk.choices[0].delta
content = getattr(delta, "content", None)
if content:
yield content

await async_client.close()

return _generator()


# Public factory helper
def gemini(
model: str,
*,
api_key: Optional[str] = None,
base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai/",
**default_kwargs: Any,
) -> GeminiModel:
"""Return a configured GeminiModel instance using OpenAI SDK compatibility."""
return GeminiModel(model, api_key=api_key, base_url=base_url, **default_kwargs)
70 changes: 41 additions & 29 deletions src/ai_sdk/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

A *Tool* couples a JSON schema (name, description, parameters) with a Python
handler function. The :func:`tool` decorator behaves similar to the JavaScript
version it takes the manifest as its first call and then expects a function
version - it takes the manifest as its first call and then expects a function
that implements the tool logic::

@tool({
Expand All @@ -20,10 +20,10 @@ def double(x: int) -> int: # noqa: D401 – simple demo

# Or using Pydantic models for better type safety:
from pydantic import BaseModel

class DoubleParams(BaseModel):
x: int

@tool(name="double", description="Double the given integer.", parameters=DoubleParams)
def double(x: int) -> int:
return x * 2
Expand All @@ -47,13 +47,13 @@ def double(x: int) -> int:
def _pydantic_to_json_schema(model: Type[BaseModel]) -> Dict[str, Any]:
"""Convert a Pydantic model to JSON schema format."""
schema = model.model_json_schema()

# Ensure we have the required structure for OpenAI function calling
if "properties" not in schema:
schema["properties"] = {}
if "required" not in schema:
schema["required"] = []

return schema


Expand Down Expand Up @@ -86,7 +86,7 @@ async def run(self, **kwargs: Any) -> Any: # noqa: D401 – mirrors JS SDK
if self._pydantic_model is not None:
validated_data = self._pydantic_model(**kwargs)
kwargs = validated_data.model_dump()

result = self.handler(**kwargs)
if inspect.isawaitable(result):
return await result # type: ignore[return-value]
Expand All @@ -99,13 +99,13 @@ async def run(self, **kwargs: Any) -> Any: # noqa: D401 – mirrors JS SDK


def tool(
*,
name: str,
description: str,
parameters: Dict[str, Any] | Type[BaseModel],
execute: HandlerFn
) -> "Tool": # noqa: D401
'''Create a :class:`ai_sdk.tool.Tool` from a Python callable.
*,
name: str,
description: str,
parameters: Dict[str, Any] | Type[BaseModel],
execute: HandlerFn | None = None,
) -> "Tool" | Callable[[HandlerFn], "Tool"]: # noqa: D401
"""Create a :class:`ai_sdk.tool.Tool` from a Python callable.

Parameters
----------
Expand All @@ -130,7 +130,7 @@ def tool(
Examples
--------
Using JSON schema directly:

>>> @tool(
... name="double",
... description="Double the given integer.",
Expand All @@ -144,25 +144,23 @@ def tool(
... return x * 2

Using Pydantic model for better type safety:

>>> from pydantic import BaseModel
>>>
>>>
>>> class DoubleParams(BaseModel):
... x: int
...
...
>>> @tool(
... name="double",
... description="Double the given integer.",
... parameters=DoubleParams
... )
... def double(x: int) -> int:
... return x * 2
'''
"""

if not all([name, description, parameters, execute]):
raise ValueError(
"'name', 'description', 'parameters', and 'execute' are required"
)
if not all([name, description, parameters]):
raise ValueError("'name', 'description', and 'parameters' are required")

# Handle Pydantic model vs JSON schema
pydantic_model = None
Expand All @@ -176,10 +174,24 @@ def tool(
"parameters must be either a JSON schema dict or a Pydantic model class"
)

return Tool(
name=name,
description=description,
parameters=parameters_dict,
handler=execute,
_pydantic_model=pydantic_model
)
# If execute is provided (functional usage), return the Tool immediately
if execute is not None:
return Tool(
name=name,
description=description,
parameters=parameters_dict,
handler=execute,
_pydantic_model=pydantic_model,
)

# Otherwise (decorator usage), return a wrapper that accepts the function
def wrapper(func: HandlerFn) -> Tool:
return Tool(
name=name,
description=description,
parameters=parameters_dict,
handler=func,
_pydantic_model=pydantic_model,
)

return wrapper
1 change: 1 addition & 0 deletions src/ai_sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
previous dataclass-based implementation, so no changes are required in
existing downstream code.
"""

from __future__ import annotations

from datetime import datetime
Expand Down
1 change: 0 additions & 1 deletion tests/manual_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
tool,
embed_many,
cosine_similarity,
anthropic,
Agent,
)
from ai_sdk.types import CoreSystemMessage, CoreUserMessage, TextPart, AnyMessage
Expand Down
Loading