Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion eval_protocol/pytest/default_agent_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from openai.types import CompletionUsage
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig
from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm
from pydantic import BaseModel
from typing import Optional

Expand Down Expand Up @@ -251,8 +252,11 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row with agent rollout."""
start_time = time.perf_counter()

# Normalize Fireworks model names for LiteLLM routing
completion_params = normalize_fireworks_model_for_litellm(row.input_metadata.completion_params) or {}
row.input_metadata.completion_params = completion_params
agent = Agent(
model=row.input_metadata.completion_params["model"],
model=completion_params["model"],
row=row,
config_path=config.mcp_config_path,
logger=config.logger,
Expand Down
48 changes: 27 additions & 21 deletions eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from eval_protocol.models import EvaluationRow
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm

from eval_protocol.pytest.default_agent_rollout_processor import Agent
from klavis import Klavis
Expand All @@ -30,15 +31,15 @@ def __init__(
self.server_name = server_name
self.initialize_data_factory = initialize_data_factory
self.klavis_client = Klavis(api_key=os.environ.get("KLAVIS_API_KEY"))

def _init_sandbox(self) -> CreateSandboxResponse:
try:
server_name_enum = SandboxMcpServer(self.server_name)
return self.klavis_client.sandbox.create_sandbox(server_name=server_name_enum)
except Exception as e:
logger.error(f"Error creating sandbox: {str(e)}", exc_info=True)
raise

@staticmethod
def create_mcp_config(server_url: str, server_key: str = "main", auth_token: str | None = None) -> str:
"""Create a temporary MCP config file and return its path."""
Expand All @@ -47,26 +48,24 @@ def create_mcp_config(server_url: str, server_key: str = "main", auth_token: str
server_key: {
"url": server_url,
"transport": "streamable_http",
**({"authorization": f"Bearer {auth_token}"} if auth_token else {})
**({"authorization": f"Bearer {auth_token}"} if auth_token else {}),
}
}
}

# Create a temp file that persists for the session
fd, path = tempfile.mkstemp(suffix=".json", prefix="mcp_config_")
with os.fdopen(fd, 'w') as f:
with os.fdopen(fd, "w") as f:
json.dump(config, f)
return path

def __call__(
self, rows: List[EvaluationRow], config: RolloutProcessorConfig
) -> List[asyncio.Task[EvaluationRow]]:
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
"""Process evaluation rows with Klavis sandbox lifecycle management"""
semaphore = config.semaphore

async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row with complete sandbox lifecycle"""

start_time = time.perf_counter()
agent: Agent | None = None
temp_config_path: str | None = None
Expand All @@ -88,25 +87,32 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
if row.input_metadata is not None
else None
)

if init_data:
logger.info(f"Initializing {self.server_name} sandbox {sandbox.sandbox_id}")
logger.info(f"Initializing {self.server_name} sandbox {sandbox.sandbox_id}") # pyright: ignore[reportOptionalMemberAccess]
initialize_method = getattr(
self.klavis_client.sandbox, f"initialize_{sandbox.server_name.value}_sandbox"
self.klavis_client.sandbox,
f"initialize_{sandbox.server_name.value}_sandbox", # pyright: ignore[reportOptionalMemberAccess]
)
init_response = initialize_method(sandbox_id=sandbox.sandbox_id, **init_data)
init_response = initialize_method(sandbox_id=sandbox.sandbox_id, **init_data) # pyright: ignore[reportOptionalMemberAccess]
logger.info(f"Initialization response: {init_response}")

# Step 2: Create temporary MCP config with sandbox URL
temp_config_path = self.create_mcp_config(
server_url=sandbox.server_url, server_key=sandbox.server_name.value
server_url=sandbox.server_url, # pyright: ignore[reportOptionalMemberAccess]
server_key=sandbox.server_name.value, # pyright: ignore[reportOptionalMemberAccess]
)
logger.info(f"MCP config created: {temp_config_path}")

# Step 3: Run agent with sandbox MCP server
logger.info(f"Running agent for row {row.execution_metadata.rollout_id} with {self.server_name} sandbox")
logger.info(
f"Running agent for row {row.execution_metadata.rollout_id} with {self.server_name} sandbox"
)
# Normalize Fireworks model names for LiteLLM routing
completion_params = normalize_fireworks_model_for_litellm(row.input_metadata.completion_params) or {}
row.input_metadata.completion_params = completion_params
agent = Agent(
model=row.input_metadata.completion_params["model"],
model=completion_params["model"],
row=row,
config_path=temp_config_path,
logger=config.logger,
Expand All @@ -124,16 +130,16 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
logger.info(f"Agent execution completed for row {row.execution_metadata.rollout_id}")

# Step 4: Export sandbox data
dump_method = getattr(self.klavis_client.sandbox, f"dump_{sandbox.server_name.value}_sandbox")
dump_response = dump_method(sandbox_id=sandbox.sandbox_id)
dump_method = getattr(self.klavis_client.sandbox, f"dump_{sandbox.server_name.value}_sandbox") # pyright: ignore[reportOptionalMemberAccess]
dump_response = dump_method(sandbox_id=sandbox.sandbox_id) # pyright: ignore[reportOptionalMemberAccess]
sandbox_data = dump_response.data
logger.info(f"Sandbox data: {sandbox_data}")

# Store sandbox data in row metadata for evaluation
if not row.execution_metadata.extra:
row.execution_metadata.extra = {}
row.execution_metadata.extra["sandbox_data"] = sandbox_data
row.execution_metadata.extra["sandbox_id"] = sandbox.sandbox_id
row.execution_metadata.extra["sandbox_id"] = sandbox.sandbox_id # pyright: ignore[reportOptionalMemberAccess]
row.execution_metadata.extra["server_name"] = self.server_name

except Exception as e:
Expand All @@ -149,7 +155,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
await agent.mcp_client.cleanup()
if temp_config_path and os.path.exists(temp_config_path):
os.unlink(temp_config_path)

# Release sandbox
if sandbox and sandbox.sandbox_id:
try:
Expand Down
18 changes: 11 additions & 7 deletions eval_protocol/pytest/default_mcp_gym_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from eval_protocol.models import EvaluationRow
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig, ServerMode
from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm


class MCPServerManager:
Expand Down Expand Up @@ -280,17 +281,20 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
"Cannot retry without existing server/environments. Call with start_server=True first."
)

model_id = str((config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini")
temperature = config.completion_params.get("temperature", 0.0)
max_tokens = config.completion_params.get("max_tokens", 4096)
# Normalize Fireworks model names for LiteLLM routing
completion_params = normalize_fireworks_model_for_litellm(config.completion_params) or {}
# Update all rows with normalized completion_params
for row in rows:
row.input_metadata.completion_params = completion_params
model_id = str(completion_params.get("model") or "gpt-4o-mini")
temperature = completion_params.get("temperature", 0.0)
max_tokens = completion_params.get("max_tokens", 4096)

# Pass all other completion_params (e.g. stream=True) via kwargs
other_params = {
k: v
for k, v in (config.completion_params or {}).items()
if k not in ["model", "temperature", "max_tokens", "extra_body"]
k: v for k, v in completion_params.items() if k not in ["model", "temperature", "max_tokens", "extra_body"]
}
extra_body = config.completion_params.get("extra_body", {}) or {}
extra_body = completion_params.get("extra_body", {}) or {}

self.policy = ep.LiteLLMPolicy(
model_id=model_id,
Expand Down
23 changes: 12 additions & 11 deletions eval_protocol/pytest/default_single_turn_rollout_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from openai.types import CompletionUsage
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,7 +64,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row asynchronously."""
start_time = time.perf_counter()

if len(row.messages) == 0:
raise ValueError("Messages is empty. Please provide a non-empty dataset")

Expand All @@ -77,7 +78,10 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
# Use the Message class method that excludes unsupported fields
messages_payload = [message.dump_mdoel_for_chat_completion_request() for message in messages_for_request]

request_params = {"messages": messages_payload, **config.completion_params}
# Normalize Fireworks model names for LiteLLM routing
completion_params = normalize_fireworks_model_for_litellm(config.completion_params) or {}
row.input_metadata.completion_params = completion_params
request_params = {"messages": messages_payload, **completion_params}
# Ensure caching is disabled only for this request (review feedback)
request_params["cache"] = {"no-cache": True}

Expand All @@ -87,18 +91,15 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
# Single-level reasoning effort: expect `reasoning_effort` only
effort_val = None

if (
"reasoning_effort" in config.completion_params
and config.completion_params["reasoning_effort"] is not None
):
effort_val = str(config.completion_params["reasoning_effort"]) # flat shape
if "reasoning_effort" in completion_params and completion_params["reasoning_effort"] is not None:
effort_val = str(completion_params["reasoning_effort"]) # flat shape
elif (
isinstance(config.completion_params.get("extra_body"), dict)
and "reasoning_effort" in config.completion_params["extra_body"]
and config.completion_params["extra_body"]["reasoning_effort"] is not None
isinstance(completion_params.get("extra_body"), dict)
and "reasoning_effort" in completion_params["extra_body"]
and completion_params["extra_body"]["reasoning_effort"] is not None
):
# Accept if user passed it directly inside extra_body
effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body
effort_val = str(completion_params["extra_body"]["reasoning_effort"]) # already in extra_body

if effort_val:
# Always under extra_body so LiteLLM forwards to provider-specific param set
Expand Down
3 changes: 0 additions & 3 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
AggregationMethod,
add_cost_metrics,
log_eval_status_and_rows,
normalize_fireworks_model,
parse_ep_completion_params,
parse_ep_completion_params_overwrite,
parse_ep_max_concurrent_rollouts,
Expand Down Expand Up @@ -205,7 +204,6 @@ def evaluation_test(
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
completion_params = parse_ep_completion_params(completion_params)
completion_params = parse_ep_completion_params_overwrite(completion_params)
completion_params = [normalize_fireworks_model(cp) for cp in completion_params]
original_completion_params = completion_params
passed_threshold = parse_ep_passed_threshold(passed_threshold)
data_loaders = parse_ep_dataloaders(data_loaders)
Expand Down Expand Up @@ -366,7 +364,6 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
row.input_metadata.row_id = generate_id(seed=0, index=index)

completion_params = kwargs["completion_params"] if "completion_params" in kwargs else None
completion_params = normalize_fireworks_model(completion_params)
# Create eval metadata with test function info and current commit hash
eval_metadata = EvalMetadata(
name=test_func.__name__,
Expand Down
19 changes: 0 additions & 19 deletions eval_protocol/pytest/evaluation_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,22 +619,3 @@ def build_rollout_processor_config(
server_script_path=None,
kwargs=rollout_processor_kwargs,
)


def normalize_fireworks_model(completion_params: CompletionParams | None) -> CompletionParams | None:
"""Fireworks model names like 'accounts/<org>/models/<model>' need the fireworks_ai/
prefix when routing through LiteLLM. This function adds the prefix if missing.
"""
if completion_params is None:
return None

model = completion_params.get("model")
if (
model
and isinstance(model, str)
and not model.startswith("fireworks_ai/")
and re.match(r"^accounts/[^/]+/models/.+", model)
):
completion_params = completion_params.copy()
completion_params["model"] = f"fireworks_ai/{model}"
return completion_params
7 changes: 7 additions & 0 deletions eval_protocol/pytest/github_action_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .rollout_processor import RolloutProcessor
from .types import RolloutProcessorConfig
from .tracing_utils import default_fireworks_output_data_loader, build_init_request, update_row_with_remote_trace
from .utils import normalize_fireworks_model_for_litellm


class GithubActionRolloutProcessor(RolloutProcessor):
Expand Down Expand Up @@ -80,6 +81,12 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
if row.input_metadata.row_id is None:
raise ValueError("Row ID is required in GithubActionRolloutProcessor")

# Normalize Fireworks model names for LiteLLM routing
config.completion_params = (
normalize_fireworks_model_for_litellm(config.completion_params) or config.completion_params
)
row.input_metadata.completion_params = config.completion_params

init_request = build_init_request(row, config, self.model_base_url)

def _dispatch_workflow():
Expand Down
16 changes: 10 additions & 6 deletions eval_protocol/pytest/openenv_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -177,15 +178,18 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
logger.debug("[OpenEnvRolloutProcessor] Environment client created successfully")

try:
# Normalize Fireworks model names for LiteLLM routing
completion_params = normalize_fireworks_model_for_litellm(config.completion_params) or {}
row.input_metadata.completion_params = completion_params
# Get model config
raw_model = config.completion_params.get("model", "gpt-4o-mini")
raw_model = completion_params.get("model", "gpt-4o-mini")
model = raw_model
temperature = config.completion_params.get("temperature", 0.0)
max_tokens = config.completion_params.get("max_tokens", 100)
temperature = completion_params.get("temperature", 0.0)
max_tokens = completion_params.get("max_tokens", 100)
# Optional: direct routing or provider overrides (e.g., base_url, api_key, top_p, stop, etc.)
base_url = config.completion_params.get("base_url")
base_url = completion_params.get("base_url")
# Forward any extra completion params to LiteLLMPolicy (they will be sent per-request)
extra_params: Dict[str, Any] = dict(config.completion_params or {})
extra_params: Dict[str, Any] = dict(completion_params)
for _k in ("model", "temperature", "max_tokens", "base_url"):
try:
extra_params.pop(_k, None)
Expand Down Expand Up @@ -247,7 +251,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
messages = list(row.messages) # Copy initial messages
# Inject system prompt if provided and not already present
has_system = any(m.role == "system" for m in messages)
system_prompt = config.completion_params.get("system_prompt")
system_prompt = completion_params.get("system_prompt")
if system_prompt and not has_system:
messages.insert(0, Message(role="system", content=system_prompt))
usage = {
Expand Down
7 changes: 7 additions & 0 deletions eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .rollout_processor import RolloutProcessor
from .types import RolloutProcessorConfig
from .tracing_utils import default_fireworks_output_data_loader, build_init_request, update_row_with_remote_trace
from .utils import normalize_fireworks_model_for_litellm
import logging

import os
Expand Down Expand Up @@ -87,6 +88,12 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
if row.input_metadata.row_id is None:
raise ValueError("Row ID is required in RemoteRolloutProcessor")

# Normalize Fireworks model names for LiteLLM routing
config.completion_params = (
normalize_fireworks_model_for_litellm(config.completion_params) or config.completion_params
)
row.input_metadata.completion_params = config.completion_params

init_payload = build_init_request(row, config, model_base_url)

# Fire-and-poll
Expand Down
Loading
Loading