Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 74 additions & 37 deletions AIDojoCoordinator/coordinator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import jsonlines
import logging
import json
import asyncio
Expand All @@ -7,7 +6,7 @@

from AIDojoCoordinator.game_components import Action, Observation, ActionType, GameStatus, GameState, AgentStatus, ProtocolConfig
from AIDojoCoordinator.global_defender import GlobalDefender
from AIDojoCoordinator.utils.utils import observation_as_dict, get_str_hash, ConfigParser
from AIDojoCoordinator.utils.utils import observation_as_dict, get_str_hash, ConfigParser, store_trajectories_to_jsonl
import os
from aiohttp import ClientSession
from cyst.api.environment.environment import Environment
Expand Down Expand Up @@ -136,8 +135,44 @@ async def __call__(self, reader, writer):
class GameCoordinator:
"""
Class for creation, and management of agent interactions in AI Dojo.

Attributes:
host (str): Host address for the game server.
port (int): Port number for the game server.
logger (logging.Logger): Logger for the GameCoordinator.
_tasks (set): Set of active asyncio tasks.
shutdown_flag (asyncio.Event): Event to signal shutdown.
_reset_event (asyncio.Event): Event to signal game reset.
_episode_end_event (asyncio.Event): Event to signal episode end.
_episode_start_event (asyncio.Event): Event to signal episode start.
_episode_rewards_condition (asyncio.Condition): Condition for episode rewards assignment.
_reset_done_condition (asyncio.Condition): Condition for reset completion.
_reset_lock (asyncio.Lock): Lock for reset operations.
_agents_lock (asyncio.Lock): Lock for agent operations.
_service_host (str): Host for remote configuration service.
_service_port (int): Port for remote configuration service.
_task_config_file (str): Path to local task configuration file.
ALLOWED_ROLES (list): List of allowed agent roles.
_cyst_objects: CYST simulator initialization objects.
_cyst_object_string: String representation of CYST objects.
_agent_action_queue (asyncio.Queue): Queue for agent actions.
_agent_response_queues (dict): Mapping of agent addresses to their response queues.
agents (dict): Mapping of agent addresses to their information.
_agent_steps (dict): Step counters per agent address.
_reset_requests (dict): Reset requests per agent address.
_randomize_topology_requests (dict): Topology randomization requests per agent address.
_agent_status (dict): Status of each agent.
_episode_ends (dict): Episode end flags per agent address.
_agent_observations (dict): Observations per agent address.
_agent_starting_position (dict): Starting positions per agent address.
_agent_states (dict): Current states per agent address.
_agent_goal_states (dict): Goal states per agent address.
_agent_last_action (dict): Last actions played by agents.
_agent_false_positives (dict): False positives per agent.
_agent_rewards (dict): Rewards per agent address.
_agent_trajectories (dict): Trajectories per agent address.
"""
def __init__(self, game_host: str, game_port: int, service_host:str, service_port:int, allowed_roles=["Attacker", "Defender", "Benign"], task_config_file:str=None) -> None:
def __init__(self, game_host: str, game_port: int, service_host:str, service_port:int, task_config_file:str,allowed_roles=["Attacker", "Defender", "Benign"]) -> None:
self.host = game_host
self.port = game_port
self.logger = logging.getLogger("AIDojo-GameCoordinator")
Expand Down Expand Up @@ -190,8 +225,6 @@ def __init__(self, game_host: str, game_port: int, service_host:str, service_por
self._agent_rewards = {}
# trajectories per agent_addr
self._agent_trajectories = {}
# false_positives per agent_addr
self._agent_false_positives = {}

def _spawn_task(self, coroutine, *args, **kwargs)->asyncio.Task:
"Helper function to make sure all tasks are registered for proper termination"
Expand Down Expand Up @@ -317,6 +350,7 @@ async def start_tcp_server(self):
"""
Starts TPC sever for the agent communication.
"""
server = None
try:
self.logger.info("Starting the server listening for agents")
server = await asyncio.start_server(
Expand All @@ -337,8 +371,9 @@ async def start_tcp_server(self):
except Exception as e:
self.logger.error(f"TCP server failed: {e}")
finally:
server.close()
await server.wait_closed()
if server:
server.close()
await server.wait_closed()
self.logger.info("\tTCP server task stopped")

async def start_tasks(self):
Expand Down Expand Up @@ -421,32 +456,33 @@ async def run_game(self):
agent_addr, message = await self._agent_action_queue.get()
if message is not None:
self.logger.info(f"Coordinator received from agent {agent_addr}: {message}.")

try: # Convert message to Action
action = Action.from_json(message)
self.logger.debug(f"\tConverted to: {action}.")
match action.type: # process action based on its type
case ActionType.JoinGame:
self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.JoinGame by {agent_addr}")
self.logger.debug(f"{action.type}, {action.type.value}, {action.type == ActionType.JoinGame}")
self._spawn_task(self._process_join_game_action, agent_addr, action)
case ActionType.QuitGame:
self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.QuitGame by {agent_addr}")
self._spawn_task(self._process_quit_game_action, agent_addr)
case ActionType.ResetGame:
self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.ResetGame by {agent_addr}")
self._spawn_task(self._process_reset_game_action, agent_addr, action)
case ActionType.ExfiltrateData | ActionType.FindData | ActionType.ScanNetwork | ActionType.FindServices | ActionType.ExploitService:
self.logger.debug(f"About agent {agent_addr}. Start processing of {action.type} by {agent_addr}")
self._spawn_task(self._process_game_action, agent_addr, action)
case ActionType.BlockIP:
self.logger.debug(f"About agent {agent_addr}. Start processing of {action.type} by {agent_addr}")
self._spawn_task(self._process_game_action, agent_addr, action)
case _:
self.logger.warning(f"About agent {agent_addr}. Unsupported action type: {action}!")
except Exception as e:
self.logger.error(
f"Error when converting msg to Action using Action.from_json():{e}, {message}"
)
match action.type: # process action based on its type
case ActionType.JoinGame:
self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.JoinGame by {agent_addr}")
self.logger.debug(f"{action.type}, {action.type.value}, {action.type == ActionType.JoinGame}")
self._spawn_task(self._process_join_game_action, agent_addr, action)
case ActionType.QuitGame:
self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.QuitGame by {agent_addr}")
self._spawn_task(self._process_quit_game_action, agent_addr)
case ActionType.ResetGame:
self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.ResetGame by {agent_addr}")
self._spawn_task(self._process_reset_game_action, agent_addr, action)
case ActionType.ExfiltrateData | ActionType.FindData | ActionType.ScanNetwork | ActionType.FindServices | ActionType.ExploitService:
self.logger.debug(f"About agent {agent_addr}. Start processing of {action.type} by {agent_addr}")
self._spawn_task(self._process_game_action, agent_addr, action)
case ActionType.BlockIP:
self.logger.debug(f"About agent {agent_addr}. Start processing of {action.type} by {agent_addr}")
self._spawn_task(self._process_game_action, agent_addr, action)
case _:
self.logger.warning(f"About agent {agent_addr}. Unsupported action type: {action}!")
self.logger.info("\tAction processing task stopped.")

async def _process_join_game_action(self, agent_addr: tuple, action: Action)->None:
Expand Down Expand Up @@ -945,7 +981,7 @@ def _reset_trajectory(self, agent_addr:tuple)->dict:
"agent_name":agent_name
}

def _add_step_to_trajectory(self, agent_addr:tuple, action:Action, reward:float, next_state:GameState, end_reason:str=None)-> None:
def _add_step_to_trajectory(self, agent_addr:tuple, action:Action, reward:float, next_state:GameState, end_reason:str|None=None)-> None:
"""
Method for adding one step to the agent trajectory.
"""
Expand All @@ -958,16 +994,17 @@ def _add_step_to_trajectory(self, agent_addr:tuple, action:Action, reward:float,
self._agent_trajectories[agent_addr]["end_reason"] = end_reason

def _store_trajectory_to_file(self, agent_addr:tuple, location="./logs/trajectories")-> None:
if not os.path.exists(location):
os.makedirs(location)
self.logger.debug(f"Created directory for storing trajectories: {location}")
self.logger.debug(f"Storing Trajectory of {agent_addr}in file")
if agent_addr in self._agent_trajectories:
agent_name, agent_role = self.agents[agent_addr]
filename = os.path.join(location, f"{datetime.now():%Y-%m-%d}_{agent_name}_{agent_role}.jsonl")
with jsonlines.open(filename, "a") as writer:
writer.write(self._agent_trajectories[agent_addr])
self.logger.info(f"Trajectory of {agent_addr} strored in {filename}")
"""
Method for storing the agent trajectory to a file.
"""
if agent_addr in self.agents:
agent_name, agent_role = self.agents[agent_addr]
filename =f"{datetime.now():%Y-%m-%d}_{agent_name}_{agent_role}"
trajectories = self._agent_trajectories[agent_addr]
store_trajectories_to_jsonl(trajectories, location, filename)
self.logger.info(f"Trajectories of {agent_addr} strored in {os.path.join(location, filename)}.jsonl")
else:
self.logger.warning(f"Agent {agent_addr} not found in agents list, can't store trajectory to file.")

def is_agent_benign(self, agent_addr:tuple)->bool:
"""
Expand Down
18 changes: 18 additions & 0 deletions AIDojoCoordinator/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import netaddr
import logging
import csv
import os
import jsonlines
from random import randint
import json
import hashlib
Expand Down Expand Up @@ -565,6 +567,22 @@ def get_starting_position_from_cyst_config(cyst_objects):
starting_positions[f"{obj.id}.{active_service.name}"] = {"known_hosts":hosts, "known_networks":networks}
return starting_positions

def store_trajectories_to_jsonl(trajectories:list, dir:str, filename:str)->None:
"""
Store trajectories to a JSONL file.
Args:
trajectories (list): List of trajectory data to store.
dir (str): Directory where the file will be stored.
filename (str): Name of the file (without extension).
"""
# make sure the directory exists
if not os.path.exists(dir):
os.makedirs(dir)
# construct the full file name
filename = os.path.join(dir, f"{filename.rstrip('jsonl')}.jsonl")
# store the trajectories
with jsonlines.open(filename, "a") as writer:
writer.write(trajectories)

if __name__ == "__main__":
state = GameState(known_networks={Network("1.1.1.1", 24),Network("1.1.1.2", 24)},
Expand Down