diff --git a/packages/sdk/server-ai/src/ldai/__init__.py b/packages/sdk/server-ai/src/ldai/__init__.py index 0bc0e76..fe16129 100644 --- a/packages/sdk/server-ai/src/ldai/__init__.py +++ b/packages/sdk/server-ai/src/ldai/__init__.py @@ -2,14 +2,16 @@ from ldclient import log +from ldai.agent_graph import AgentGraphDefinition from ldai.chat import Chat from ldai.client import LDAIClient from ldai.judge import Judge from ldai.models import ( # Deprecated aliases for backward compatibility - AIAgentConfig, AIAgentConfigDefault, AIAgentConfigRequest, AIAgents, - AICompletionConfig, AICompletionConfigDefault, AIConfig, AIJudgeConfig, - AIJudgeConfigDefault, JudgeConfiguration, LDAIAgent, LDAIAgentConfig, - LDAIAgentDefaults, LDMessage, ModelConfig, ProviderConfig) + AIAgentConfig, AIAgentConfigDefault, AIAgentConfigRequest, + AIAgentGraphConfig, AIAgents, AICompletionConfig, + AICompletionConfigDefault, AIConfig, AIJudgeConfig, AIJudgeConfigDefault, + Edge, JudgeConfiguration, LDAIAgent, LDAIAgentConfig, LDAIAgentDefaults, + LDMessage, ModelConfig, ProviderConfig) from ldai.providers.types import EvalScore, JudgeResponse __all__ = [ @@ -18,12 +20,15 @@ 'AIAgentConfigDefault', 'AIAgentConfigRequest', 'AIAgents', + 'AIAgentGraphConfig', + 'Edge', 'AICompletionConfig', 'AICompletionConfigDefault', 'AIJudgeConfig', 'AIJudgeConfigDefault', 'Chat', 'EvalScore', + 'AgentGraphDefinition', 'Judge', 'JudgeConfiguration', 'JudgeResponse', diff --git a/packages/sdk/server-ai/src/ldai/agent_graph/__init__.py b/packages/sdk/server-ai/src/ldai/agent_graph/__init__.py new file mode 100644 index 0000000..a10ab23 --- /dev/null +++ b/packages/sdk/server-ai/src/ldai/agent_graph/__init__.py @@ -0,0 +1,273 @@ +"""Graph implementation for managing AI agent graphs.""" + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Set + +from ldclient import Context + +from ldai.models import AIAgentConfig, AIAgentGraphConfig, Edge + +DEFAULT_FALSE = AIAgentConfig(key="", enabled=False) + + +class AgentGraphNode: + """ + Node in an agent graph. + """ + + def __init__( + self, + key: str, + config: AIAgentConfig, + children: List[Edge], + ): + self._key = key + self._config = config + self._children = children + + def get_key(self) -> str: + """Get the key of the node.""" + return self._key + + def get_config(self) -> AIAgentConfig: + """Get the config of the node.""" + return self._config + + def is_terminal(self) -> bool: + """Check if the node is a terminal node.""" + return len(self._children) == 0 + + def get_edges(self) -> List[Edge]: + """Get the edges of the node.""" + return self._children + + +class AgentGraphDefinition: + """ + Graph implementation for managing AI agent graphs. + """ + enabled: bool + + def __init__( + self, + agent_graph: AIAgentGraphConfig, + nodes: Dict[str, AgentGraphNode], + context: Context, + enabled: bool, + ): + self._agent_graph = agent_graph + self._context = context + self._nodes = nodes + self.enabled = enabled + + def is_enabled(self) -> bool: + """Check if the graph is enabled.""" + return self.enabled + + @staticmethod + def build_nodes( + agent_graph: AIAgentGraphConfig, + graph_nodes: Dict[str, AIAgentConfig], + ) -> Dict[str, "AgentGraphNode"]: + """Build the nodes of the graph into AgentGraphNode objects.""" + nodes = { + agent_graph.root_config_key: AgentGraphNode( + agent_graph.root_config_key, + graph_nodes[agent_graph.root_config_key], + [ + edge + for edge in agent_graph.edges + if edge.source_config == agent_graph.root_config_key + ], + ), + } + + for edge in agent_graph.edges: + nodes[edge.target_config] = AgentGraphNode( + edge.target_config, + graph_nodes[edge.target_config], + [e for e in agent_graph.edges if e.source_config == edge.target_config], + ) + + return nodes + + def get_node(self, key: str) -> Optional[AgentGraphNode]: + """Get a node by its key.""" + return self._nodes.get(key) + + def _get_child_edges(self, config_key: str) -> List[Edge]: + """Get the child edges of the given config.""" + return [ + edge for edge in self._agent_graph.edges if edge.source_config == config_key + ] + + def get_child_nodes(self, node_key: str) -> List[AgentGraphNode]: + """Get the child nodes of the given node key as AgentGraphNode objects.""" + nodes: List[AgentGraphNode] = [] + for edge in self._agent_graph.edges: + if edge.source_config == node_key: + node = self.get_node(edge.target_config) + if node is not None: + nodes.append(node) + return nodes + + def get_parent_nodes(self, node_key: str) -> List[AgentGraphNode]: + """Get the parent nodes of the given node key as AgentGraphNode objects.""" + nodes: List[AgentGraphNode] = [] + for edge in self._agent_graph.edges: + if edge.target_config == node_key: + node = self.get_node(edge.source_config) + if node is not None: + nodes.append(node) + return nodes + + def _collect_nodes( + self, + node: AgentGraphNode, + node_depths: Dict[str, int], + nodes_by_depth: Dict[int, List[AgentGraphNode]], + visited: Set[str], + max_depth: int, + ) -> None: + """Collect all reachable nodes from the given node and group them by depth.""" + node_key = node.get_key() + if node_key in visited: + return + visited.add(node_key) + + # Use max_depth for nodes not in node_depths to ensure they execute last + node_depth = node_depths.get(node_key, max_depth) + if node_depth not in nodes_by_depth: + nodes_by_depth[node_depth] = [] + nodes_by_depth[node_depth].append(node) + + for child in self.get_child_nodes(node_key): + self._collect_nodes(child, node_depths, nodes_by_depth, visited, max_depth) + + def terminal_nodes(self) -> List[AgentGraphNode]: + """Get the terminal nodes of the graph, meaning any nodes without children.""" + return [ + node + for node in self._nodes.values() + if len(self.get_child_nodes(node.get_key())) == 0 + ] + + def root(self) -> Optional[AgentGraphNode]: + """Get the root node of the graph.""" + return self._nodes.get(self._agent_graph.root_config_key) + + def traverse( + self, + fn: Callable[["AgentGraphNode", Dict[str, Any]], Any], + execution_context: Optional[Dict[str, Any]] = None, + ) -> Any: + """Traverse from the root down to terminal nodes, visiting nodes in order of depth. + Nodes with the longest paths from the root (deepest nodes) will always be visited last.""" + if execution_context is None: + execution_context = {} + + root_node = self.root() + if root_node is None: + return + + node_depths: Dict[str, int] = {root_node.get_key(): 0} + current_level: List[AgentGraphNode] = [root_node] + depth = 0 + max_depth_limit = 10 # Infinite loop protection limit + max_depth_encountered = 0 + seen_nodes: Set[str] = {root_node.get_key()} + + while current_level: + next_level: List[AgentGraphNode] = [] + depth += 1 + + for node in current_level: + node_key = node.get_key() + for child in self.get_child_nodes(node_key): + child_key = child.get_key() + if depth <= max_depth_limit: + # Defer this child to the next level if it's at a longer path + if child_key not in node_depths or depth > node_depths[child_key]: + node_depths[child_key] = depth + max_depth_encountered = max(max_depth_encountered, depth) + # Add to next level if not already visited (prevents cycles) + if child_key not in seen_nodes: + seen_nodes.add(child_key) + next_level.append(child) + else: + max_depth_encountered = max(max_depth_encountered, depth) + if child_key not in seen_nodes: + # Push this to the next level to be visited + seen_nodes.add(child_key) + next_level.append(child) + + current_level = next_level + + # Use max_depth_limit + 1 to ensure they execute after all recorded nodes + max_depth = max(max_depth_limit + 1, max_depth_encountered + 1) + + # Group all nodes by depth + nodes_by_depth: Dict[int, List[AgentGraphNode]] = {} + # New visited for children nodes + visited: Set[str] = set() + + self._collect_nodes(root_node, node_depths, nodes_by_depth, visited, max_depth) + # Execute the lambda at this level for the nodes at this depth + for depth_level in sorted(nodes_by_depth.keys()): + for node in nodes_by_depth[depth_level]: + execution_context[node.get_key()] = fn(node, execution_context) + + return execution_context[self._agent_graph.root_config_key] + + def reverse_traverse( + self, + fn: Callable[["AgentGraphNode", Dict[str, Any]], Any], + execution_context: Optional[Dict[str, Any]] = None, + ) -> Any: + """Traverse from terminal nodes up to the root, visiting nodes level by level. + The root node will always be visited last, even if multiple paths converge at it.""" + if execution_context is None: + execution_context = {} + + terminal_nodes = self.terminal_nodes() + if not terminal_nodes: + return + + visited: Set[str] = set() + current_level: List[AgentGraphNode] = terminal_nodes + root_key = self._agent_graph.root_config_key + root_node_seen = False + + while current_level: + next_level: List[AgentGraphNode] = [] + + for node in current_level: + node_key = node.get_key() + if node_key in visited: + continue + + visited.add(node_key) + # Skip the root node if we reach a terminus, it will be visited last + if node_key == root_key: + root_node_seen = True + continue + + execution_context[node_key] = fn(node, execution_context) + + for parent in self.get_parent_nodes(node_key): + parent_key = parent.get_key() + if parent_key not in visited: + next_level.append(parent) + + current_level = next_level + + # If we saw the root node, append it at the end as it'll always be the last node in a + # reverse traversal (this should always happen, non-contiguous graphs are invalid) + if root_node_seen: + root_node = self.root() + if root_node is not None: + execution_context[root_node.get_key()] = fn( + root_node, execution_context + ) + + return execution_context[self._agent_graph.root_config_key] diff --git a/packages/sdk/server-ai/src/ldai/client.py b/packages/sdk/server-ai/src/ldai/client.py index 47465ef..a139901 100644 --- a/packages/sdk/server-ai/src/ldai/client.py +++ b/packages/sdk/server-ai/src/ldai/client.py @@ -5,13 +5,15 @@ from ldclient.client import LDClient from ldai import log +from ldai.agent_graph import AgentGraphDefinition from ldai.chat import Chat from ldai.judge import Judge from ldai.models import (AIAgentConfig, AIAgentConfigDefault, - AIAgentConfigRequest, AIAgents, AICompletionConfig, - AICompletionConfigDefault, AIJudgeConfig, - AIJudgeConfigDefault, JudgeConfiguration, LDMessage, - ModelConfig, ProviderConfig) + AIAgentConfigRequest, AIAgentGraphConfig, AIAgents, + AICompletionConfig, AICompletionConfigDefault, + AIJudgeConfig, AIJudgeConfigDefault, Edge, + JudgeConfiguration, LDMessage, ModelConfig, + ProviderConfig) from ldai.providers.ai_provider_factory import AIProviderFactory from ldai.tracker import LDAIConfigTracker @@ -419,6 +421,102 @@ def agent_configs( return result + def agent_graph( + self, + key: str, + context: Context, + ) -> AgentGraphDefinition: + """` + Retrieve an AI agent graph. + """ + variation = self._client.variation(key, context, {}) + + if not variation.get("rootConfigKey"): + log.debug(f"Agent graph {key} is disabled, no root config key found") + return AgentGraphDefinition( + AIAgentGraphConfig( + key=key, + name="", + root_config_key="", + edges=[], + description="", + enabled=False, + ), + nodes={}, + context=context, + enabled=False, + ) + + all_agent_keys = [variation["rootConfigKey"]] + [ + edge.get("targetConfig", "") for edge in variation.get("edges", []) if edge.get("targetConfig") + ] + agent_configs = { + key: self.agent_config(key, context, AIAgentConfigDefault(enabled=False)) + for key in all_agent_keys + } + + if not all(config.enabled for config in agent_configs.values()): + log.debug( + f"Agent graph {key} is disabled, not all agent configs are enabled" + ) + return AgentGraphDefinition( + AIAgentGraphConfig( + key=key, + name="", + root_config_key="", + edges=[], + description="", + enabled=False, + ), + nodes={}, + context=context, + enabled=False, + ) + + try: + agent_graph_config = AIAgentGraphConfig( + key=variation["key"], + name=variation["name"], + root_config_key=variation["rootConfigKey"], + edges=[ + Edge( + key=edge.get("key", ""), + source_config=edge.get("sourceConfig", ""), + target_config=edge.get("targetConfig", ""), + handoff=edge.get("handoff", {}), + ) + for edge in variation["edges"] + ], + description=variation["description"], + ) + except Exception as e: + log.debug(f"Agent graph {key} is disabled, invalid agent graph config") + return AgentGraphDefinition( + AIAgentGraphConfig( + key=key, + name="", + root_config_key="", + edges=[], + description="", + enabled=False, + ), + nodes={}, + context=context, + enabled=False, + ) + + nodes = AgentGraphDefinition.build_nodes( + agent_graph_config, + agent_configs, + ) + + return AgentGraphDefinition( + agent_graph=agent_graph_config, + nodes=nodes, + context=context, + enabled=agent_graph_config.enabled, + ) + def agents( self, agent_configs: List[AIAgentConfigRequest], diff --git a/packages/sdk/server-ai/src/ldai/models.py b/packages/sdk/server-ai/src/ldai/models.py index 988d97d..6c058a7 100644 --- a/packages/sdk/server-ai/src/ldai/models.py +++ b/packages/sdk/server-ai/src/ldai/models.py @@ -338,6 +338,39 @@ class AIAgentConfigRequest: # Type alias for all AI Config variants AIConfigKind = Union[AIAgentConfig, AICompletionConfig, AIJudgeConfig] +# ============================================================================ +# AI Config Agent Graph Edge Type +# ============================================================================ + + +@dataclass +class Edge: + """ + Edge configuration for an agent graph. + """ + + key: str + source_config: str + target_config: str + handoff: Optional[dict] = field(default_factory=dict) + + +# ============================================================================ +# AI Config Agent Graph +# ============================================================================ +@dataclass +class AIAgentGraphConfig: + """ + Agent graph configuration. + """ + + key: str + name: str + root_config_key: str + edges: List[Edge] + description: Optional[str] = "" + enabled: bool = True + # ============================================================================ # Deprecated Type Aliases for Backward Compatibility diff --git a/packages/sdk/server-ai/tests/test_agent_graph.py b/packages/sdk/server-ai/tests/test_agent_graph.py new file mode 100644 index 0000000..de584de --- /dev/null +++ b/packages/sdk/server-ai/tests/test_agent_graph.py @@ -0,0 +1,439 @@ +import pytest +from ldclient import Config, Context, LDClient +from ldclient.integrations.test_data import TestData + +from ldai import ( + LDAIClient, + AIAgentGraphConfig, + AgentGraphDefinition, + AIAgentConfig, + Edge, +) + + +@pytest.fixture +def td() -> TestData: + td = TestData.data_source() + # Agent graph with depth of 1 + td.update( + td.flag("test-agent-graph") + .variations( + { + "key": "test-agent-graph", + "name": "Test Agent Graph", + "rootConfigKey": "customer-support-agent", + "edges": [ + { + "key": "edge-customer-support-agent-personalized-agent", + "sourceConfig": "customer-support-agent", + "targetConfig": "personalized-agent", + }, + { + "key": "edge-customer-support-agent-multi-context-agent", + "sourceConfig": "customer-support-agent", + "targetConfig": "multi-context-agent", + }, + { + "key": "edge-customer-support-agent-minimal-agent", + "sourceConfig": "customer-support-agent", + "targetConfig": "minimal-agent", + }, + ], + "description": "Test agent graph", + "_ldMeta": { + "enabled": True, + "variationKey": "test-agent-graph", + "version": 1, + }, + } + ) + .variation_for_all(0) + ) + # Agent graph with depth of 3 + td.update( + td.flag("test-agent-graph-depth-3") + .variations( + { + "key": "test-agent-graph-depth-3", + "name": "Test Agent Graph with Depth of 3", + "rootConfigKey": "customer-support-agent", + "edges": [ + { + "key": "edge-customer-support-agent-personalized-agent", + "sourceConfig": "customer-support-agent", + "targetConfig": "personalized-agent", + "handoff": { "state": "from-root-to-personalized" } + }, + { + "key": "edge-personalized-agent-multi-context-agent", + "sourceConfig": "personalized-agent", + "targetConfig": "multi-context-agent", + }, + { + "key": "edge-multi-context-agent-minimal-agent", + "sourceConfig": "multi-context-agent", + "targetConfig": "minimal-agent", + "handoff": {"state": "from-multi-context-to-minimal"}, + }, + { + "key": "edge-customer-support-agent-minimal-agent", + "sourceConfig": "customer-support-agent", + "targetConfig": "minimal-agent", + "handoff": { "state": "from-root-to-minimal" } + }, + ], + "description": "Test agent graph with depth of 3", + "_ldMeta": { + "enabled": True, + "variationKey": "test-agent-graph-depth-3", + "version": 1, + }, + } + ) + .variation_for_all(0) + ) + + # Agent graph with disabled agent included - invalid + td.update( + td.flag("test-agent-graph-disabled-agent") + .variations( + { + "key": "test-agent-graph-disabled-agent", + "name": "Test Agent Graph with Disabled Agent", + "rootConfigKey": "customer-support-agent", + "edges": [ + { + "key": "edge-customer-support-agent-personalized-agent", + "sourceConfig": "customer-support-agent", + "targetConfig": "disabled-agent", + }, + ], + "description": "Test agent graph with disabled agent", + "_ldMeta": { + "enabled": True, + "variationKey": "test-agent-graph-disabled-agent", + "version": 1, + }, + } + ) + .variation_for_all(0) + ) + + # Agent graph with no root key - invalid + td.update( + td.flag("test-agent-graph-no-root-key") + .variations( + { + "name": "Test Agent Graph with No Root Key", + "key": "test-agent-graph-no-root-key", + "edges": [], + } + ) + .variation_for_all(0) + ) + + # Single agent with instructions + td.update( + td.flag("customer-support-agent") + .variations( + { + "model": { + "name": "gpt-4", + "parameters": {"temperature": 0.3, "maxTokens": 2048}, + }, + "provider": {"name": "openai"}, + "instructions": "You are a helpful customer support agent for {{company_name}}. Always be polite and professional.", + "_ldMeta": { + "enabled": True, + "variationKey": "agent-v1", + "version": 1, + "mode": "agent", + }, + } + ) + .variation_for_all(0) + ) + + # Agent with context interpolation + td.update( + td.flag("personalized-agent") + .variations( + { + "model": {"name": "claude-3", "parameters": {"temperature": 0.5}}, + "instructions": "Hello {{ldctx.name}}! I am your personal assistant. Your user key is {{ldctx.key}}.", + "_ldMeta": { + "enabled": True, + "variationKey": "personal-v1", + "version": 2, + "mode": "agent", + }, + } + ) + .variation_for_all(0) + ) + + # Agent with multi-context interpolation + td.update( + td.flag("multi-context-agent") + .variations( + { + "model": {"name": "gpt-3.5-turbo"}, + "instructions": "Welcome {{ldctx.user.name}} from {{ldctx.org.name}}! Your organization tier is {{ldctx.org.tier}}.", + "_ldMeta": { + "enabled": True, + "variationKey": "multi-v1", + "version": 1, + "mode": "agent", + }, + } + ) + .variation_for_all(0) + ) + + # Disabled agent + td.update( + td.flag("disabled-agent") + .variations( + { + "model": {"name": "gpt-4"}, + "instructions": "This agent is disabled.", + "_ldMeta": { + "enabled": False, + "variationKey": "disabled-v1", + "version": 1, + "mode": "agent", + }, + } + ) + .variation_for_all(0) + ) + + # Agent with minimal metadata + td.update( + td.flag("minimal-agent") + .variations( + { + "instructions": "Minimal agent configuration.", + "_ldMeta": {"enabled": True}, + } + ) + .variation_for_all(0) + ) + + return td + + +@pytest.fixture +def client(td: TestData) -> LDClient: + config = Config("sdk-key", update_processor_class=td, send_events=False) + return LDClient(config=config) + + +@pytest.fixture +def ldai_client(client: LDClient) -> LDAIClient: + return LDAIClient(client) + + +def test_agent_graph_method(ldai_client: LDAIClient): + graph = ldai_client.agent_graph("test-agent-graph", Context.create("user-key")) + + assert graph.enabled is True + assert graph is not None + assert graph.root() is not None + assert graph.root().get_key() == "customer-support-agent" + assert len(graph.get_child_nodes("customer-support-agent")) == 3 + assert len(graph.get_child_nodes("personalized-agent")) == 0 + assert len(graph.get_child_nodes("multi-context-agent")) == 0 + assert len(graph.get_child_nodes("minimal-agent")) == 0 + + +def test_agent_graph_method_disabled_agent(ldai_client: LDAIClient): + graph = ldai_client.agent_graph( + "test-agent-graph-disabled-agent", Context.create("user-key") + ) + + assert graph.enabled is False + assert graph.root() is None + + +def test_agent_graph_method_no_root_key(ldai_client: LDAIClient): + graph = ldai_client.agent_graph( + "test-agent-graph-no-root-key", Context.create("user-key") + ) + + assert graph.enabled is False + assert graph.root() is None + + +def test_agent_graph_build_nodes(ldai_client: LDAIClient): + graph_config = ldai_client._client.variation( + "test-agent-graph", Context.create("user-key"), {} + ) + + ai_graph_config = AIAgentGraphConfig( + key=graph_config["key"], + name=graph_config["name"], + root_config_key=graph_config["rootConfigKey"], + edges=[ + Edge( + key=edge.get("key", ""), + source_config=edge.get("sourceConfig", ""), + target_config=edge.get("targetConfig", ""), + handoff=edge.get("handoff", {}), + ) + for edge in graph_config["edges"] + ], + description=graph_config["description"], + ) + + nodes = AgentGraphDefinition.build_nodes( + ai_graph_config, + { + "customer-support-agent": AIAgentConfig( + key="customer-support-agent", enabled=True + ), + "personalized-agent": AIAgentConfig(key="personalized-agent", enabled=True), + "multi-context-agent": AIAgentConfig( + key="multi-context-agent", enabled=True + ), + "minimal-agent": AIAgentConfig(key="minimal-agent", enabled=True), + }, + ) + + assert nodes["customer-support-agent"] is not None + assert nodes["personalized-agent"] is not None + assert nodes["multi-context-agent"] is not None + assert nodes["minimal-agent"] is not None + + assert len(nodes["customer-support-agent"].get_edges()) == 3 + assert len(nodes["personalized-agent"].get_edges()) == 0 + assert len(nodes["multi-context-agent"].get_edges()) == 0 + assert len(nodes["minimal-agent"].get_edges()) == 0 + + assert type(nodes["customer-support-agent"].get_config()) is AIAgentConfig + assert type(nodes["personalized-agent"].get_config()) is AIAgentConfig + assert type(nodes["multi-context-agent"].get_config()) is AIAgentConfig + assert type(nodes["minimal-agent"].get_config()) is AIAgentConfig + + assert type(nodes["customer-support-agent"].get_edges()[0]) is Edge + + +def test_agent_graph_get_methods(ldai_client: LDAIClient): + graph = ldai_client.agent_graph("test-agent-graph", Context.create("user-key")) + + assert graph.root() is not None + assert graph.root().get_key() == "customer-support-agent" + assert graph.get_node("customer-support-agent") is not None + assert graph.get_node("personalized-agent") is not None + assert graph.get_node("multi-context-agent") is not None + + children = graph.get_child_nodes("customer-support-agent") + assert len(children) == 3 + assert children[0].get_key() == "personalized-agent" + assert children[1].get_key() == "multi-context-agent" + assert children[2].get_key() == "minimal-agent" + + parents = graph.get_parent_nodes("personalized-agent") + assert len(parents) == 1 + assert parents[0].get_key() == "customer-support-agent" + + parents = graph.get_parent_nodes("multi-context-agent") + assert len(parents) == 1 + assert parents[0].get_key() == "customer-support-agent" + + terminal = graph.terminal_nodes() + assert len(terminal) == 3 + assert terminal[0].get_key() == "personalized-agent" + assert terminal[1].get_key() == "multi-context-agent" + assert terminal[2].get_key() == "minimal-agent" + + assert graph.root().is_terminal() is False + assert graph.get_node("customer-support-agent").is_terminal() is False + assert graph.get_node("personalized-agent").is_terminal() is True + assert graph.get_node("multi-context-agent").is_terminal() is True + assert graph.get_node("minimal-agent").is_terminal() is True + + +def test_agent_graph_traverse(ldai_client: LDAIClient): + graph = ldai_client.agent_graph( + "test-agent-graph-depth-3", Context.create("user-key") + ) + + context = {} + order = [] + + def handle_traverse(node, context): + # Asserting that returned values are included in the context + for previousKey in order: + assert previousKey in context + assert context[previousKey] == previousKey + "-test" + order.append(node.get_key()) + return node.get_key() + "-test" + + graph.traverse(handle_traverse, context) + # Asserting that we traverse in the expected order + # This config specifically has nodes connecting from depth 2->3 and root->3 to ensure the root node is visited first + # and minimal-agent is visited last + assert order == [ + "customer-support-agent", + "personalized-agent", + "multi-context-agent", + "minimal-agent", + ] + + +def test_agent_graph_reverse_traverse(ldai_client: LDAIClient): + graph = ldai_client.agent_graph( + "test-agent-graph-depth-3", Context.create("user-key") + ) + + context = {} + order = [] + + def handle_reverse_traverse(node, context): + # Asserting that returned values are included in the context + for previousKey in order: + assert previousKey in context + assert context[previousKey] == previousKey + "-test" + order.append(node.get_key()) + return node.get_key() + "-test" + + graph.reverse_traverse(handle_reverse_traverse, context) + # Asserting that we traverse in the expected order + # This config specifically has nodes connecting from depth 2->3 and root->3 to ensure the root node is visited last + assert order == [ + "minimal-agent", + "multi-context-agent", + "personalized-agent", + "customer-support-agent", + ] + + +def test_agent_graph_handoff(ldai_client: LDAIClient): + graph = ldai_client.agent_graph( + "test-agent-graph-depth-3", Context.create("user-key") + ) + + context = {} + + def handle_traverse(node, context): + if node.get_key() == "multi-context-agent": + first_edge = node.get_edges()[0] + assert first_edge.handoff == {"state": "from-multi-context-to-minimal"} + assert first_edge.source_config == "multi-context-agent" + assert first_edge.target_config == "minimal-agent" + assert first_edge.key == "edge-multi-context-agent-minimal-agent" + if node.get_key() == "customer-support-agent": + first_edge = node.get_edges()[0] + second_edge = node.get_edges()[1] + assert first_edge.handoff == {"state": "from-root-to-personalized"} + assert second_edge.handoff == {"state": "from-root-to-minimal"} + assert first_edge.source_config == "customer-support-agent" + assert first_edge.target_config == "personalized-agent" + assert first_edge.key == "edge-customer-support-agent-personalized-agent" + assert second_edge.source_config == "customer-support-agent" + assert second_edge.target_config == "minimal-agent" + assert second_edge.key == "edge-customer-support-agent-minimal-agent" + return None + + graph.traverse(handle_traverse, context)