From 4d45495c55359f64b61b0290367d56cec50cf801 Mon Sep 17 00:00:00 2001 From: ynbh Date: Wed, 10 Dec 2025 19:57:31 -0500 Subject: [PATCH 1/4] feat: add Gemini provider with openai comptability --- examples/gemini_demo.py | 33 ++++++++ src/ai_sdk/__init__.py | 2 + src/ai_sdk/providers/gemini.py | 139 +++++++++++++++++++++++++++++++++ 3 files changed, 174 insertions(+) create mode 100644 examples/gemini_demo.py create mode 100644 src/ai_sdk/providers/gemini.py diff --git a/examples/gemini_demo.py b/examples/gemini_demo.py new file mode 100644 index 0000000..8186d21 --- /dev/null +++ b/examples/gemini_demo.py @@ -0,0 +1,33 @@ +import asyncio +import os +from ai_sdk import gemini, generate_text, stream_text + +from dotenv import load_dotenv + +load_dotenv() + +if not os.getenv("GEMINI_API_KEY"): + print("Please set GEMINI_API_KEY environment variable.") + exit(1) + +async def main(): + model = gemini("gemini-2.5-flash") + + print("--- Generating Text ---") + response = generate_text( + model=model, + prompt="Tell me a joke about programming.", + ) + print(response.text) + + print("\n--- Streaming Text ---") + result = stream_text( + model=model, + prompt="Write a haiku about Python. Make it LONG.", + ) + async for delta in result.text_stream: + print(delta, end="", flush=True) + print() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/ai_sdk/__init__.py b/src/ai_sdk/__init__.py index 09d95f3..79a701f 100644 --- a/src/ai_sdk/__init__.py +++ b/src/ai_sdk/__init__.py @@ -5,6 +5,7 @@ from .providers.openai import openai from .tool import tool, Tool from .providers.anthropic import anthropic +from .providers.gemini import gemini from .agent import Agent """Public entry-point for the *Python* port of Vercel's AI SDK. @@ -26,6 +27,7 @@ "cosine_similarity", "openai", "anthropic", + "gemini", "tool", "Tool", "Agent", diff --git a/src/ai_sdk/providers/gemini.py b/src/ai_sdk/providers/gemini.py new file mode 100644 index 0000000..da91885 --- /dev/null +++ b/src/ai_sdk/providers/gemini.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import os +from typing import Any, AsyncIterator, Dict, List, Optional +from openai import OpenAI # type: ignore[import] + +from .language_model import LanguageModel +from .openai import _build_chat_messages + + +class GeminiModel(LanguageModel): + """OpenAI SDK compatibility provider for Google Gemini models.""" + + def __init__( + self, + model: str, + *, + api_key: Optional[str] = None, + base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai/", + **default_kwargs: Any, + ) -> None: + # Use OpenAI SDK to talk to Gemini endpoint + # Resolve API key from argument or environment variable + api_key = api_key or os.getenv("GEMINI_API_KEY") + self._client = OpenAI(api_key=api_key, base_url=base_url) + self._model = model + self._default_kwargs = default_kwargs + + def generate_text( + self, + *, + prompt: str | None = None, + system: str | None = None, + messages: Optional[List[Dict[str, Any]]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """Generate a completion via the OpenAI SDK compatibility layer against Gemini.""" + if prompt is None and not messages: + raise ValueError("Either 'prompt' or 'messages' must be provided.") + + # Build chat messages using OpenAI helper + chat_messages = _build_chat_messages( + prompt=prompt, system=system, messages=messages + ) + + # Merge default kwargs with call-site overrides + request_kwargs: Dict[str, Any] = {**self._default_kwargs, **kwargs} + + # Call via OpenAI SDK + resp = self._client.chat.completions.create( + model=self._model, + messages=chat_messages, + **request_kwargs, + ) + + choice = resp.choices[0] + text = choice.message.content or "" + finish_reason = choice.finish_reason or "unknown" + + # Extract tool_calls if present + tool_calls = [] + if getattr(choice.message, "tool_calls", None): + import json as _json + + for call in choice.message.tool_calls: # type: ignore[attr-defined] + try: + args = _json.loads(call.function.arguments) + except Exception: + args = {"raw": call.function.arguments} + tool_calls.append( + { + "tool_call_id": call.id, + "tool_name": call.function.name, + "args": args, + } + ) + finish_reason = "tool" + + # Usage if available + usage = resp.usage.model_dump() if hasattr(resp, "usage") else None + return { + "text": text, + "finish_reason": finish_reason, + "usage": usage, + "raw_response": resp, + "tool_calls": tool_calls or None, + } + + def stream_text( + self, + *, + prompt: str | None = None, + system: str | None = None, + messages: Optional[List[Dict[str, Any]]] = None, + **kwargs: Any, + ) -> AsyncIterator[str]: + """Stream deltas via OpenAI SDK compatibility.""" + if prompt is None and not messages: + raise ValueError("Either 'prompt' or 'messages' must be provided.") + + # Build messages and merge kwargs + chat_messages = _build_chat_messages( + prompt=prompt, system=system, messages=messages + ) + request_kwargs: Dict[str, Any] = {**self._default_kwargs, **kwargs} + + # Use AsyncOpenAI for streaming to avoid threading issues + from openai import AsyncOpenAI + + async_client = AsyncOpenAI(api_key=self._client.api_key, base_url=self._client.base_url) + + async def _generator() -> AsyncIterator[str]: + stream = await async_client.chat.completions.create( + model=self._model, + messages=chat_messages, + stream=True, + **request_kwargs, + ) + async for chunk in stream: + delta = chunk.choices[0].delta + content = getattr(delta, "content", None) + if content: + yield content + + await async_client.close() + + return _generator() + + +# Public factory helper +def gemini( + model: str, + *, + api_key: Optional[str] = None, + base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai/", + **default_kwargs: Any, +) -> GeminiModel: + """Return a configured GeminiModel instance using OpenAI SDK compatibility.""" + return GeminiModel(model, api_key=api_key, base_url=base_url, **default_kwargs) From b59e23cecf2e9fba01af09fe971789449370e8df Mon Sep 17 00:00:00 2001 From: ynbh Date: Wed, 10 Dec 2025 20:16:05 -0500 Subject: [PATCH 2/4] fix(tool): support decorator usage by making 'execute' optional --- src/ai_sdk/tool.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/src/ai_sdk/tool.py b/src/ai_sdk/tool.py index a4f0265..bac4aa1 100644 --- a/src/ai_sdk/tool.py +++ b/src/ai_sdk/tool.py @@ -3,7 +3,7 @@ A *Tool* couples a JSON schema (name, description, parameters) with a Python handler function. The :func:`tool` decorator behaves similar to the JavaScript -version – it takes the manifest as its first call and then expects a function +version - it takes the manifest as its first call and then expects a function that implements the tool logic:: @tool({ @@ -103,8 +103,8 @@ def tool( name: str, description: str, parameters: Dict[str, Any] | Type[BaseModel], - execute: HandlerFn -) -> "Tool": # noqa: D401 + execute: HandlerFn | None = None +) -> "Tool" | Callable[[HandlerFn], "Tool"]: # noqa: D401 '''Create a :class:`ai_sdk.tool.Tool` from a Python callable. Parameters @@ -159,9 +159,9 @@ def tool( ... return x * 2 ''' - if not all([name, description, parameters, execute]): + if not all([name, description, parameters]): raise ValueError( - "'name', 'description', 'parameters', and 'execute' are required" + "'name', 'description', and 'parameters' are required" ) # Handle Pydantic model vs JSON schema @@ -176,10 +176,24 @@ def tool( "parameters must be either a JSON schema dict or a Pydantic model class" ) - return Tool( - name=name, - description=description, - parameters=parameters_dict, - handler=execute, - _pydantic_model=pydantic_model - ) + # If execute is provided (functional usage), return the Tool immediately + if execute is not None: + return Tool( + name=name, + description=description, + parameters=parameters_dict, + handler=execute, + _pydantic_model=pydantic_model + ) + + # Otherwise (decorator usage), return a wrapper that accepts the function + def wrapper(func: HandlerFn) -> Tool: + return Tool( + name=name, + description=description, + parameters=parameters_dict, + handler=func, + _pydantic_model=pydantic_model + ) + + return wrapper From e7100c10f5613de0e0a5eb317050353dcfa3fd3a Mon Sep 17 00:00:00 2001 From: ynbh Date: Wed, 10 Dec 2025 20:16:59 -0500 Subject: [PATCH 3/4] test: add tests for gemini provider --- tests/test_gemini.py | 87 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 tests/test_gemini.py diff --git a/tests/test_gemini.py b/tests/test_gemini.py new file mode 100644 index 0000000..7cfad2f --- /dev/null +++ b/tests/test_gemini.py @@ -0,0 +1,87 @@ +import os +import pytest +from ai_sdk import gemini, generate_text, stream_text, generate_object +from pydantic import BaseModel +from dotenv import load_dotenv +from ai_sdk.tool import tool +from pydantic import BaseModel, Field + +load_dotenv() + +# Skip all tests if GEMINI_API_KEY is missing +pytestmark = pytest.mark.skipif( + not os.getenv("GEMINI_API_KEY"), + reason="GEMINI_API_KEY environment variable not set", +) + +MODEL_ID = "gemini-2.5-flash" + +@pytest.fixture +def model(): + if not os.getenv("GEMINI_API_KEY"): + pytest.skip("GEMINI_API_KEY not set") + return gemini(MODEL_ID) + + +@pytest.mark.asyncio +async def test_gemini_generate_text(model): + """Basic prompt-only generation.""" + res = generate_text(model=model, prompt="Say 'hello' and nothing else.") + assert "hello" in res.text.lower() + assert res.finish_reason is not None + assert res.usage is not None + + +@pytest.mark.asyncio +async def test_gemini_stream_text(model): + """Streaming generation.""" + result = stream_text(model=model, prompt="Count to 3.") + collected = [] + async for delta in result.text_stream: + collected.append(delta) + full_text = await result.text() + assert len(collected) > 0 + assert len(full_text) > 0 + + +class MathResponse(BaseModel): + answer: int + steps: str + + +@pytest.mark.asyncio +async def test_gemini_generate_object(model): + """Structured output generation.""" + res = generate_object( + model=model, + schema=MathResponse, + prompt="What is 2 + 2?", + ) + assert isinstance(res.object, MathResponse) + assert res.object.answer == 4 + +class AddParams(BaseModel): + a: float = Field(description="First number") + b: float = Field(description="Second number") + +@tool( + name="add", + description="Add two numbers.", + parameters=AddParams +) +def add_tool(a: float, b: float) -> float: + return a + b + + +@pytest.mark.asyncio +async def test_gemini_tool_call(model): + """Tool call generation.""" + res = generate_text( + model=model, + prompt="What is 3+3?", + tools=[ + add_tool, + ], + ) + assert "6" in res.text + From b990a155291410a3988af2b3978029cb4643f2d6 Mon Sep 17 00:00:00 2001 From: ynbh Date: Wed, 10 Dec 2025 20:44:10 -0500 Subject: [PATCH 4/4] style: apply ruff formatting and fix tests --- examples/gemini_demo.py | 2 ++ pytest.ini | 2 ++ src/ai_sdk/embed.py | 1 + src/ai_sdk/providers/gemini.py | 12 ++++---- src/ai_sdk/tool.py | 54 ++++++++++++++++------------------ src/ai_sdk/types.py | 1 + tests/manual_test.py | 1 - tests/test_gemini.py | 13 ++++---- tests/test_tool_calling.py | 34 +++++++++++++++++---- 9 files changed, 72 insertions(+), 48 deletions(-) create mode 100644 pytest.ini diff --git a/examples/gemini_demo.py b/examples/gemini_demo.py index 8186d21..153e5c1 100644 --- a/examples/gemini_demo.py +++ b/examples/gemini_demo.py @@ -10,6 +10,7 @@ print("Please set GEMINI_API_KEY environment variable.") exit(1) + async def main(): model = gemini("gemini-2.5-flash") @@ -29,5 +30,6 @@ async def main(): print(delta, end="", flush=True) print() + if __name__ == "__main__": asyncio.run(main()) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..2f4c80e --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto diff --git a/src/ai_sdk/embed.py b/src/ai_sdk/embed.py index 91b0d6d..e90da61 100644 --- a/src/ai_sdk/embed.py +++ b/src/ai_sdk/embed.py @@ -93,6 +93,7 @@ def cosine_similarity(vec_a: Sequence[float], vec_b: Sequence[float]) -> float: return dot / (norm_a * norm_b) + # --------------------------------------------------------------------------- # Public helpers # --------------------------------------------------------------------------- diff --git a/src/ai_sdk/providers/gemini.py b/src/ai_sdk/providers/gemini.py index da91885..653a688 100644 --- a/src/ai_sdk/providers/gemini.py +++ b/src/ai_sdk/providers/gemini.py @@ -56,7 +56,7 @@ def generate_text( choice = resp.choices[0] text = choice.message.content or "" finish_reason = choice.finish_reason or "unknown" - + # Extract tool_calls if present tool_calls = [] if getattr(choice.message, "tool_calls", None): @@ -106,9 +106,11 @@ def stream_text( # Use AsyncOpenAI for streaming to avoid threading issues from openai import AsyncOpenAI - - async_client = AsyncOpenAI(api_key=self._client.api_key, base_url=self._client.base_url) - + + async_client = AsyncOpenAI( + api_key=self._client.api_key, base_url=self._client.base_url + ) + async def _generator() -> AsyncIterator[str]: stream = await async_client.chat.completions.create( model=self._model, @@ -121,7 +123,7 @@ async def _generator() -> AsyncIterator[str]: content = getattr(delta, "content", None) if content: yield content - + await async_client.close() return _generator() diff --git a/src/ai_sdk/tool.py b/src/ai_sdk/tool.py index bac4aa1..2ffd1cb 100644 --- a/src/ai_sdk/tool.py +++ b/src/ai_sdk/tool.py @@ -20,10 +20,10 @@ def double(x: int) -> int: # noqa: D401 – simple demo # Or using Pydantic models for better type safety: from pydantic import BaseModel - + class DoubleParams(BaseModel): x: int - + @tool(name="double", description="Double the given integer.", parameters=DoubleParams) def double(x: int) -> int: return x * 2 @@ -47,13 +47,13 @@ def double(x: int) -> int: def _pydantic_to_json_schema(model: Type[BaseModel]) -> Dict[str, Any]: """Convert a Pydantic model to JSON schema format.""" schema = model.model_json_schema() - + # Ensure we have the required structure for OpenAI function calling if "properties" not in schema: schema["properties"] = {} if "required" not in schema: schema["required"] = [] - + return schema @@ -86,7 +86,7 @@ async def run(self, **kwargs: Any) -> Any: # noqa: D401 – mirrors JS SDK if self._pydantic_model is not None: validated_data = self._pydantic_model(**kwargs) kwargs = validated_data.model_dump() - + result = self.handler(**kwargs) if inspect.isawaitable(result): return await result # type: ignore[return-value] @@ -99,13 +99,13 @@ async def run(self, **kwargs: Any) -> Any: # noqa: D401 – mirrors JS SDK def tool( - *, - name: str, - description: str, - parameters: Dict[str, Any] | Type[BaseModel], - execute: HandlerFn | None = None + *, + name: str, + description: str, + parameters: Dict[str, Any] | Type[BaseModel], + execute: HandlerFn | None = None, ) -> "Tool" | Callable[[HandlerFn], "Tool"]: # noqa: D401 - '''Create a :class:`ai_sdk.tool.Tool` from a Python callable. + """Create a :class:`ai_sdk.tool.Tool` from a Python callable. Parameters ---------- @@ -130,7 +130,7 @@ def tool( Examples -------- Using JSON schema directly: - + >>> @tool( ... name="double", ... description="Double the given integer.", @@ -144,12 +144,12 @@ def tool( ... return x * 2 Using Pydantic model for better type safety: - + >>> from pydantic import BaseModel - >>> + >>> >>> class DoubleParams(BaseModel): ... x: int - ... + ... >>> @tool( ... name="double", ... description="Double the given integer.", @@ -157,12 +157,10 @@ def tool( ... ) ... def double(x: int) -> int: ... return x * 2 - ''' + """ if not all([name, description, parameters]): - raise ValueError( - "'name', 'description', and 'parameters' are required" - ) + raise ValueError("'name', 'description', and 'parameters' are required") # Handle Pydantic model vs JSON schema pydantic_model = None @@ -179,21 +177,21 @@ def tool( # If execute is provided (functional usage), return the Tool immediately if execute is not None: return Tool( - name=name, - description=description, - parameters=parameters_dict, + name=name, + description=description, + parameters=parameters_dict, handler=execute, - _pydantic_model=pydantic_model + _pydantic_model=pydantic_model, ) # Otherwise (decorator usage), return a wrapper that accepts the function def wrapper(func: HandlerFn) -> Tool: return Tool( - name=name, - description=description, - parameters=parameters_dict, + name=name, + description=description, + parameters=parameters_dict, handler=func, - _pydantic_model=pydantic_model + _pydantic_model=pydantic_model, ) - + return wrapper diff --git a/src/ai_sdk/types.py b/src/ai_sdk/types.py index 4fc7f72..387cf50 100644 --- a/src/ai_sdk/types.py +++ b/src/ai_sdk/types.py @@ -12,6 +12,7 @@ previous dataclass-based implementation, so no changes are required in existing downstream code. """ + from __future__ import annotations from datetime import datetime diff --git a/tests/manual_test.py b/tests/manual_test.py index 39c7917..59890a0 100644 --- a/tests/manual_test.py +++ b/tests/manual_test.py @@ -23,7 +23,6 @@ tool, embed_many, cosine_similarity, - anthropic, Agent, ) from ai_sdk.types import CoreSystemMessage, CoreUserMessage, TextPart, AnyMessage diff --git a/tests/test_gemini.py b/tests/test_gemini.py index 7cfad2f..c8cdf4b 100644 --- a/tests/test_gemini.py +++ b/tests/test_gemini.py @@ -1,7 +1,6 @@ import os import pytest from ai_sdk import gemini, generate_text, stream_text, generate_object -from pydantic import BaseModel from dotenv import load_dotenv from ai_sdk.tool import tool from pydantic import BaseModel, Field @@ -16,6 +15,7 @@ MODEL_ID = "gemini-2.5-flash" + @pytest.fixture def model(): if not os.getenv("GEMINI_API_KEY"): @@ -60,20 +60,18 @@ async def test_gemini_generate_object(model): assert isinstance(res.object, MathResponse) assert res.object.answer == 4 + class AddParams(BaseModel): a: float = Field(description="First number") b: float = Field(description="Second number") -@tool( - name="add", - description="Add two numbers.", - parameters=AddParams -) + +@tool(name="add", description="Add two numbers.", parameters=AddParams) def add_tool(a: float, b: float) -> float: return a + b -@pytest.mark.asyncio +@pytest.mark.asyncio async def test_gemini_tool_call(model): """Tool call generation.""" res = generate_text( @@ -84,4 +82,3 @@ async def test_gemini_tool_call(model): ], ) assert "6" in res.text - diff --git a/tests/test_tool_calling.py b/tests/test_tool_calling.py index b078745..906ab8e 100644 --- a/tests/test_tool_calling.py +++ b/tests/test_tool_calling.py @@ -22,7 +22,7 @@ def __exit__(self, *args): from ai_sdk import generate_text, tool from ai_sdk.providers.language_model import LanguageModel -from ai_sdk.types import CoreToolMessage + # --------------------------------------------------------------------------- # Dummy provider – emulates tool calling behaviour without external network @@ -69,9 +69,30 @@ def generate_text( # The dummy model *echoes* whatever the tool result was. In a real # conversation the LLM would continue reasoning here. last_tool_msg = ( - messages[-1] if messages else CoreToolMessage(role="tool", content=[]) + messages[-1] if messages else {"role": "tool", "content": []} ) - result_value = json.loads(last_tool_msg.content[0].result) + # Handle both dict and object access for robust testing + content = last_tool_msg.get("content") if isinstance(last_tool_msg, dict) else last_tool_msg.content + + # Content might be a list of ToolResult objects or dicts + first_content = content[0] + if hasattr(first_content, "result"): + result_val = first_content.result + elif isinstance(first_content, dict): + result_val = first_content["result"] + else: + # Fallback if it's already the result value (simpler mocks) + result_val = first_content + + # If result is json string, load it + if isinstance(result_val, str): + try: + result_value = json.loads(result_val) + except Exception: + result_value = result_val + else: + result_value = result_val + return { "text": str(result_value), "finish_reason": "stop", @@ -173,12 +194,13 @@ def test_tool_with_pydantic_model(): assert "x" in schema["required"] -def test_tool_execution_with_validation(): +@pytest.mark.asyncio +async def test_tool_execution_with_validation(): """Test that tool execution validates inputs against Pydantic model.""" # Valid input - result = double_tool.run(x=5) + result = await double_tool.run(x=5) assert result == 10 # Invalid input should raise validation error with pytest.raises(Exception): # Pydantic validation error - double_tool.run(x="not an integer") + await double_tool.run(x="not an integer")