diff --git a/AIDojoCoordinator/coordinator.py b/AIDojoCoordinator/coordinator.py index 92d07034..9f0a2a8e 100644 --- a/AIDojoCoordinator/coordinator.py +++ b/AIDojoCoordinator/coordinator.py @@ -180,6 +180,8 @@ def __init__(self, game_host: str, game_port: int, service_host:str, service_por self._agent_starting_position = {} # current state per agent_addr (GameState) self._agent_states = {} + # goal state per agent_addr (GameState) + self._agent_goal_states = {} # last action played by agent (Action) self._agent_last_action = {} # False positives per agent (due to added blocks) @@ -462,11 +464,11 @@ async def _process_join_game_action(self, agent_addr: tuple, action: Action)->No agent_role = action.parameters["agent_info"].role if agent_role in self.ALLOWED_ROLES: # add agent to the world - new_agent_game_state = await self.register_agent(agent_addr, agent_role, self._starting_positions_per_role[agent_role]) + new_agent_game_state, new_agent_goal_state = await self.register_agent(agent_addr, agent_role, self._starting_positions_per_role[agent_role], self._win_conditions_per_role[agent_role]) if new_agent_game_state: # successful registration async with self._agents_lock: self.agents[agent_addr] = (agent_name, agent_role) - observation = self._initialize_new_player(agent_addr, new_agent_game_state) + observation = self._initialize_new_player(agent_addr, new_agent_game_state, new_agent_goal_state) self._agent_observations[agent_addr] = observation #if len(self.agents) == self._min_required_players: if sum(1 for v in self._agent_status.values() if v == AgentStatus.PlayingWithTimeout) >= self._min_required_players: @@ -720,10 +722,13 @@ async def _reset_game(self): async with self._agents_lock: self._store_trajectory_to_file(agent) self.logger.debug(f"Resetting agent {agent}") - new_state = await self.reset_agent(agent, self.agents[agent][1], self._agent_starting_position[agent]) + agent_role = self.agents[agent][1] + # reset the agent in the world + new_state, new_goal_state = await self.reset_agent(agent, agent_role, self._starting_positions_per_role[agent_role], self._win_conditions_per_role[agent_role]) new_observation = Observation(new_state, 0, False, {}) async with self._agents_lock: self._agent_states[agent] = new_state + self._agent_goal_states[agent] = new_goal_state self._agent_observations[agent] = new_observation self._episode_ends[agent] = False self._reset_requests[agent] = False @@ -741,7 +746,7 @@ async def _reset_game(self): self._reset_done_condition.notify_all() self.logger.info("\tReset game task stopped.") - def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState) -> Observation: + def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState, agent_current_goal_state:GameState) -> Observation: """ Method to initialize new player upon joining the game. Returns initial observation for the agent based on the agent's role @@ -753,6 +758,8 @@ def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState self._episode_ends[agent_addr] = False self._agent_starting_position[agent_addr] = self._starting_positions_per_role[agent_role] self._agent_states[agent_addr] = agent_current_state + self._agent_goal_states[agent_addr] = agent_current_goal_state + self._agent_last_action[agent_addr] = None self._agent_rewards[agent_addr] = 0 self._agent_false_positives[agent_addr] = 0 if agent_role.lower() == "attacker": @@ -764,7 +771,7 @@ def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState # create initial observation return Observation(self._agent_states[agent_addr], 0, False, {}) - async def register_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict)->GameState: + async def register_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict, agent_win_condition_view:dict)->tuple[GameState, GameState]: """ Domain specific method of the environment. Creates the initial state of the agent. """ @@ -775,8 +782,8 @@ async def remove_agent(self, agent_id:tuple, agent_state:GameState)->bool: Domain specific method of the environment. Creates the initial state of the agent. """ raise NotImplementedError - - async def reset_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict)->GameState: + + async def reset_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict, agent_win_condition_view:dict)->tuple[GameState, GameState]: raise NotImplementedError async def _remove_agent_from_game(self, agent_addr): @@ -788,6 +795,7 @@ async def _remove_agent_from_game(self, agent_addr): async with self._agents_lock: if agent_addr in self.agents: agent_info["state"] = self._agent_states.pop(agent_addr) + agent_info["goal_state"] = self._agent_goal_states.pop(agent_addr) agent_info["num_steps"] = self._agent_steps.pop(agent_addr) agent_info["agent_status"] = self._agent_status.pop(agent_addr) agent_info["false_positives"] = self._agent_false_positives.pop(agent_addr) @@ -816,7 +824,7 @@ async def _remove_agent_from_game(self, agent_addr): async def step(self, agent_id:tuple, agent_state:GameState, action:Action): raise NotImplementedError - async def reset(self): + async def reset(self)->bool: return NotImplemented def _initialize(self): @@ -846,16 +854,17 @@ def goal_dict_satistfied(goal_dict:dict, known_dict: dict)-> bool: return False return False self.logger.debug(f"Checking goal for agent {agent_addr}.") - goal_conditions = self._win_conditions_per_role[self.agents[agent_addr][1]] state = self._agent_states[agent_addr] # For each part of the state of the game, check if the conditions are met + target_goal_state = self._agent_goal_states[agent_addr] + self.logger.debug(f"\tGoal conditions: {target_goal_state}.") goal_reached = {} - goal_reached["networks"] = set(goal_conditions["known_networks"]) <= set(state.known_networks) - goal_reached["known_hosts"] = set(goal_conditions["known_hosts"]) <= set(state.known_hosts) - goal_reached["controlled_hosts"] = set(goal_conditions["controlled_hosts"]) <= set(state.controlled_hosts) - goal_reached["services"] = goal_dict_satistfied(goal_conditions["known_services"], state.known_services) - goal_reached["data"] = goal_dict_satistfied(goal_conditions["known_data"], state.known_data) - goal_reached["known_blocks"] = goal_dict_satistfied(goal_conditions["known_blocks"], state.known_blocks) + goal_reached["networks"] = target_goal_state.known_networks <= state.known_networks + goal_reached["known_hosts"] = target_goal_state.known_hosts <= state.known_hosts + goal_reached["controlled_hosts"] = target_goal_state.controlled_hosts <= state.controlled_hosts + goal_reached["services"] = goal_dict_satistfied(target_goal_state.known_services, state.known_services) + goal_reached["data"] = goal_dict_satistfied(target_goal_state.known_data, state.known_data) + goal_reached["known_blocks"] = goal_dict_satistfied(target_goal_state.known_blocks, state.known_blocks) self.logger.debug(f"\t{goal_reached}") return all(goal_reached.values()) diff --git a/AIDojoCoordinator/game_components.py b/AIDojoCoordinator/game_components.py index 98881f13..5ec6a505 100755 --- a/AIDojoCoordinator/game_components.py +++ b/AIDojoCoordinator/game_components.py @@ -551,8 +551,8 @@ def as_graph(self)->tuple: graph_nodes = {} node_features = [] controlled = [] + edges = [] try: - edges = [] #add known nets for net in self.known_networks: graph_nodes[net] = len(graph_nodes) @@ -738,6 +738,8 @@ def from_string(cls, string:str)->"GameStatus": return GameStatus.FORBIDDEN case "GameStatus.RESET_DONE": return GameStatus.RESET_DONE + case _: + raise ValueError(f"Invalid GameStatus string: {string}") def __repr__(self) -> str: """ Return the string representation of the GameStatus. diff --git a/AIDojoCoordinator/utils/utils.py b/AIDojoCoordinator/utils/utils.py index a1e5510e..c92b030c 100644 --- a/AIDojoCoordinator/utils/utils.py +++ b/AIDojoCoordinator/utils/utils.py @@ -13,6 +13,7 @@ import json import hashlib from cyst.api.configuration.network.node import NodeConfig +from typing import Optional def get_file_hash(filepath, hash_func='sha256', chunk_size=4096): """ @@ -111,7 +112,7 @@ def observation_as_dict(observation:Observation)->dict: } return observation_dict -def parse_log_content(log_content:str)->list: +def parse_log_content(log_content:str)->Optional[list]: try: logs = [] data = json.loads(log_content) @@ -154,7 +155,7 @@ def read_config_file(self, conf_file_name:str): self.logger.error(f'Error loading the configuration file{e}') pass - def read_env_action_data(self, action_name: str) -> dict: + def read_env_action_data(self, action_name: str) -> float: """ Generic function to read the known data for any agent and goal of position """ @@ -238,7 +239,7 @@ def read_agents_known_services(self, type_agent: str, type_data: str) -> dict: known_services = {} return known_services - def read_agents_known_networks(self, type_agent: str, type_data: str) -> dict: + def read_agents_known_networks(self, type_agent: str, type_data: str) -> set: """ Generic function to read the known networks for any agent and goal of position """ @@ -251,10 +252,10 @@ def read_agents_known_networks(self, type_agent: str, type_data: str) -> dict: host_part, net_part = net.split('/') known_networks.add(Network(host_part, int(net_part))) except (ValueError, TypeError, netaddr.AddrFormatError): - self.logger('Configuration problem with the known networks') + self.logger.error('Configuration problem with the known networks') return known_networks - def read_agents_known_hosts(self, type_agent: str, type_data: str) -> dict: + def read_agents_known_hosts(self, type_agent: str, type_data: str) -> set: """ Generic function to read the known hosts for any agent and goal of position """ @@ -274,7 +275,7 @@ def read_agents_known_hosts(self, type_agent: str, type_data: str) -> dict: self.logger.error(f'Configuration problem with the known hosts: {e}') return known_hosts - def read_agents_controlled_hosts(self, type_agent: str, type_data: str) -> dict: + def read_agents_controlled_hosts(self, type_agent: str, type_data: str) -> set: """ Generic function to read the controlled hosts for any agent and goal of position """ @@ -395,7 +396,7 @@ def get_win_conditions(self, agent_role): case _: raise ValueError(f"Unsupported agent role: {agent_role}") - def get_max_steps(self, role=str)->int: + def get_max_steps(self, role=str)->Optional[int]: """ Get the max steps based on agent's role """ @@ -409,7 +410,7 @@ def get_max_steps(self, role=str)->int: self.logger.warning(f"Unsupported value in 'coordinator.agents.{role}.max_steps': {e}. Setting value to default=None (no step limit)") return max_steps - def get_goal_description(self, agent_role)->dict: + def get_goal_description(self, agent_role)->str: """ Get goal description per role """ @@ -554,7 +555,7 @@ def get_starting_position_from_cyst_config(cyst_objects): if isinstance(obj, NodeConfig): for active_service in obj.active_services: if active_service.type == "netsecenv_agent": - print(f"startig processing {obj.id}.{active_service.name}") + print(f"starting processing {obj.id}.{active_service.name}") hosts = set() networks = set() for interface in obj.interfaces: diff --git a/AIDojoCoordinator/worlds/NSEGameCoordinator.py b/AIDojoCoordinator/worlds/NSEGameCoordinator.py index b845497c..76222313 100644 --- a/AIDojoCoordinator/worlds/NSEGameCoordinator.py +++ b/AIDojoCoordinator/worlds/NSEGameCoordinator.py @@ -10,6 +10,7 @@ from faker import Faker from pathlib import Path from typing import Iterable +from collections import defaultdict from AIDojoCoordinator.game_components import GameState, Action, ActionType, IP, Network, Data, Service from AIDojoCoordinator.coordinator import GameCoordinator @@ -44,7 +45,7 @@ def __init__(self, game_host, game_port, task_config:str, allowed_roles=["Attack self._seed = seed self.logger.info(f'Setting env seed to {seed}') - def _initialize(self) -> None: + def _initialize(self): """ Initializes the NetSecGame environment. @@ -67,87 +68,183 @@ def _initialize(self) -> None: self._data_content_original = copy.deepcopy(self._data_content) self._firewall_original = copy.deepcopy(self._firewall) self.logger.info("Environment initialization finished") - - def _get_controlled_hosts_from_view(self, view_controlled_hosts:Iterable)->set: + + def _get_hosts_from_view(self, view_hosts:Iterable, allowed_hosts=None)->set[IP]: """ Parses view and translates all keywords. Produces set of controlled host (IP) + Args: + view_hosts (Iterable): The view containing host information. + allowed_hosts (list, optional): A list of host to start from if 'random' is specified. Defaults to None. + Returns: + set: A set of controlled hosts. """ - controlled_hosts = set() + hosts = set() + self.logger.debug(f'\tParsing hosts from view: {view_hosts}') # controlled_hosts - for host in view_controlled_hosts: + for host in view_hosts: if isinstance(host, IP): - controlled_hosts.add(self._ip_mapping[host]) - self.logger.debug(f'\tThe attacker has control of host {self._ip_mapping[host]}.') + hosts.add(host) + self.logger.debug(f'\tAdding {host}.') elif host == 'random': # Random start - self.logger.debug('\tAdding random starting position of agent') - self.logger.debug(f'\t\tChoosing from {self.hosts_to_start}') - selected = random.choice(self.hosts_to_start) - controlled_hosts.add(selected) - self.logger.debug(f'\t\tMaking agent start in {selected}') + if allowed_hosts is not None: + self.logger.debug(f'\tChoosing randomly from {allowed_hosts}') + selected = random.choice(allowed_hosts) + else: + self.logger.debug(f'\tChoosing randomly from all available hosts {list(self._ip_to_hostname.keys())}') + selected = random.choice(list(self._ip_to_hostname.keys())) + hosts.add(selected) + self.logger.debug(f'\t\tAdding {selected}.') elif host == "all_local": # all local ips - self.logger.debug('\t\tAdding all local hosts to agent') - controlled_hosts = controlled_hosts.union(self._get_all_local_ips()) + self.logger.debug(f'\tAdding all local hosts') + hosts = hosts.union(self._get_all_local_ips()) else: - self.logger.error(f"Unsupported value encountered in start_position['controlled_hosts']: {host}") - return controlled_hosts + self.logger.error(f"Unsupported value encountered in view_hosts: {host}") + return hosts def _get_services_from_view(self, view_known_services:dict)->dict: """ Parses view and translates all keywords. Produces dict of known services {IP: set(Service)} - + Args: view_known_services (dict): The view containing known services information. Returns: dict: A dictionary mapping IP addresses to sets of known services. """ + # TODO: Add keyword scope parameter (like in _get_data_from_view) known_services = {} for ip, service_list in view_known_services.items(): - if self._ip_mapping[ip] not in known_services: - known_services[self._ip_mapping[ip]] = set() - for s in service_list: - if isinstance(s, Service): - known_services[self._ip_mapping[ip]].add(s) - elif isinstance(s, str): - if s == "random": # randomly select the service - self.logger.info(f"\tSelecting service randomly in {self._ip_mapping[ip]}") + self.logger.debug(f'\tParsing services from {ip}: {service_list}') + known_services[ip] = set() + for service in service_list: + if isinstance(service, Service): + known_services[ip].add(service) + self.logger.debug(f'\tAdding {service}.') + elif isinstance(service, str): + if service == "random": # randomly select the service + self.logger.info(f"\tSelecting service randomly in {ip}") # select candidates that are not explicitly listed - service_candidates = [s for s in self._services[self._ip_to_hostname[ip]] if s not in known_services[self._ip_mapping[ip]]] - # randomly select from candidates - known_services[self._ip_mapping[ip]].add(random.choice(service_candidates)) + service_candidates = [s for s in self._services[self._ip_to_hostname[ip]] if s not in known_services[ip]] + if len(service_candidates) == 0: + self.logger.warning("\t\tNo available services. Skipping") + else: + # randomly select from candidates + selected = random.choice(service_candidates) + self.logger.debug(f"\t\tAdding: {selected}") + known_services[ip].add(selected) + elif service == "all": + self.logger.info(f"\tSelecting all services in {ip}") + known_services[ip].update(self._services[self._ip_to_hostname[ip]]) + else: + self.logger.error(f"Unsupported value encountered in view_known_services: {service}") + # re-map all IPs based on current mapping in self._ip_mapping return known_services - def _get_data_from_view(self, view_known_data:dict)->dict: + def _get_data_from_view(self, view_known_data:dict, keyword_scope:str="host", exclude_types=["log"])->dict: """ Parses view and translates all keywords. Produces dict of known data {IP: set(Data)} Args: view_known_data (dict): The view containing known data information. - + keyword_scope (str, optional): Scope of keywords like 'random' or 'all'. Defaults to "host" (i.e., only data from the specified host are considered). + exclude_types (list, optional): List of data types to exclude when selecting data. Defaults to ["log"]. Returns: dict: A dictionary mapping IP addresses to sets of known data. """ known_data = {} for ip, data_list in view_known_data.items(): - if self._ip_mapping[ip] not in known_data: - known_data[self._ip_mapping[ip]] = set() + self.logger.debug(f'\tParsing data from {ip}: {data_list}') + known_data[ip] = set() for datum in data_list: if isinstance(datum, Data): - known_data[self._ip_mapping[ip]].add(datum) + known_data[ip].add(datum) + self.logger.debug(f'\tAdding {datum}.') elif isinstance(datum, str): - if datum == "random": # randomly select the data - self.logger.info(f"\tSelecting data randomly in {self._ip_mapping[ip]}") - # select candidates that are not explicitly listed - data_candidates = [d for d in self._data[self._ip_to_hostname[ip]] if d not in known_data[self._ip_mapping[ip]]] - if len(data_candidates) > 0: - # randomly select from candidates - known_data[self._ip_mapping[ip]].add(random.choice(data_candidates)) + # select candidates that are not explicitly listed + data_candidates = set() + if keyword_scope == "host": # scope of the keyword is the host only + for d in self._data[self._ip_to_hostname[ip]]: + if d.type not in exclude_types and d not in known_data[ip]: + data_candidates.add(d) + else: + # scope of the keyword is all hosts + for datapoints in self._data.values(): + for d in datapoints: + if d.type not in exclude_types and d not in known_data[ip]: + data_candidates.add(d) + if datum == "random": # randomly select the service + self.logger.info("\tSelecting data randomly") + if len(data_candidates) == 0: + self.logger.warning("\t\tNo available data. Skipping") else: - self.logger.warning("\tNo available data. Skipping") + # randomly select from candidates + selected = random.choice(list(data_candidates)) + self.logger.debug(f"\t\tAdding: {selected}") + known_data[ip].add(selected) + elif datum == "all": + self.logger.info(f"\tSelecting all data in {ip}") + known_data[ip].update(data_candidates) + else: + self.logger.error(f"Unsupported value encountered in view_known_data: {datum}") + else: + self.logger.error(f"Unsupported value encountered in view_known_data: {datum}") + # re-map all IPs based on current mapping in self._ip_mapping return known_data + def _get_networks_from_view(self, view_known_networks:Iterable)->set[Network]: + """ + Parses view and translates all keywords. Produces set of known networks (Network). + Args: + view_known_networks (Iterable): The view containing known networks information. + Returns: + set: A set of known networks. + """ + known_networks = set() + for net in view_known_networks: + if isinstance(net, Network): + known_networks.add(self._network_mapping[net]) + self.logger.debug(f'\tAdding network {self._network_mapping[net]}.') + elif net == 'random': + # Randomly select a network from the available ones + selected = random.choice(list(self._networks.keys())) + known_networks.add(self._network_mapping[selected]) + self.logger.debug(f'\tAdding randomly selected network: {self._network_mapping[selected]}.') + elif net == "all_local": + # all local networks + self.logger.debug('\t\tAdding all local private networks') + for n in self._networks.keys(): + if not n.is_private(): + known_networks.add(self._network_mapping[n]) + else: + self.logger.error(f"Unsupported value encountered in start_position['known_networks']: {net}") + return known_networks + + def _create_goal_state_from_view(self, view:dict, allowed_hosts=None)->GameState: + """ + Builds a GameState from given view (dict). All keywords are replaced by valid options. + Args: + view (dict): The view containing goal state information. + allowed_hosts (set, optional): A set of allowed hosts for random selection. Defaults to None. + Returns: + GameState: The generated goal state. + """ + self.logger.info(f'Generating goal state from view:{view}') + # process known networks + known_networks = self._get_networks_from_view(view_known_networks=view["known_networks"]) + # parse controlled hosts, replacing keywords if present + controlled_hosts = self._get_hosts_from_view(view_hosts=view["controlled_hosts"], allowed_hosts=allowed_hosts) + # parse known hosts + known_hosts = self._get_hosts_from_view(view_hosts=view["known_hosts"]) + # parse known services + known_services = self._get_services_from_view(view["known_services"]) + # parse known data + known_data = self._get_data_from_view(view["known_data"], keyword_scope="global", exclude_types=["logs"]) + goal_state = GameState(controlled_hosts, known_hosts, known_services, known_data, known_networks) + self.logger.info(f"Generated Goal GameState:{goal_state}") + return goal_state + def _create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->GameState: """ Builds a GameState from given view. @@ -157,10 +254,10 @@ def _create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->Ga """ self.logger.info(f'Generating state from view:{view}') # re-map all networks based on current mapping in self._network_mapping - known_networks = set([self._network_mapping[net] for net in view["known_networks"]]) + known_networks = self._get_networks_from_view(view_known_networks=view["known_networks"]) # parse controlled hosts - controlled_hosts = self._get_controlled_hosts_from_view(view["controlled_hosts"]) - known_hosts = set([self._ip_mapping[ip] for ip in view["known_hosts"]]) + controlled_hosts = self._get_hosts_from_view(view_hosts=view["controlled_hosts"], allowed_hosts=self.hosts_to_start) + known_hosts = self._get_hosts_from_view(view_hosts=view["known_hosts"], allowed_hosts=self.hosts_to_start) # Add all controlled hosts to known_hosts known_hosts = known_hosts.union(controlled_hosts) if add_neighboring_nets: @@ -377,7 +474,131 @@ def process_firewall()->dict: self.logger.info(f"\tintitial self._ip_mapping: {self._ip_mapping}") self.logger.info("CYST configuration processed successfully") - def _create_new_network_mapping(self)->tuple: + def _dynamic_ip_change(self, max_attempts:int=10)-> None: + """ + Changes the IP and network addresses in the environment + """ + self.logger.info("Changing IP and Network addresses in the environment") + # find a new IP and network mapping + mapping_nets, mapping_ips = self._create_new_network_mapping(max_attempts) + + # update ALL data structure in the environment with the new mappings + + # self._networks + new_self_networks = {} + for net, ips in self._networks.items(): + new_self_networks[mapping_nets[net]] = set() + for ip in ips: + new_self_networks[mapping_nets[net]].add(mapping_ips[ip]) + self._networks = new_self_networks + + #self._firewall_original (we do not care about the changes done during the episode) + new_self_firewall_original = {} + for ip, dst_ips in self._firewall_original.items(): + new_self_firewall_original[mapping_ips[ip]] = set() + for dst_ip in dst_ips: + new_self_firewall_original[mapping_ips[ip]].add(mapping_ips[dst_ip]) + self.logger.debug(f"New FW: {new_self_firewall_original}") + self._firewall_original = new_self_firewall_original + + # self._ip_to_hostname + new_self_ip_to_hostname = {} + for ip, hostname in self._ip_to_hostname.items(): + new_self_ip_to_hostname[mapping_ips[ip]] = hostname + self._ip_to_hostname = new_self_ip_to_hostname + + # Map hosts_to_start + new_self_host_to_start = [] + for ip in self.hosts_to_start: + new_self_host_to_start.append(mapping_ips[ip]) + self.hosts_to_start = new_self_host_to_start + + def apply_mapping(d: dict, mapping: dict) -> dict: + """ + Apply a mapping to a dictionary. + - Keys are remapped with mapping if present. + - Values: + * If iterable (set/list/tuple), each element is remapped. + * If string (or non-iterable), attempt direct remap. + """ + out = defaultdict(set) + for k, vals in d.items(): + nk = mapping.get(k, k) + + if isinstance(vals, str) or not isinstance(vals, Iterable): + # treat as a single atomic value + nv = {mapping.get(vals, vals)} + else: + nv = {mapping.get(v, v) for v in vals} + + out[nk].update(nv) + + return dict(out) + + # start_position per role + for role, start_position in self._starting_positions_per_role.items(): + # {'role': {'controlled_hosts': [...], 'known_hosts': [...], 'known_data': {...}, 'known_services': {...}, known_networks: [...], known_blocks: [...]}} + new_start_position = {} + new_start_position['known_networks'] = [mapping_nets.get(net, net) for net in start_position['known_networks']] + new_start_position['controlled_hosts'] = [mapping_ips.get(ip, ip) for ip in start_position['controlled_hosts']] + new_start_position['known_hosts'] = [mapping_ips.get(ip, ip) for ip in start_position['known_hosts']] + new_start_position['known_services'] = {mapping_ips.get(ip, ip): services for ip, services in start_position['known_services'].items()} + new_start_position["known_data"] = {mapping_ips.get(ip, ip): data for ip, data in start_position['known_data'].items()} + # known_blocks {IP:set(IP)} + new_start_position["known_blocks"] = apply_mapping(start_position.get("known_blocks", {}), mapping_ips) + self._starting_positions_per_role[role] = new_start_position + self.logger.debug(f"Updated starting position for role {role}: {self._starting_positions_per_role[role]}") + + # win_conditions_per_role + for role, win_condition in self._win_conditions_per_role.items(): + new_win_condition = {} + new_win_condition['known_networks'] = [mapping_nets.get(net, net) for net in win_condition['known_networks']] + new_win_condition['controlled_hosts'] = [mapping_ips.get(ip, ip) for ip in win_condition['controlled_hosts']] + new_win_condition['known_hosts'] = [mapping_ips.get(ip, ip) for ip in win_condition['known_hosts']] + new_win_condition['known_services'] = {mapping_ips.get(ip, ip): services for ip, services in win_condition['known_services'].items()} + new_win_condition["known_data"] = {mapping_ips.get(ip, ip): data for ip, data in win_condition['known_data'].items()} + new_win_condition["known_blocks"] = apply_mapping(win_condition.get("known_blocks", {}), mapping_ips) + self._win_conditions_per_role[role] = new_win_condition + self.logger.debug(f"Updated win condition for role {role}: {self._win_conditions_per_role[role]}") + + # goal_description_per_role + def replace_ips_in_text(text: str, ip_mapping: dict, net_mapping:dict) -> str: + """ + Replace IPs/CIDRs in text according to mapping {IP: IP}. + """ + # regex: matches IPv4 like 1.2.3.4 or 1.2.3.4/24 + ip_pattern = re.compile(r"\b(?:\d{1,3}\.){3}\d{1,3}(?:/\d{1,2})?\b") + + def replacer(match): + token = match.group(0) + if "/" in token: + try: + net_obj = Network(*token.split("/")) + return str(net_mapping.get(net_obj, token)) + except ValueError: + return token + else: + try: + net_obj = IP(token) + return str(ip_mapping.get(net_obj, token)) + except ValueError: + return token + + return ip_pattern.sub(replacer, text) + + new_goal_description = {role:replace_ips_in_text(description, mapping_ips, mapping_nets) for role, description in self._goal_description_per_role.items()} + self._goal_description_per_role = new_goal_description + self.logger.debug(f"Updated goal description per role: {self._goal_description_per_role}") + + # update mappings stored in the environment + for net, mapping in self._network_mapping.items(): + self._network_mapping[net] = mapping_nets[mapping] + self.logger.debug(f"self._network_mapping: {self._network_mapping}") + for ip, mapping in self._ip_mapping.items(): + self._ip_mapping[ip] = mapping_ips[mapping] + self.logger.debug(f"self._ip_mapping: {self._ip_mapping}") + + def _create_new_network_mapping(self, max_attempts:int=10)->tuple: """ Method that generates random IP and Network addreses while following the topology loaded in the environment. All internal data structures are updated with the newly generated addresses.""" @@ -419,8 +640,8 @@ def _create_new_network_mapping(self)->tuple: except IndexError as e: self.logger.info(f"Dynamic address sampling failed, re-trying. {e}") counter_iter +=1 - if counter_iter > 10: - self.logger.error("Dynamic address failed more than 10 times - stopping.") + if counter_iter > max_attempts: + self.logger.error(f"Dynamic address failed more than {max_attempts} times - stopping.") exit(-1) # Invalid IP address boundary self.logger.info(f"New network mapping:{mapping_nets}") @@ -432,101 +653,13 @@ def _create_new_network_mapping(self)->tuple: random.shuffle(ip_list) for i,ip in enumerate(ips): mapping_ips[ip] = IP(str(ip_list[i])) - # Always add random, in case random is selected for ips + # Always add keywords 'random' and 'all_local' 'all_attackers' to the mapping mapping_ips['random'] = 'random' - self.logger.info(f"Mapping IPs done:{mapping_ips}") - - # update ALL data structure in the environment with the new mappings - # self._networks - new_self_networks = {} - for net, ips in self._networks.items(): - new_self_networks[mapping_nets[net]] = set() - for ip in ips: - new_self_networks[mapping_nets[net]].add(mapping_ips[ip]) - self._networks = new_self_networks - - #self._firewall_original (we do not care about the changes done during the episode) - new_self_firewall_original = {} - for ip, dst_ips in self._firewall_original.items(): - new_self_firewall_original[mapping_ips[ip]] = set() - for dst_ip in dst_ips: - new_self_firewall_original[mapping_ips[ip]].add(mapping_ips[dst_ip]) - self.logger.debug(f"New FW: {new_self_firewall_original}") - self._firewall_original = new_self_firewall_original - - #self._ip_to_hostname - new_self_ip_to_hostname = {} - for ip, hostname in self._ip_to_hostname.items(): - new_self_ip_to_hostname[mapping_ips[ip]] = hostname - self._ip_to_hostname = new_self_ip_to_hostname + mapping_ips['all_local'] = 'all_local' + mapping_ips['all_attackers'] = 'all_attackers' - # Map hosts_to_start - new_self_host_to_start = [] - for ip in self.hosts_to_start: - new_self_host_to_start.append(mapping_ips[ip]) - self.hosts_to_start = new_self_host_to_start - - # map IPs and networks stored in the taskconfig file - # This is a quick fix, we should find some other solution - agents = self.task_config.config['coordinator']['agents'] - # Fields that are dictionaries with IP keys - dict_keys = ['known_data', 'blocked_ips', 'known_blocks'] - # Fields that are lists of IP strings - list_keys = ['known_hosts', 'controlled_hosts'] - ip_regex = re.compile(r'\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b') - - for agent in agents.values(): - for section_key in ['goal', 'start_position']: - section = agent.get(section_key, {}) - - # Remap IP addresses in the description field of the goal section - if section_key == 'goal' and 'description' in section: - description = section['description'] - def repl(match): - ip_str = match.group(0) - try: - new_ip = str(mapping_ips[IP(ip_str)]) - return new_ip - except (ValueError, KeyError): - return ip_str - section['description'] = ip_regex.sub(repl, description) - - # Remap dictionary keys - for key in dict_keys: - if key in section: - current_dict = section[key] - for ip in list(current_dict.keys()): - try: - # Convert the ip string to an IP object - new_ip = str(mapping_ips[IP(ip)]) - except (ValueError, KeyError): - # Skip if the IP is invalid or not found in mapping_ips - continue - current_dict[new_ip] = current_dict.pop(ip) - - # Remap list items - for key in list_keys: - if key in section: - new_list = [] - for ip in section[key]: - try: - new_ip = str(mapping_ips[IP(ip)]) - except (ValueError, KeyError): - # Keep the original if invalid or not in mapping_ips - new_ip = ip - new_list.append(new_ip) - section[key] = new_list - # update win conditions with the new IPs - self._win_conditions_per_role = self._get_win_condition_per_role() - self._goal_description_per_role = self._get_goal_description_per_role() - - #update mappings stored in the environment - for net, mapping in self._network_mapping.items(): - self._network_mapping[net] = mapping_nets[mapping] - self.logger.debug(f"self._network_mapping: {self._network_mapping}") - for ip, mapping in self._ip_mapping.items(): - self._ip_mapping[ip] = mapping_ips[mapping] - self.logger.debug(f"self._ip_mapping: {self._ip_mapping}") + self.logger.info(f"Mapping IPs done:{mapping_ips}") + return mapping_nets, mapping_ips def _get_services_from_host(self, host_ip:str, controlled_hosts:set)-> set: """ @@ -924,20 +1057,22 @@ def update_log_file(self, known_data:set, action, target_host:IP): new_content = json.dumps(new_content) self._data[hostaname].add(Data(owner="system", id="logfile", type="log", size=len(new_content) , content= new_content)) - async def register_agent(self, agent_id, agent_role, agent_initial_view)->GameState: - game_state = self._create_state_from_view(agent_initial_view) - return game_state - + async def register_agent(self, agent_id, agent_role, agent_initial_view:dict, agent_win_condition_view:dict)->tuple[GameState, GameState]: + start_game_state = self._create_state_from_view(agent_initial_view) + goal_state = self._create_goal_state_from_view(agent_win_condition_view) + return start_game_state, goal_state + async def remove_agent(self, agent_id, agent_state)->bool: # No action is required return True async def step(self, agent_id, agent_state, action)->GameState: return self._execute_action(agent_state, action, agent_id) - - async def reset_agent(self, agent_id, agent_role, agent_initial_view)->GameState: + + async def reset_agent(self, agent_id, agent_role, agent_initial_view:dict, agent_win_condition_view:dict)->tuple[GameState, GameState]: game_state = self._create_state_from_view(agent_initial_view) - return game_state + goal_state = self._create_goal_state_from_view(agent_win_condition_view) + return game_state, goal_state async def reset(self)->bool: """ @@ -951,7 +1086,7 @@ async def reset(self)->bool: if self.task_config.get_use_dynamic_addresses(): if all(self._randomize_topology_requests.values()): self.logger.info("All agents requested reset with randomized topology.") - self._create_new_network_mapping() + self._dynamic_ip_change() else: self.logger.info("Not all agents requested a topology randomization. Keeping the current one.") # reset self._data to orignal state diff --git a/AIDojoCoordinator/worlds/WhiteBoxNSGCoordinator.py b/AIDojoCoordinator/worlds/WhiteBoxNSGCoordinator.py index 88b5e1b7..d9ac3285 100644 --- a/AIDojoCoordinator/worlds/WhiteBoxNSGCoordinator.py +++ b/AIDojoCoordinator/worlds/WhiteBoxNSGCoordinator.py @@ -29,7 +29,7 @@ def _initialize(self): self._generate_all_actions() self._registration_info = { "all_actions": json.dumps([v.as_dict for v in self._all_actions]), - } + } if self._all_actions is not None else {} def _generate_all_actions(self)-> list: diff --git a/tests/coordinator/test_coordinator_core.py b/tests/coordinator/test_coordinator_core.py index 7d435bb5..0853693a 100644 --- a/tests/coordinator/test_coordinator_core.py +++ b/tests/coordinator/test_coordinator_core.py @@ -244,6 +244,7 @@ async def test_process_join_game_action_success(initialized_coordinator): # Minimal working state initialized_coordinator._starting_positions_per_role = {"Attacker": MagicMock()} initialized_coordinator._goal_description_per_role = {"Attacker": "Goal"} + initialized_coordinator._win_conditions_per_role = {"Attacker": MagicMock()} initialized_coordinator._steps_limit_per_role = {"Attacker": 10} initialized_coordinator._CONFIG_FILE_HASH = "abc123" initialized_coordinator._min_required_players = 1 @@ -251,7 +252,10 @@ async def test_process_join_game_action_success(initialized_coordinator): initialized_coordinator._episode_start_event.set() # Prevent wait action = MagicMock() - action.parameters = {"agent_info": MagicMock(name="AgentX", role="Attacker")} + agent_info = MagicMock() + agent_info.name = "AgentX" + agent_info.role = "Attacker" + action.parameters = {"agent_info": agent_info} observation = SimpleNamespace( state=SimpleNamespace(as_dict={}), # empty dict works here reward=0, @@ -259,7 +263,7 @@ async def test_process_join_game_action_success(initialized_coordinator): info={} ) - with patch.object(initialized_coordinator, "register_agent", new_callable=AsyncMock, return_value=MagicMock()), \ + with patch.object(initialized_coordinator, "register_agent", new_callable=AsyncMock, return_value=(MagicMock(),MagicMock())), \ patch.object(initialized_coordinator, "_initialize_new_player", return_value=observation), \ patch.object(initialized_coordinator.logger, "info"), \ patch.object(initialized_coordinator.logger, "debug"):