diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fa7092dd..9c044518 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.2.0 + rev: v5.0.0 hooks: - id: trailing-whitespace exclude: '.*\.pdb$' @@ -20,4 +20,12 @@ repos: hooks: # Run the linter. - id: ruff - args: [ --fix ] \ No newline at end of file + args: [ --fix ] +- repo: local + hooks: + - id: pyright + name: pyright + entry: pyright + language: system + require_serial: true + types: [python] \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6a84cec3..e8dd44bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ ehrlichholo = [ "pytorch-holo", ] tdc = [ - "pytdc", + "pytdc==1.1.14", ] dockstring = [ "dockstring" @@ -65,7 +65,7 @@ rosetta_energy = [ "biopython", "pyrosetta-installer", ] -dev = ["black", "tox", "pytest", "bump-my-version"] +dev = ["black", "tox", "pytest", "bump-my-version", "pre-commit", "pyright"] docs = ["sphinx", "furo"] [project.urls] @@ -83,9 +83,15 @@ markers = [ "poli__rosetta_energy: marks tests that run in poli__rosetta_energy", "poli__ehrlich_holo: marks tests that run in poli__ehrlich_holo environment", "poli__dms: marks tests that run in poli__dms environment", + "isolation: marks tests that require isolation of the black box function", "unmarked: All other tests, which usually run in the base environment", ] +[tool.pyright] +include = ["src/poli"] +exclude = ["src/poli/core/util/proteins/rasp/inner_rasp", "src/poli/objective_repository/gfp_cbas", "examples", "src/poli/tests"] +reportIncompatibleMethodOverride = "none" + [tool.isort] profile = "black" @@ -148,4 +154,7 @@ replace = 'version: {new_version}' [dependency-groups] dev = [ "pre-commit>=4.2.0", + "pyright>=1.1.403", + "pytest>=8.4.0", + "ruff>=0.12.3", ] diff --git a/src/poli/benchmarks/guacamol.py b/src/poli/benchmarks/guacamol.py index 9cf85b8a..f4345d38 100644 --- a/src/poli/benchmarks/guacamol.py +++ b/src/poli/benchmarks/guacamol.py @@ -125,7 +125,7 @@ def _initialize_problem(self, index: int) -> Problem: problem_factory = self.problem_factories[index] problem = problem_factory.create( - string_representation=self.string_representation, + string_representation=self.string_representation, # type: ignore seed=self.seed, batch_size=self.batch_size, parallelize=self.parallelize, diff --git a/src/poli/benchmarks/pmo.py b/src/poli/benchmarks/pmo.py index a004b1fc..e0c3592c 100644 --- a/src/poli/benchmarks/pmo.py +++ b/src/poli/benchmarks/pmo.py @@ -75,7 +75,7 @@ def __init__( batch_size: Union[int, None] = None, parallelize: bool = False, num_workers: Union[int, None] = None, - evaluation_budget: int = None, + evaluation_budget: int | None = None, ) -> None: super().__init__( string_representation=string_representation, diff --git a/src/poli/benchmarks/toy_continuous_functions_benchmark.py b/src/poli/benchmarks/toy_continuous_functions_benchmark.py index efe1d426..b2876ac9 100644 --- a/src/poli/benchmarks/toy_continuous_functions_benchmark.py +++ b/src/poli/benchmarks/toy_continuous_functions_benchmark.py @@ -15,13 +15,14 @@ https://www.sfu.ca/~ssurjano/optimization.html. """ -from typing import List, Union +from typing import Sequence, Union, cast from poli.core.abstract_benchmark import AbstractBenchmark from poli.core.problem import Problem from poli.objective_repository import ToyContinuousProblemFactory from poli.objective_repository.toy_continuous_problem.toy_continuous_problem import ( POSSIBLE_FUNCTIONS, + POSSIBLE_FUNCTIONS_TYPE, SIX_DIMENSIONAL_PROBLEMS, TWO_DIMENSIONAL_PROBLEMS, ) @@ -49,12 +50,12 @@ def __init__( self, n_dimensions: int = 2, embed_in: Union[int, None] = None, - dimensions_to_embed_in: Union[List[int], None] = None, + dimensions_to_embed_in: Union[list[int], None] = None, seed: Union[int, None] = None, batch_size: Union[int, None] = None, parallelize: bool = False, num_workers: Union[int, None] = None, - evaluation_budget: Union[int, List[int]] = None, + evaluation_budget: int | None = None, ) -> None: super().__init__( seed=seed, @@ -66,7 +67,7 @@ def __init__( self.n_dimensions = n_dimensions self.embed_in = embed_in self.dimensions_to_embed_in = dimensions_to_embed_in - self.function_names = list( + self.function_names: Sequence[POSSIBLE_FUNCTIONS_TYPE] = list( # type: ignore ( set(POSSIBLE_FUNCTIONS) - set(TWO_DIMENSIONAL_PROBLEMS) @@ -78,7 +79,9 @@ def __init__( ) def _initialize_problem(self, index: int) -> Problem: - problem_factory: ToyContinuousProblemFactory = self.problem_factories[index] + problem_factory: ToyContinuousProblemFactory = cast( + ToyContinuousProblemFactory, self.problem_factories[index] + ) problem = problem_factory.create( function_name=self.function_names[index], @@ -121,7 +124,7 @@ def __init__( batch_size: Union[int, None] = None, parallelize: bool = False, num_workers: Union[int, None] = None, - evaluation_budget: Union[int, List[int]] = None, + evaluation_budget: int | None = None, ) -> None: super().__init__( seed, @@ -134,7 +137,9 @@ def __init__( self.problem_factories = [ToyContinuousProblemFactory()] * len(self.embed_in) def _initialize_problem(self, index: int) -> Problem: - problem_factory: ToyContinuousProblemFactory = self.problem_factories[index] + problem_factory: ToyContinuousProblemFactory = cast( + ToyContinuousProblemFactory, self.problem_factories[index] + ) problem = problem_factory.create( function_name="branin_2d", @@ -174,7 +179,7 @@ def __init__( batch_size: Union[int, None] = None, parallelize: bool = False, num_workers: Union[int, None] = None, - evaluation_budget: Union[int, List[int]] = None, + evaluation_budget: int | None = None, ) -> None: super().__init__( seed, @@ -187,7 +192,9 @@ def __init__( self.problem_factories = [ToyContinuousProblemFactory()] * len(self.embed_in) def _initialize_problem(self, index: int) -> Problem: - problem_factory: ToyContinuousProblemFactory = self.problem_factories[index] + problem_factory: ToyContinuousProblemFactory = cast( + ToyContinuousProblemFactory, self.problem_factories[index] + ) if index == 0: problem = problem_factory.create( diff --git a/src/poli/core/abstract_benchmark.py b/src/poli/core/abstract_benchmark.py index 032c97a8..146fde86 100644 --- a/src/poli/core/abstract_benchmark.py +++ b/src/poli/core/abstract_benchmark.py @@ -1,13 +1,13 @@ from __future__ import annotations -from typing import List, Union +from typing import Union from poli.core.abstract_problem_factory import AbstractProblemFactory from poli.core.problem import Problem class AbstractBenchmark: - problem_factories: List[AbstractProblemFactory] + problem_factories: list[AbstractProblemFactory] index: int = 0 def __init__( @@ -16,7 +16,7 @@ def __init__( batch_size: Union[int, None] = None, parallelize: bool = False, num_workers: Union[int, None] = None, - evaluation_budget: int = None, + evaluation_budget: int | None = None, ) -> None: self.seed = seed self.batch_size = batch_size @@ -46,7 +46,7 @@ def info(self) -> str: raise NotImplementedError @property - def problem_names(self) -> List[str]: + def problem_names(self) -> list[str]: return [ problem_factory.__module__.replace( "poli.objective_repository.", "" diff --git a/src/poli/core/abstract_black_box.py b/src/poli/core/abstract_black_box.py index 57b5e030..0aad58b0 100644 --- a/src/poli/core/abstract_black_box.py +++ b/src/poli/core/abstract_black_box.py @@ -5,9 +5,11 @@ from __future__ import annotations from multiprocessing import Pool, cpu_count +from typing import cast from warnings import warn import numpy as np +from numpy.typing import NDArray from poli.core.black_box_information import BlackBoxInformation from poli.core.exceptions import BudgetExhaustedException @@ -22,17 +24,17 @@ class AbstractBlackBox: Parameters ---------- - batch_size : int, optional + batch_size : int | None, optional The batch size for evaluating the black box function. Default is None. parallelize : bool, optional Flag indicating whether to evaluate the black box function in parallel. Default is False. - num_workers : int, optional + num_workers : int | None, optional The number of workers to use for parallel evaluation. Default is None, which uses half of the available CPU cores. - evaluation_budget : int, optional + evaluation_budget : int | None, optional The maximum number of evaluations allowed for the black box function. - Default is None). + Default is None, which means an infinite budget. Attributes ---------- @@ -44,7 +46,7 @@ class AbstractBlackBox: Flag indicating whether to evaluate the black box function in parallel. num_workers : int The number of workers to use for parallel evaluation. - batch_size : int or None + batch_size : int | None The batch size for evaluating the black box function. Methods @@ -84,13 +86,13 @@ def __init__( Parameters ---------- - batch_size : int, optional + batch_size : int | None, optional The batch size for parallel execution, by default None. parallelize : bool, optional Flag indicating whether to parallelize the execution, by default False. - num_workers : int, optional + num_workers : int | None, optional The number of workers for parallel execution, by default we use half the available CPUs. - evaluation_budget : int, optional + evaluation_budget : int | None, optional The maximum number of evaluations allowed for the black box function, by default it is None, which means no limit. """ self.observer = None @@ -145,13 +147,13 @@ def set_observer(self, observer: AbstractObserver): ) self.observer = observer - def set_observer_info(self, observer_info: object): + def set_observer_info(self, observer_info: dict[str, object] | None): """ Set the observer information after initialization. Parameters ---------- - observer_info : object + observer_info : dict[str, object] The information given by the observer after initialization. """ self.observer_info = observer_info @@ -160,7 +162,7 @@ def reset_evaluation_budget(self): """Resets the evaluation budget by setting the number of evaluations made to 0.""" self.num_evaluations = 0 - def __call__(self, x: np.array, context=None): + def __call__(self, x: NDArray[np.str_], context=None): """Calls the black box function. The purpose of this function is to enforce that inputs are equal across @@ -340,7 +342,7 @@ def terminate(self) -> None: Terminate the black box optimization problem. """ if hasattr(self, "inner_function"): - self.inner_function.terminate() + self.inner_function.terminate() # type: ignore # if self.observer is not None: # # NOTE: terminating a problem should gracefully end the observer process -> write the last state. # self.observer.finish() @@ -387,13 +389,13 @@ def __init__(self, f: AbstractBlackBox): batch_size=f.batch_size, parallelize=f.parallelize, num_workers=f.num_workers, - evaluation_budget=f.evaluation_budget, + evaluation_budget=cast(int | None, f.evaluation_budget), ) - def __call__(self, x, context=None): + def __call__(self, x: NDArray[np.str_], context=None): return -self.f.__call__(x, context) - def _black_box(self, x, context=None): + def _black_box(self, x: NDArray[np.str_], context=None): return self.f._black_box(x, context) def __str__(self) -> str: diff --git a/src/poli/core/abstract_problem_factory.py b/src/poli/core/abstract_problem_factory.py index f7caa825..f831520e 100644 --- a/src/poli/core/abstract_problem_factory.py +++ b/src/poli/core/abstract_problem_factory.py @@ -30,11 +30,11 @@ class AbstractProblemFactory(metaclass=MetaProblemFactory): def create( self, - seed: int = None, - batch_size: int = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/core/benchmark_information.py b/src/poli/core/benchmark_information.py index 98a88485..f034ffb6 100644 --- a/src/poli/core/benchmark_information.py +++ b/src/poli/core/benchmark_information.py @@ -10,7 +10,7 @@ def __init__( fixed_length: bool, deterministic: bool, alphabet: list, - log_transform_recommended: bool = None, + log_transform_recommended: bool | None = None, discrete: bool = True, fidelity: Union[Literal["high", "low"], None] = None, padding_token: str = "", @@ -111,7 +111,7 @@ def get_alphabet(self) -> list: """ return self.alphabet - def log_transform_recommended(self) -> bool: + def is_log_transform_recommended(self) -> bool | None: """ Returns whether the black-box recommends log-transforming the targets. diff --git a/src/poli/core/black_box_information.py b/src/poli/core/black_box_information.py index 93de034e..d5179893 100644 --- a/src/poli/core/black_box_information.py +++ b/src/poli/core/black_box_information.py @@ -12,6 +12,8 @@ - The alphabet of allowed characters. """ +from __future__ import annotations + from typing import Literal, Union import numpy as np @@ -21,18 +23,20 @@ class BlackBoxInformation: def __init__( self, name: str, - max_sequence_length: int, + max_sequence_length: int | Literal["inf"] | float, aligned: bool, fixed_length: bool, deterministic: bool, - alphabet: list, - log_transform_recommended: bool = None, + alphabet: list[str] | None = None, + log_transform_recommended: bool | None = None, discrete: bool = True, fidelity: Union[Literal["high", "low"], None] = None, - padding_token: str = "", + padding_token: str | None = None, ): self.name = name - self.max_sequence_length = max_sequence_length + self.max_sequence_length = ( + max_sequence_length if max_sequence_length != "inf" else np.inf + ) self.aligned = aligned self.fixed_length = fixed_length self.deterministic = deterministic @@ -40,7 +44,7 @@ def __init__( self.log_transform_recommended = log_transform_recommended self.discrete = discrete self.fidelity = fidelity - self.padding_token = padding_token + self.padding_token = padding_token if padding_token is not None else "" def get_problem_name(self) -> str: """Returns the problem's name. @@ -52,14 +56,15 @@ def get_problem_name(self) -> str: """ return self.name - def get_max_sequence_length(self) -> int: + def get_max_sequence_length(self) -> int | float: """ Returns the maximum sequence length allowed by the black-box. Returns -------- - max_sequence_length : int - The length of the longest sequence. + max_sequence_length : int | float + The length of the longest sequence. If the maximum sequence length is + infinity, it returns np.inf. """ return self.max_sequence_length @@ -116,7 +121,7 @@ def sequences_are_aligned(self) -> bool: """ return self.aligned - def get_alphabet(self) -> list: + def get_alphabet(self) -> list[str] | None: """ Returns the alphabet of allowed characters. @@ -127,13 +132,13 @@ def get_alphabet(self) -> list: """ return self.alphabet - def log_transform_recommended(self) -> bool: + def is_log_transform_recommended(self) -> bool | None: """ Returns whether the black-box recommends log-transforming the targets. Returns -------- - log_transform_recommended : bool + log_transform_recommended : bool | None Whether the black-box recommends log-transforming the targets. """ return self.log_transform_recommended diff --git a/src/poli/core/chemistry/data_packages/random_molecules_data_package.py b/src/poli/core/chemistry/data_packages/random_molecules_data_package.py index 99658db3..38e0e01a 100644 --- a/src/poli/core/chemistry/data_packages/random_molecules_data_package.py +++ b/src/poli/core/chemistry/data_packages/random_molecules_data_package.py @@ -40,7 +40,7 @@ def __init__( string_representation: Literal["SMILES", "SELFIES"], n_molecules: int = 10, seed: int | None = None, - tokenize_with: Callable[[str], list[str]] = None, + tokenize_with: Callable[[str], list[str]] | None = None, ): assert ( n_molecules <= 5000 diff --git a/src/poli/core/chemistry/tdc_black_box.py b/src/poli/core/chemistry/tdc_black_box.py index d3cf8c07..8546c934 100644 --- a/src/poli/core/chemistry/tdc_black_box.py +++ b/src/poli/core/chemistry/tdc_black_box.py @@ -78,12 +78,12 @@ def __init__( oracle_name: str, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, **kwargs_for_oracle, ): if parallelize: @@ -99,7 +99,9 @@ def __init__( ) self.oracle_name = oracle_name self.alphabet = alphabet - self.max_sequence_length = max_sequence_length + self.max_sequence_length = ( + max_sequence_length if max_sequence_length != "inf" else np.inf + ) self.string_representation = string_representation from_smiles = string_representation.upper() == "SMILES" diff --git a/src/poli/core/chemistry/tdc_isolated_function.py b/src/poli/core/chemistry/tdc_isolated_function.py index 52a23eff..cf472069 100644 --- a/src/poli/core/chemistry/tdc_isolated_function.py +++ b/src/poli/core/chemistry/tdc_isolated_function.py @@ -11,6 +11,8 @@ (October 2022): 1033-36. https://doi.org/10.1038/s41589-022-01131-2. """ +# pyright: reportMissingImports=false + import numpy as np from tdc import Oracle diff --git a/src/poli/core/chemistry/tdc_problem.py b/src/poli/core/chemistry/tdc_problem.py index 7a3ffa1a..c2393515 100644 --- a/src/poli/core/chemistry/tdc_problem.py +++ b/src/poli/core/chemistry/tdc_problem.py @@ -1,14 +1,28 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, cast + from poli.core.chemistry.data_packages import RandomMoleculesDataPackage from poli.core.chemistry.tdc_black_box import TDCBlackBox from poli.core.problem import Problem +if TYPE_CHECKING: + from poli.objective_repository.rdkit_logp.register import LogPBlackBox + from poli.objective_repository.rdkit_qed.register import QEDBlackBox + class TDCProblem(Problem): def __init__( - self, black_box: TDCBlackBox, x0, data_package=None, strict_validation=True + self, + black_box: TDCBlackBox | QEDBlackBox | LogPBlackBox, + x0, + data_package=None, + strict_validation=True, ): if data_package is None: - data_package = RandomMoleculesDataPackage(black_box.string_representation) + data_package = RandomMoleculesDataPackage( + cast(Literal["SELFIES", "SMILES"], black_box.string_representation) + ) super().__init__( black_box=black_box, diff --git a/src/poli/core/multi_objective_black_box.py b/src/poli/core/multi_objective_black_box.py index 17f05bc3..b6e4d91d 100644 --- a/src/poli/core/multi_objective_black_box.py +++ b/src/poli/core/multi_objective_black_box.py @@ -4,8 +4,6 @@ objective functions. """ -from typing import List - import numpy as np from poli.core.abstract_black_box import AbstractBlackBox @@ -22,12 +20,12 @@ class MultiObjectiveBlackBox(AbstractBlackBox): ----------- batch_size : int, optional The batch size for evaluating the black box function. Defaults to None. - objective_functions : List[AbstractBlackBox], required + objective_functions : list[AbstractBlackBox], required The list of objective functions to be evaluated. Defaults to None. Attributes ---------- - objective_functions : List[AbstractBlackBox] + objective_functions : list[AbstractBlackBox] The list of objective functions to be evaluated. Methods @@ -49,15 +47,15 @@ class MultiObjectiveBlackBox(AbstractBlackBox): def __init__( self, - objective_functions: List[AbstractBlackBox], - batch_size: int = None, + objective_functions: list[AbstractBlackBox], + batch_size: int | None = None, ) -> None: """ Initialize the MultiObjectiveBlackBox class. Parameters ----------- - objective_functions : List[AbstractBlackBox] + objective_functions : list[AbstractBlackBox] The list of objective functions. batch_size : int, optional The batch size. Defaults to None. diff --git a/src/poli/core/proteins/foldx_black_box.py b/src/poli/core/proteins/foldx_black_box.py index 7705d9f7..5b6efb1c 100644 --- a/src/poli/core/proteins/foldx_black_box.py +++ b/src/poli/core/proteins/foldx_black_box.py @@ -14,7 +14,7 @@ from multiprocessing import cpu_count from pathlib import Path from time import time -from typing import List, Union +from typing import Union from uuid import uuid4 from poli.core.abstract_black_box import AbstractBlackBox @@ -48,9 +48,9 @@ class FoldxBlackBox(AbstractBlackBox): Flag indicating whether to parallelize the simulations. (default: False) num_workers : int, optional The number of workers for parallelization. (default: None) - wildtype_pdb_path : Union[Path, List[Path]], required + wildtype_pdb_path : Union[Path, list[Path]], required The path(s) to the wildtype PDB file(s). (default: None) - alphabet : List[str], optional + alphabet : list[str], optional The list of allowed amino acids. (default: None) experiment_id : str, optional The experiment ID. (default: None) @@ -67,13 +67,13 @@ class FoldxBlackBox(AbstractBlackBox): The experiment ID. tmp_folder : Path The temporary folder path. - wildtype_pdb_paths : List[Path] + wildtype_pdb_paths : list[Path] The list of repaired wildtype PDB file paths. - wildtype_residues : List[List[Residue]] + wildtype_residues : list[list[Residue]] The list of wildtype residues for each PDB file. - wildtype_amino_acids : List[List[str]] + wildtype_amino_acids : list[list[str]] The list of wildtype amino acids for each PDB file. - wildtype_residue_strings : List[str] + wildtype_residue_strings : list[str] The list of wildtype residue strings for each PDB file. Methods @@ -85,15 +85,15 @@ class FoldxBlackBox(AbstractBlackBox): def __init__( self, - info: BlackBoxInformation = None, - batch_size: int = None, + info: BlackBoxInformation | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, - wildtype_pdb_path: Union[Path, List[Path]] = None, - alphabet: List[str] = None, - experiment_id: str = None, - tmp_folder: Path = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, + wildtype_pdb_path: Union[Path, list[Path]] | None = None, + alphabet: list[str] | None = None, + experiment_id: str | None = None, + tmp_folder: Path | None = None, eager_repair: bool = False, verbose: bool = False, ): @@ -112,9 +112,9 @@ def __init__( The number of workers for parallelization. (default: None) evaluation_budget : int, optional The evaluation budget. (default: float('inf')) - wildtype_pdb_path : Union[Path, List[Path]], optional + wildtype_pdb_path : Union[Path, list[Path]], optional The path(s) to the wildtype PDB file(s). (default: None) - alphabet : List[str], optional + alphabet : list[str], optional The list of allowed amino acids. (default: None) experiment_id : str, optional The experiment ID. (default: None) @@ -151,7 +151,6 @@ def __init__( batch_size = 1 super().__init__( - info=info, batch_size=batch_size, parallelize=parallelize, num_workers=num_workers, @@ -167,6 +166,11 @@ def __init__( self.tmp_folder = tmp_folder if tmp_folder is not None else DEFAULT_TMP_PATH if alphabet is None: + if info is None: + raise ValueError( + "Missing required keyword argument: alphabet: list[str]. " + "Alphabet must be provided if not in info." + ) alphabet = info.alphabet if isinstance(wildtype_pdb_path, str): diff --git a/src/poli/core/proteins/foldx_isolated_function.py b/src/poli/core/proteins/foldx_isolated_function.py index 0364bb2a..4a018e63 100644 --- a/src/poli/core/proteins/foldx_isolated_function.py +++ b/src/poli/core/proteins/foldx_isolated_function.py @@ -1,6 +1,6 @@ from pathlib import Path from time import time -from typing import List, Union +from typing import Union from uuid import uuid4 import numpy as np @@ -27,7 +27,7 @@ class FoldxIsolatedFunction(AbstractIsolatedFunction): Parameters ----------- - wildtype_pdb_path : Union[Path, List[Path]], required + wildtype_pdb_path : Union[Path, list[Path]], required The path(s) to the wildtype PDB file(s). (default: None) experiment_id : str, optional The experiment ID. (default: None) @@ -44,13 +44,13 @@ class FoldxIsolatedFunction(AbstractIsolatedFunction): The experiment ID. tmp_folder : Path The temporary folder path. - wildtype_pdb_paths : List[Path] + wildtype_pdb_paths : list[Path] The list of repaired wildtype PDB file paths. - wildtype_residues : List[List[Residue]] + wildtype_residues : list[list[Residue]] The list of wildtype residues for each PDB file. - wildtype_amino_acids : List[List[str]] + wildtype_amino_acids : list[list[str]] The list of wildtype amino acids for each PDB file. - wildtype_residue_strings : List[str] + wildtype_residue_strings : list[str] The list of wildtype residue strings for each PDB file. Methods @@ -62,9 +62,9 @@ class FoldxIsolatedFunction(AbstractIsolatedFunction): def __init__( self, - wildtype_pdb_path: Union[Path, List[Path]], - experiment_id: str = None, - tmp_folder: Path = None, + wildtype_pdb_path: Union[Path, list[Path]], + experiment_id: str | None = None, + tmp_folder: Path | None = None, eager_repair: bool = False, verbose: bool = False, ): @@ -73,7 +73,7 @@ def __init__( Parameters ----------- - wildtype_pdb_path : Union[Path, List[Path]] + wildtype_pdb_path : Union[Path, list[Path]] The path(s) to the wildtype PDB file(s). experiment_id : str, optional The experiment ID. (default: None) diff --git a/src/poli/core/registry.py b/src/poli/core/registry.py index 7c3fd3ff..8d368e27 100644 --- a/src/poli/core/registry.py +++ b/src/poli/core/registry.py @@ -1,11 +1,12 @@ """This module contains utilities for registering problems and observers.""" +from __future__ import annotations + import configparser import warnings from pathlib import Path -from typing import List, Type, Union +from typing import Type, Union -from poli.core.abstract_black_box import AbstractBlackBox from poli.core.abstract_isolated_function import AbstractIsolatedFunction from poli.core.util.abstract_observer import AbstractObserver from poli.core.util.objective_management.make_run_script import ( @@ -32,9 +33,9 @@ def register_observer( observer: Union[AbstractObserver, Type[AbstractObserver]], - conda_environment_location: str = None, - python_paths: List[str] = None, - observer_name: str = None, + conda_environment_location: str | None = None, + python_paths: list[str] | None = None, + observer_name: str | None = None, set_as_default_observer: bool = True, ): """Defines an external observer to be run in a separate process. @@ -53,7 +54,7 @@ def register_observer( The observer to be registered. conda_environment_location : str The location of the conda environment to be used. - python_paths : List[str] + python_paths : list[str] A list of paths to append to the python path of the run script. observer_name : str The name of the observer to be registered. @@ -73,7 +74,7 @@ def register_observer( else: non_instance_observer = observer.__class__ if observer_name is None: - observer_name = observer.__name__ + observer_name = non_instance_observer.__name__ run_script_location = make_observer_script( non_instance_observer, conda_environment_location, python_paths ) @@ -122,10 +123,10 @@ def remove_default_observer(): def register_isolated_function( - isolated_function: Union[AbstractBlackBox, AbstractIsolatedFunction], + isolated_function: Union[type[AbstractIsolatedFunction], AbstractIsolatedFunction], name: str, - conda_environment_name: Union[str, Path] = None, - python_paths: List[str] = None, + conda_environment_name: Union[str, Path, None] = None, + python_paths: list[str] | None = None, force: bool = True, **kwargs, ): diff --git a/src/poli/core/util/abstract_observer.py b/src/poli/core/util/abstract_observer.py index 5e089898..485e59a2 100644 --- a/src/poli/core/util/abstract_observer.py +++ b/src/poli/core/util/abstract_observer.py @@ -55,9 +55,9 @@ def observe(self, x: np.ndarray, y: np.ndarray, context=None) -> None: def initialize_observer( self, problem_setup_info: BlackBoxInformation, - caller_info: object, - seed: int, - ) -> object: + caller_info: dict[str, object] | None, + seed: int | None, + ) -> dict[str, object]: """ Initialize the observer. diff --git a/src/poli/core/util/alignment/is_aligned.py b/src/poli/core/util/alignment/is_aligned.py index 3d6f7f38..15f98bfc 100644 --- a/src/poli/core/util/alignment/is_aligned.py +++ b/src/poli/core/util/alignment/is_aligned.py @@ -5,7 +5,9 @@ import numpy as np -def is_aligned_input(x: np.ndarray, maximum_sequence_length: int = None) -> bool: +def is_aligned_input( + x: np.ndarray, maximum_sequence_length: int | float | None = None +) -> bool: """Utility function to check if the input to an "aligned" problem is indeed aligned. diff --git a/src/poli/core/util/chemistry/string_to_molecule.py b/src/poli/core/util/chemistry/string_to_molecule.py index fbfaf229..cc541443 100644 --- a/src/poli/core/util/chemistry/string_to_molecule.py +++ b/src/poli/core/util/chemistry/string_to_molecule.py @@ -2,15 +2,17 @@ molecules into molecules in RDKit. """ -from typing import List +# pyright: reportAttributeAccessIssue=false + +from __future__ import annotations import selfies as sf from rdkit import Chem def translate_smiles_to_selfies( - smiles_strings: List[str], strict: bool = False -) -> List[str]: + smiles_strings: list[str], strict: bool = False +) -> list[str]: """Translates a list of SMILES strings to SELFIES strings. Given a list of SMILES strings, returns the translation @@ -24,14 +26,14 @@ def translate_smiles_to_selfies( Parameters ---------- - smiles_strings : List[str] + smiles_strings : list[str] A list of SMILES strings. strict : bool, optional If True, raise an error if a SMILES string in the list cannot be parsed. Returns ------- - List[str] + list[str] A list of SELFIES strings. """ selfies_strings = [] @@ -48,8 +50,8 @@ def translate_smiles_to_selfies( def translate_selfies_to_smiles( - selfies_strings: List[str], strict: bool = False -) -> List[str]: + selfies_strings: list[str], strict: bool = False +) -> list[str]: """Translates a list of SELFIES strings to SMILES strings. Given a list of SELFIES strings, returns the translation @@ -62,14 +64,14 @@ def translate_selfies_to_smiles( Parameters ---------- - selfies_strings : List[str] + selfies_strings : list[str] A list of SELFIES strings. strict : bool, optional If True, raise an error if a SELFIES string in the list cannot be parsed. Returns ------- - smiles_strings : List[str] + smiles_strings : list[str] A list of SMILES strings. """ smiles_strings = [] @@ -85,7 +87,9 @@ def translate_selfies_to_smiles( return smiles_strings -def smiles_to_molecules(smiles_strings: List[str], strict: bool = False) -> Chem.Mol: +def smiles_to_molecules( + smiles_strings: list[str], strict: bool = False +) -> list[Chem.Mol]: """Converts a list of SMILES strings to RDKit molecules. Converts a list of SMILES strings to RDKit molecules. If strict is True, @@ -93,7 +97,7 @@ def smiles_to_molecules(smiles_strings: List[str], strict: bool = False) -> Chem Parameters ---------- - smiles : List[str] + smiles : list[str] A list of SMILES string. Returns @@ -113,7 +117,7 @@ def smiles_to_molecules(smiles_strings: List[str], strict: bool = False) -> Chem return molecules -def selfies_to_molecules(selfies_strings: List[str]) -> Chem.Mol: +def selfies_to_molecules(selfies_strings: list[str]) -> list[Chem.Mol]: """Converts a list of selfies strings to RDKit molecules. Parameters @@ -133,7 +137,7 @@ def selfies_to_molecules(selfies_strings: List[str]) -> Chem.Mol: return molecule -def strings_to_molecules(molecule_strings: List[str], from_selfies: bool = False): +def strings_to_molecules(molecule_strings: list[str], from_selfies: bool = False): """ Convert a string representation of a molecule to an RDKit molecule. diff --git a/src/poli/core/util/external_observer.py b/src/poli/core/util/external_observer.py index 135be019..d741fa56 100644 --- a/src/poli/core/util/external_observer.py +++ b/src/poli/core/util/external_observer.py @@ -1,6 +1,6 @@ """External observer, which can be run in an isolated process.""" -from typing import Any +from typing import Any, cast import numpy as np @@ -32,7 +32,7 @@ class ExternalObserver(AbstractObserver): Retrieves the attribute of the underlying observer. """ - def __init__(self, observer_name: str = None, **kwargs_for_observer): + def __init__(self, observer_name: str | None = None, **kwargs_for_observer): """ Initialize the ExternalObserver object. @@ -72,10 +72,10 @@ def observe(self, x: np.ndarray, y: np.ndarray, context=None) -> None: """ # We send the observation - self.process_wrapper.send(["OBSERVATION", x, y, context]) + cast(ProcessWrapper, self.process_wrapper).send(["OBSERVATION", x, y, context]) # And we make sure the process received and logged it correctly - msg_type, *msg = self.process_wrapper.recv() + msg_type, *msg = cast(ProcessWrapper, self.process_wrapper).recv() if msg_type == "EXCEPTION": e, tb = msg print(tb) @@ -136,10 +136,10 @@ def initialize_observer( def log(self, algorithm_info: dict): # We send the observation - self.process_wrapper.send(["LOG", algorithm_info]) + cast(ProcessWrapper, self.process_wrapper).send(["LOG", algorithm_info]) # And we make sure the process received and logged it correctly - msg_type, *msg = self.process_wrapper.recv() + msg_type, *msg = cast(ProcessWrapper, self.process_wrapper).recv() if msg_type == "EXCEPTION": e, tb = msg print(tb) @@ -164,8 +164,8 @@ def __getattr__(self, __name: str) -> Any: black-box function by sending a message to the process w. the msg_type "ATTRIBUTE". """ - self.process_wrapper.send(["ATTRIBUTE", __name]) - msg_type, *msg = self.process_wrapper.recv() + cast(ProcessWrapper, self.process_wrapper).send(["ATTRIBUTE", __name]) + msg_type, *msg = cast(ProcessWrapper, self.process_wrapper).recv() if msg_type == "EXCEPTION": e, tb = msg print(tb) diff --git a/src/poli/core/util/files/download_files_from_github.py b/src/poli/core/util/files/download_files_from_github.py index 7a9ab100..fe20538a 100644 --- a/src/poli/core/util/files/download_files_from_github.py +++ b/src/poli/core/util/files/download_files_from_github.py @@ -8,6 +8,9 @@ https://gist.github.com/pdashford/2e4bcd4fc2343e2fd03efe4da17f577d?permalink_comment_id=4274705#gistcomment-4274705 """ +# pyright: reportMissingImports=false +# pyright: reportMissingModuleSource=false + import base64 import os from pathlib import Path @@ -63,9 +66,9 @@ def get_sha_for_tag(repository: Repository, tag: str) -> str: def download_file_from_github_repository( repository_name: str, file_path_in_repository: str, - download_path_for_file: str, + download_path_for_file: str | Path, tag: str = "master", - commit_sha: str = None, + commit_sha: str | None = None, exist_ok: bool = False, parent_folders_exist_ok: bool = True, verbose: bool = False, @@ -134,7 +137,7 @@ def _download_file_from_github_repo( repository: Repository, commit_sha: str, file_path_in_repository: str, - download_path_for_file: str, + download_path_for_file: str | Path, exist_ok: bool = False, parent_folders_exist_ok: bool = True, verbose: bool = False, diff --git a/src/poli/core/util/inter_process_communication/process_wrapper.py b/src/poli/core/util/inter_process_communication/process_wrapper.py index 590cd5e3..81460e27 100644 --- a/src/poli/core/util/inter_process_communication/process_wrapper.py +++ b/src/poli/core/util/inter_process_communication/process_wrapper.py @@ -5,12 +5,12 @@ import logging import subprocess import time -from multiprocessing.connection import Client, Listener +from multiprocessing.connection import Client, Connection, Listener from pathlib import Path from uuid import uuid4 -def get_connection(port: int, password: str) -> Client: +def get_connection(port: int, password: str) -> Connection: """ Get a connection to a server. @@ -86,7 +86,7 @@ def __init__(self, run_script, **kwargs_for_factory): self.listener = Listener(address, authkey=self.password.encode()) # TODO: very hacky way to read out the socket! (but the listener is not very cooperative) - self.port = self.listener._listener._socket.getsockname()[1] + self.port = self.listener._listener._socket.getsockname()[1] # type: ignore # here is a VERY crucial step # we expect the shell script to take port and password as arguments, as well as other arguments passed by the user # when calling objective_factory.create diff --git a/src/poli/core/util/isolation/external_function.py b/src/poli/core/util/isolation/external_function.py index 1d9ae149..357dee7f 100644 --- a/src/poli/core/util/isolation/external_function.py +++ b/src/poli/core/util/isolation/external_function.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, cast from poli.core.abstract_isolated_function import AbstractIsolatedFunction from poli.core.util.inter_process_communication.process_wrapper import ProcessWrapper @@ -45,8 +45,8 @@ def __call__(self, x, context=None): y : np.ndarray The output data points. """ - self.process_wrapper.send(["QUERY", x, context]) - msg_type, *val = self.process_wrapper.recv() + cast(ProcessWrapper, self.process_wrapper).send(["QUERY", x, context]) + msg_type, *val = cast(ProcessWrapper, self.process_wrapper).recv() if msg_type == "EXCEPTION": e, traceback_ = val print(traceback_) @@ -89,8 +89,8 @@ def __getattr__(self, __name: str) -> Any: attribute : Any The attribute of the underlying black-box function. """ - self.process_wrapper.send(["IS_METHOD", __name]) - msg_type, *msg = self.process_wrapper.recv() + cast(ProcessWrapper, self.process_wrapper).send(["IS_METHOD", __name]) + msg_type, *msg = cast(ProcessWrapper, self.process_wrapper).recv() if msg_type == "EXCEPTION": e, traceback_ = msg print(traceback_) @@ -102,8 +102,8 @@ def __getattr__(self, __name: str) -> Any: if is_method: return lambda *args, **kwargs: self._method_call(__name, *args, **kwargs) else: - self.process_wrapper.send(["ATTRIBUTE", __name]) - msg_type, *msg = self.process_wrapper.recv() + cast(ProcessWrapper, self.process_wrapper).send(["ATTRIBUTE", __name]) + msg_type, *msg = cast(ProcessWrapper, self.process_wrapper).recv() if msg_type == "EXCEPTION": e, traceback_ = msg print(traceback_) @@ -125,8 +125,10 @@ def _method_call(self, method_name: str, *args, **kwargs) -> Any: method_name : str The name of the method. """ - self.process_wrapper.send(["METHOD", method_name, args, kwargs]) - msg_type, *msg = self.process_wrapper.recv() + cast(ProcessWrapper, self.process_wrapper).send( + ["METHOD", method_name, args, kwargs] + ) + msg_type, *msg = cast(ProcessWrapper, self.process_wrapper).recv() if msg_type == "EXCEPTION": e, traceback_ = msg print(traceback_) diff --git a/src/poli/core/util/isolation/instancing.py b/src/poli/core/util/isolation/instancing.py index 83ba6cb8..300dc640 100644 --- a/src/poli/core/util/isolation/instancing.py +++ b/src/poli/core/util/isolation/instancing.py @@ -167,7 +167,7 @@ def __run_file_in_env(env_name: str, file_path: Path): def __register_isolated_file( environment_file: Path, isolated_file: Path, - name_for_show: str = None, + name_for_show: str | None = None, quiet: bool = False, ): """ diff --git a/src/poli/core/util/isolation/isolated_black_box.py b/src/poli/core/util/isolation/isolated_black_box.py deleted file mode 100644 index fc055b3d..00000000 --- a/src/poli/core/util/isolation/isolated_black_box.py +++ /dev/null @@ -1,20 +0,0 @@ -from poli.core.abstract_black_box import AbstractBlackBox - - -class IsolatedBlackBox(AbstractBlackBox): - def __init__( - self, - name: str = None, - batch_size: int = None, - parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, - **kwargs_for_black_box, - ): - - super().__init__( - batch_size=batch_size, - parallelize=parallelize, - num_workers=num_workers, - evaluation_budget=evaluation_budget, - ) diff --git a/src/poli/core/util/objective_management/make_run_script.py b/src/poli/core/util/objective_management/make_run_script.py index 94d664dd..bea42ac1 100644 --- a/src/poli/core/util/objective_management/make_run_script.py +++ b/src/poli/core/util/objective_management/make_run_script.py @@ -1,12 +1,15 @@ """This module contains utilities for creating run scripts for problems and observers.""" +from __future__ import annotations + import inspect import os import stat import sys from os.path import basename, dirname, join from pathlib import Path -from typing import List, Type, Union +from types import ModuleType +from typing import Type, Union, cast from poli import external_isolated_function_script from poli.core.abstract_isolated_function import AbstractIsolatedFunction @@ -22,9 +25,14 @@ def make_isolated_function_script( - isolated_function: AbstractIsolatedFunction, - conda_environment_name: Union[str, Path] = None, - python_paths: List[str] = None, + isolated_function: ( + type[AbstractIsolatedFunction] + | type[AbstractObserver] + | AbstractIsolatedFunction + | AbstractObserver + ), + conda_environment_name: Union[str, Path, None] = None, + python_paths: list[str] | None = None, cwd=None, **kwargs, ): @@ -33,12 +41,12 @@ def make_isolated_function_script( Parameters ---------- - black_box : AbstractBlackBox - The black box object to be executed. + isolated_function : AbstractIsolatedFunction + The isolated function object to be executed. conda_environment_name : str or Path, optional - The conda environment to activate before running the black box. - python_paths : List[str], optional - Additional Python paths to be added before running the black box. + The conda environment to activate before running the isolated function. + python_paths : list[str], optional + Additional Python paths to be added before running the isolated function. cwd : str or Path, optional The current working directory for the script execution. @@ -56,8 +64,8 @@ def make_isolated_function_script( def make_observer_script( observer: Type[AbstractObserver], - conda_environment: Union[str, Path] = None, - python_paths: List[str] = None, + conda_environment: Union[str, Path, None] = None, + python_paths: list[str] | None = None, cwd=None, ): """ @@ -69,7 +77,7 @@ def make_observer_script( The observer object to be executed. conda_environment : str or Path, optional The conda environment to activate before running the observer. - python_paths : List[str], optional + python_paths : list[str], optional Additional Python paths to be added before running the observer. cwd : str or Path, optional The current working directory for the script execution. @@ -88,9 +96,14 @@ def make_observer_script( def _make_run_script_from_template( command: str, - non_instantiated_object, - conda_environment_name: Union[str, Path], - python_paths: List[str], + non_instantiated_object: ( + type[AbstractIsolatedFunction] + | type[AbstractObserver] + | AbstractIsolatedFunction + | AbstractObserver + ), + conda_environment_name: Union[str, Path, None], + python_paths: list[str] | None, cwd=None, ): """ @@ -104,7 +117,7 @@ def _make_run_script_from_template( The instantiated object representing the problem factory. conda_environment_name : str or Path The name or path of the conda environment to be used. - python_paths : List[str] + python_paths : list[str] The list of python paths to be appended to the run script. cwd : str, optional The current working directory for the run script. If not provided, the current working directory is used. @@ -120,10 +133,15 @@ def _make_run_script_from_template( cwd = str(os.getcwd()) # class_object = instantiated_object.__class__ - class_object = non_instantiated_object + if isinstance( + non_instantiated_object, (AbstractIsolatedFunction, AbstractObserver) + ): + class_object = non_instantiated_object.__class__ + else: + class_object = non_instantiated_object problem_factory_name = class_object.__name__ # TODO: potential vulnerability? factory_location = inspect.getfile(class_object) - package_name = inspect.getmodule(non_instantiated_object).__name__ + package_name = cast(ModuleType, inspect.getmodule(non_instantiated_object)).__name__ if package_name == "__main__": package_name = basename(factory_location)[:-3] @@ -159,7 +177,7 @@ def _make_run_script_from_template( python_paths = [dirname(factory_location)] # TODO: check that location exists and is valid environment - python_paths = ":".join(python_paths) + python_paths_ = ":".join(python_paths) with open( join(dirname(__file__), "run_script_template.sht"), "r" @@ -171,7 +189,7 @@ def _make_run_script_from_template( run_script = run_script_template_file.read() % ( cwd, conda_environment_name, - python_paths, + python_paths_, ADDITIONAL_IMPORT_SEARCH_PATHES_KEY, command, full_problem_factory_name, diff --git a/src/poli/core/util/observers/mlflow_observer.py b/src/poli/core/util/observers/mlflow_observer.py index 2a31a503..acda9d95 100644 --- a/src/poli/core/util/observers/mlflow_observer.py +++ b/src/poli/core/util/observers/mlflow_observer.py @@ -1,3 +1,5 @@ +# pyright: reportMissingImports=false + from pathlib import Path import mlflow @@ -17,7 +19,7 @@ class MLFlowObserver(AbstractObserver): This observer uses mlflow as a backend. """ - def __init__(self, tracking_uri: Path = None): + def __init__(self, tracking_uri: Path | None = None): self.step = 0 if tracking_uri is not None: mlflow.set_tracking_uri(tracking_uri) diff --git a/src/poli/core/util/proteins/foldx.py b/src/poli/core/util/proteins/foldx.py index 901a0909..8405a17b 100644 --- a/src/poli/core/util/proteins/foldx.py +++ b/src/poli/core/util/proteins/foldx.py @@ -24,12 +24,16 @@ """ +# pyright: reportMissingImports=false + +from __future__ import annotations + import logging import os import shutil import subprocess from pathlib import Path -from typing import List, Union +from typing import Union from Bio.PDB import SASA from Bio.PDB.Residue import Residue @@ -263,7 +267,9 @@ def _repair_if_necessary_and_provide_path(self, pdb_file: Path) -> Path: self.repair(pdb_file) return self.working_dir / f"{pdb_file.stem}_Repair.pdb" - def _simulate_mutations(self, pdb_file: Path, mutations: List[str] = None) -> None: + def _simulate_mutations( + self, pdb_file: Path, mutations: list[str] | None = None + ) -> None: """Simulates mutations, starting from a wildtype PDB file. This method simulates mutations on a PDB file with FoldX. @@ -281,7 +287,7 @@ def _simulate_mutations(self, pdb_file: Path, mutations: List[str] = None) -> No ---------- pdb_file : Path The path to the PDB file to be repaired. - mutations : List[str], optional + mutations : list[str], optional The list of mutations to simulate. If None, we simulate the wildtype. Default is None. @@ -427,7 +433,9 @@ def _compute_sasa(self, pdb_file: Path) -> float: return mutated_structure.sasa - def compute_stability(self, pdb_file: Path, mutations: List[str] = None) -> float: + def compute_stability( + self, pdb_file: Path, mutations: list[str] | None = None + ) -> float: """ Compute the stability of a protein structure using FoldX. @@ -435,7 +443,7 @@ def compute_stability(self, pdb_file: Path, mutations: List[str] = None) -> floa ---------- pdb_file : Path The path to the PDB file of the protein structure. - mutations : List[str], optional + mutations : list[str], optional A list of mutations to be simulated. Only single mutations are supported. Pass no mutations to compute the energy of the wildtype. Returns @@ -462,7 +470,7 @@ def compute_stability(self, pdb_file: Path, mutations: List[str] = None) -> floa stability = -self._read_energy(pdb_file) return stability - def compute_sasa(self, pdb_file: Path, mutations: List[str] = None) -> float: + def compute_sasa(self, pdb_file: Path, mutations: list[str] | None = None) -> float: """ Compute the solvent-accessible surface area (SASA) score for a given protein structure. @@ -470,7 +478,7 @@ def compute_sasa(self, pdb_file: Path, mutations: List[str] = None) -> float: ---------- pdb_file : Path The path to the PDB file of the protein structure. - mutations : List[str], optional + mutations : list[str], optional A list of mutations to be simulated on the protein structure. Only single mutations are supported. Pass no mutations if you want to compute the SASA of the wildtype. @@ -495,7 +503,9 @@ def compute_sasa(self, pdb_file: Path, mutations: List[str] = None) -> float: sasa_score = self._compute_sasa(pdb_file) return sasa_score - def compute_stability_and_sasa(self, pdb_file: Path, mutations: List[str] = None): + def compute_stability_and_sasa( + self, pdb_file: Path, mutations: list[str] | None = None + ): """Computes stability and sasa with a single foldx run, instead of two separate runs. @@ -503,7 +513,7 @@ def compute_stability_and_sasa(self, pdb_file: Path, mutations: List[str] = None ---------- pdb_file : Path The path to the PDB file of the protein structure. - mutations : List[str], optional + mutations : list[str], optional A list of mutations to be simulated on the protein structure. Only single mutations are supported. Pass no mutations if you want to compute the SASA of the wildtype. """ @@ -545,16 +555,16 @@ def copy_foldx_files(self, pdb_file: Path): @staticmethod def write_mutations_to_file( - wildtype_resiudes: List[Residue], mutations: List[str], output_dir: Path + wildtype_resiudes: list[Residue], mutations: list[str], output_dir: Path ) -> None: """Writes the list of mutations to a file in the given directory. Parameters ---------- - wildtype_resiudes : List[Residue] + wildtype_resiudes : list[Residue] The list of wildtype residues. - mutations : List[str] + mutations : list[str] The list of mutations to simulate. output_dir : Path The directory to write the file to. diff --git a/src/poli/core/util/proteins/mutations.py b/src/poli/core/util/proteins/mutations.py index 332be5ce..34be532c 100644 --- a/src/poli/core/util/proteins/mutations.py +++ b/src/poli/core/util/proteins/mutations.py @@ -13,8 +13,10 @@ for more details. """ +# pyright: reportMissingImports=false + from pathlib import Path -from typing import List, Tuple, Union +from typing import Literal, Tuple, Union, overload import numpy as np from Bio.PDB.Residue import Residue @@ -25,7 +27,7 @@ def edits_between_strings( string_1: str, string_2: str, strict: bool = True -) -> List[Tuple[str, int, int]]: +) -> list[Tuple[str, int, int]]: # type: ignore """ Compute the edit operations between two strings. @@ -41,7 +43,7 @@ def edits_between_strings( Returns ------- - List[Tuple[str, int, int]] + list[Tuple[str, int, int]] A list of tuples representing the edit operations between the two strings. Each tuple contains the operation type ("replace"), the position in string_1, and the position in string_2. @@ -66,12 +68,12 @@ def edits_between_strings( ) for i, (a, b) in enumerate(zip(string_1, string_2)): if a != b: - yield ("replace", i, i) + yield ("replace", i, i) # type: ignore def mutations_from_wildtype_residues_and_mutant( - wildtype_residues: List[Residue], mutated_residue_string: str -) -> List[str]: + wildtype_residues: list[Residue], mutated_residue_string: str +) -> list[str]: """Computes the mutations from a wildtype list of residues and a mutated residue string. @@ -102,14 +104,14 @@ def mutations_from_wildtype_residues_and_mutant( Parameters ---------- - wildtype_residues : List[Residue] + wildtype_residues : list[Residue] The list of wildtype residues. mutated_residue_string : str The mutated residue string. Returns ------- - mutations: List[str] + mutations: list[str] The list of mutations in the format foldx expects. """ wildtype_residue_string = "".join( @@ -156,8 +158,24 @@ def mutations_from_wildtype_residues_and_mutant( return mutations_in_line +@overload +def find_closest_wildtype_pdb_file_to_mutant( + wildtype_pdb_files: list[Path], + mutated_residue_string: str, + return_hamming_distance: Literal[False] = False, +) -> Path: ... + + +@overload +def find_closest_wildtype_pdb_file_to_mutant( + wildtype_pdb_files: list[Path], + mutated_residue_string: str, + return_hamming_distance: Literal[True], +) -> Tuple[Path, int]: ... + + def find_closest_wildtype_pdb_file_to_mutant( - wildtype_pdb_files: List[Path], + wildtype_pdb_files: list[Path], mutated_residue_string: str, return_hamming_distance: bool = False, ) -> Union[Path, Tuple[Path, int]]: @@ -166,7 +184,7 @@ def find_closest_wildtype_pdb_file_to_mutant( Parameters ---------- - wildtype_pdb_files : List[Path] + wildtype_pdb_files : list[Path] A list of paths to wildtype PDB files. mutated_residue_string : str The mutated residue string. @@ -220,6 +238,6 @@ def find_closest_wildtype_pdb_file_to_mutant( ) if return_hamming_distance: - return best_candidate_pdb_file, min_hamming_distance + return best_candidate_pdb_file, int(min_hamming_distance) else: return best_candidate_pdb_file diff --git a/src/poli/core/util/proteins/pdb_parsing.py b/src/poli/core/util/proteins/pdb_parsing.py index 369dff60..90094a44 100644 --- a/src/poli/core/util/proteins/pdb_parsing.py +++ b/src/poli/core/util/proteins/pdb_parsing.py @@ -1,7 +1,8 @@ """This module contains utilities for loading PDB files and parsing them.""" +# pyright: reportMissingImports=false + from pathlib import Path -from typing import List from Bio import PDB from Bio.PDB.Residue import Residue @@ -38,7 +39,7 @@ def parse_pdb_as_structure( def parse_pdb_as_residues( path_to_pdb: Path, structure_name: str = "pdb", verbose: bool = False -) -> List[Residue]: +) -> list[Residue]: """ Parse a PDB file and return a list of Residue objects. @@ -53,7 +54,7 @@ def parse_pdb_as_residues( Returns -------- - residues: List[Residue] + residues: list[Residue] A list of Residue objects representing the parsed PDB file. """ structure = parse_pdb_as_structure(path_to_pdb, structure_name, verbose) @@ -62,7 +63,7 @@ def parse_pdb_as_residues( def parse_pdb_as_residue_strings( path_to_pdb: Path, structure_name: str = "pdb", verbose: bool = False -) -> List[str]: +) -> list[str]: """ Parse a PDB file and return a list of residue strings. @@ -77,7 +78,7 @@ def parse_pdb_as_residue_strings( Returns ------- - List[str] + list[str] A list of residue strings. """ residues = parse_pdb_as_residues(path_to_pdb, structure_name, verbose) diff --git a/src/poli/core/util/proteins/rasp/inner_rasp/cavity_model.py b/src/poli/core/util/proteins/rasp/inner_rasp/cavity_model.py index d487e9ca..bc6f1209 100644 --- a/src/poli/core/util/proteins/rasp/inner_rasp/cavity_model.py +++ b/src/poli/core/util/proteins/rasp/inner_rasp/cavity_model.py @@ -71,7 +71,7 @@ class ResidueEnvironmentsDataset(Dataset): Parameters ---------- - input_data: Union[List[str], List[ResidueEnvironment]] + input_data: Union[list[str], list[ResidueEnvironment]] List of parsed pdb filenames in .npz format or list of ResidueEnvironment objects transform: Callable @@ -89,7 +89,7 @@ def __init__( self.res_env_objects = self.parse_envs(input_data) else: raise ValueError( - "Input data is not of type" "Union[List[str], List[ResidueEnvironment]]" + "Input data is not of type" "Union[list[str], list[ResidueEnvironment]]" ) self.transformer = transformer diff --git a/src/poli/core/util/proteins/rasp/load_models.py b/src/poli/core/util/proteins/rasp/load_models.py index 17bfc84e..fefa3a1a 100644 --- a/src/poli/core/util/proteins/rasp/load_models.py +++ b/src/poli/core/util/proteins/rasp/load_models.py @@ -1,5 +1,7 @@ """Utilities for loading up the cavity and downstream models for RaSP.""" +# pyright: reportMissingImports=false + from pathlib import Path from typing import Tuple diff --git a/src/poli/core/util/proteins/rasp/rasp_interface.py b/src/poli/core/util/proteins/rasp/rasp_interface.py index 80f51790..b9789493 100644 --- a/src/poli/core/util/proteins/rasp/rasp_interface.py +++ b/src/poli/core/util/proteins/rasp/rasp_interface.py @@ -37,13 +37,15 @@ 33(suppl_2), W382-W388. """ +# pyright: reportMissingImports=false +# pyright: reportMissingModuleSource=false + import logging import os import stat import subprocess import traceback from pathlib import Path -from typing import List import numpy as np import pandas as pd @@ -478,7 +480,7 @@ def cleaned_to_parsed_pdb( ) def create_df_structure( - self, wildtype_pdb_path: Path, mutant_residue_strings: List[str] = None + self, wildtype_pdb_path: Path, mutant_residue_strings: list[str] | None = None ): """ This function creates a pandas dataframe with the @@ -549,7 +551,7 @@ def create_df_structure( pos_of_variant_column = df_structure.columns.get_loc("variant") for i in range(0, len(df_structure), 20): for j in range(20): - df_structure.iloc[i + j, pos_of_variant_column] = ( + df_structure.iloc[i + j, pos_of_variant_column] = ( # type: ignore df_structure.iloc[i + j, :]["variant"][:-1] + aa_list[j] ) diff --git a/src/poli/core/util/seeding/seeding.py b/src/poli/core/util/seeding/seeding.py index 987a43b4..172b9ce5 100644 --- a/src/poli/core/util/seeding/seeding.py +++ b/src/poli/core/util/seeding/seeding.py @@ -1,11 +1,13 @@ """Utilities for seeding random number generators.""" +# pyright: reportMissingImports=false + import random import numpy as np -def seed_numpy(seed: int = None) -> None: +def seed_numpy(seed: int | None = None) -> None: """ Seed the NumPy random number generator. @@ -18,7 +20,7 @@ def seed_numpy(seed: int = None) -> None: np.random.seed(seed) -def seed_python(seed: int = None) -> None: +def seed_python(seed: int | None = None) -> None: """ Seed the random number generator for Python. @@ -32,7 +34,7 @@ def seed_python(seed: int = None) -> None: random.seed(seed) -def seed_torch(seed: int = None) -> None: +def seed_torch(seed: int | None = None) -> None: """ Seed the random number generator for PyTorch. @@ -53,7 +55,7 @@ def seed_torch(seed: int = None) -> None: torch.cuda.manual_seed_all(seed) -def seed_python_numpy_and_torch(seed: int = None) -> None: +def seed_python_numpy_and_torch(seed: int | None = None) -> None: """ Seed all random number generators. diff --git a/src/poli/objective_factory.py b/src/poli/objective_factory.py index 59b3a4f8..dfc76025 100644 --- a/src/poli/objective_factory.py +++ b/src/poli/objective_factory.py @@ -42,13 +42,13 @@ def load_config(): def __create_problem_from_repository( name: str, - seed: int = None, - batch_size: int = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, - observer: AbstractObserver = None, + observer: AbstractObserver | None = None, **kwargs_for_factory, ) -> Problem: """Creates the objective function from the repository. @@ -114,14 +114,14 @@ def __create_problem_from_repository( def create( name: str, *, - seed: int = None, - observer_init_info: dict = None, - observer_name: str = None, + seed: int | None = None, + observer_init_info: dict[str, object] | None = None, + observer_name: str | None = None, force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, quiet: bool = False, **kwargs_for_factory, ) -> Problem: @@ -189,9 +189,9 @@ def create( def start( name: str, - seed: int = None, - caller_info: dict = None, - observer_name: str = None, + seed: int | None = None, + caller_info: dict | None = None, + observer_name: str | None = None, force_isolation: bool = False, **kwargs_for_factory, ) -> AbstractBlackBox: @@ -246,7 +246,9 @@ def start( return f -def _instantiate_observer(observer_name: str, quiet: bool = False) -> AbstractObserver: +def _instantiate_observer( + observer_name: str | None, quiet: bool = False +) -> AbstractObserver: """ This function attempts to locally instantiate an observer and if that fails starts the observer in the dedicated environment. diff --git a/src/poli/objective_repository/albuterol_similarity/register.py b/src/poli/objective_repository/albuterol_similarity/register.py index c8f5718c..d9349e38 100644 --- a/src/poli/objective_repository/albuterol_similarity/register.py +++ b/src/poli/objective_repository/albuterol_similarity/register.py @@ -88,12 +88,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Albuterol_Similarity", @@ -147,12 +147,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/aloha/register.py b/src/poli/objective_repository/aloha/register.py index ce30334f..724d412f 100644 --- a/src/poli/objective_repository/aloha/register.py +++ b/src/poli/objective_repository/aloha/register.py @@ -55,10 +55,10 @@ class AlohaBlackBox(AbstractBlackBox): def __init__( self, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): """ Initialize the aloha black box object. @@ -95,7 +95,7 @@ def get_black_box_info(self) -> BlackBoxInformation: ) # The only method you have to define - def _black_box(self, x: np.ndarray, context: dict = None) -> np.ndarray: + def _black_box(self, x: np.ndarray, context: dict | None = None) -> np.ndarray: """ Compute the distance of x to the sequence "ALOHA". @@ -143,11 +143,11 @@ class AlohaProblemFactory(AbstractProblemFactory): def create( self, - seed: int = None, - batch_size: int = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/amlodipine_mpo/register.py b/src/poli/objective_repository/amlodipine_mpo/register.py index 6ba0cb89..e690956d 100644 --- a/src/poli/objective_repository/amlodipine_mpo/register.py +++ b/src/poli/objective_repository/amlodipine_mpo/register.py @@ -83,12 +83,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Amlodipine_MPO", @@ -144,12 +144,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/celecoxib_rediscovery/register.py b/src/poli/objective_repository/celecoxib_rediscovery/register.py index 03a29764..fb1066cf 100644 --- a/src/poli/objective_repository/celecoxib_rediscovery/register.py +++ b/src/poli/objective_repository/celecoxib_rediscovery/register.py @@ -85,12 +85,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Celecoxib_Rediscovery", @@ -145,12 +145,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/deco_hop/register.py b/src/poli/objective_repository/deco_hop/register.py index 0d806d0d..51f83aed 100644 --- a/src/poli/objective_repository/deco_hop/register.py +++ b/src/poli/objective_repository/deco_hop/register.py @@ -80,12 +80,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Deco Hop", @@ -139,12 +139,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/dms_gb1/isolated_function.py b/src/poli/objective_repository/dms_gb1/isolated_function.py index f2fa5c9a..d15e7c9d 100644 --- a/src/poli/objective_repository/dms_gb1/isolated_function.py +++ b/src/poli/objective_repository/dms_gb1/isolated_function.py @@ -18,6 +18,9 @@ """ +# pyright: reportMissingImports=false +# pyright: reportMissingModuleSource=false + from __future__ import annotations from pathlib import Path @@ -38,7 +41,7 @@ class DMSGB1IsolatedLogic(AbstractIsolatedFunction): Parameters ---------- - alphabet : List[str], optional + alphabet : list[str], optional The alphabet for the problem, by default we use the amino acid list provided in poli.core.util.proteins.defaults. experiment_id : str, optional @@ -59,7 +62,7 @@ class DMSGB1IsolatedLogic(AbstractIsolatedFunction): def __init__( self, - experiment_id: str = None, + experiment_id: str | None = None, ): """ Initialize the GB1 Register object. diff --git a/src/poli/objective_repository/dms_gb1/register.py b/src/poli/objective_repository/dms_gb1/register.py index 968242a9..fcb72884 100644 --- a/src/poli/objective_repository/dms_gb1/register.py +++ b/src/poli/objective_repository/dms_gb1/register.py @@ -67,11 +67,11 @@ class DMSGB1BlackBox(AbstractBlackBox): def __init__( self, negative: bool = False, - experiment_id: str = None, - batch_size: int = None, + experiment_id: str | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ): """ @@ -170,12 +170,12 @@ class DMSGB1ProblemFactory(AbstractProblemFactory): def create( self, negative: bool = False, - experiment_id: str = None, - seed: int = None, - batch_size: int = None, + experiment_id: str | None = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/dms_trpb/isolated_function.py b/src/poli/objective_repository/dms_trpb/isolated_function.py index fdf30b0a..1c477157 100644 --- a/src/poli/objective_repository/dms_trpb/isolated_function.py +++ b/src/poli/objective_repository/dms_trpb/isolated_function.py @@ -13,6 +13,9 @@ """ +# pyright: reportMissingImports=false +# pyright: reportMissingModuleSource=false + from __future__ import annotations from pathlib import Path @@ -51,7 +54,7 @@ class DMSTrpBIsolatedLogic(AbstractIsolatedFunction): def __init__( self, - experiment_id: str = None, + experiment_id: str | None = None, ): """ Initialize the GB1 Register object. diff --git a/src/poli/objective_repository/dms_trpb/register.py b/src/poli/objective_repository/dms_trpb/register.py index b8d23b31..4c8189fc 100644 --- a/src/poli/objective_repository/dms_trpb/register.py +++ b/src/poli/objective_repository/dms_trpb/register.py @@ -63,11 +63,11 @@ class DMSTrpBBlackBox(AbstractBlackBox): def __init__( self, negative: bool = False, - experiment_id: str = None, - batch_size: int = None, + experiment_id: str | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ): """ @@ -166,12 +166,12 @@ class DMSTrpBProblemFactory(AbstractProblemFactory): def create( self, negative: bool = False, - experiment_id: str = None, - seed: int = None, - batch_size: int = None, + experiment_id: str | None = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/dockstring/isolated_function.py b/src/poli/objective_repository/dockstring/isolated_function.py index 93ed2ef5..d33c7030 100644 --- a/src/poli/objective_repository/dockstring/isolated_function.py +++ b/src/poli/objective_repository/dockstring/isolated_function.py @@ -1,3 +1,6 @@ +# pyright: reportMissingImports=false +# pyright: reportAttributeAccessIssue=false + from typing import Literal import numpy as np diff --git a/src/poli/objective_repository/dockstring/register.py b/src/poli/objective_repository/dockstring/register.py index e537b119..946038de 100644 --- a/src/poli/objective_repository/dockstring/register.py +++ b/src/poli/objective_repository/dockstring/register.py @@ -75,10 +75,10 @@ def __init__( self, target_name: str, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ): """ @@ -199,11 +199,11 @@ def create( self, target_name: str, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", - seed: int = None, - batch_size: int = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """Creates a dockstring black box function and initial observations. diff --git a/src/poli/objective_repository/drd2_docking/register.py b/src/poli/objective_repository/drd2_docking/register.py index 8854879d..d186a4f8 100644 --- a/src/poli/objective_repository/drd2_docking/register.py +++ b/src/poli/objective_repository/drd2_docking/register.py @@ -85,12 +85,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="DRD2", @@ -145,12 +145,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/drd3_docking/register.py b/src/poli/objective_repository/drd3_docking/register.py index eb849ecb..12c3258d 100644 --- a/src/poli/objective_repository/drd3_docking/register.py +++ b/src/poli/objective_repository/drd3_docking/register.py @@ -64,12 +64,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="3pbl_docking", @@ -113,12 +113,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/ehrlich/register.py b/src/poli/objective_repository/ehrlich/register.py index 5c64f8fa..98d87011 100644 --- a/src/poli/objective_repository/ehrlich/register.py +++ b/src/poli/objective_repository/ehrlich/register.py @@ -107,15 +107,15 @@ def __init__( motif_length: int, n_motifs: int, quantization: int | None = None, - seed: int = None, + seed: int | None = None, return_value_on_unfeasible: float = -np.inf, feasibility_matrix_temperature: float = 0.5, feasibility_matrix_band_length: int | None = None, alphabet: list[str] = AMINO_ACIDS, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): warnings.warn( "This EhrlichBlackBox class is different from the original " @@ -235,7 +235,7 @@ def _is_feasible(self, sequence: str | np.ndarray) -> bool: return True def construct_random_motifs( - self, motif_length: int, n_motifs: int, seed: int = None + self, motif_length: int, n_motifs: int, seed: int | None = None ) -> np.ndarray: """ Creates a given number of random motifs of a certain length. @@ -265,7 +265,7 @@ def construct_random_offsets( self, motif_length: int, n_motifs: int, - seed: int = None, + seed: int | None = None, ) -> np.ndarray: """ Creates a given number of random offsets for the motifs. @@ -338,12 +338,12 @@ def _maximal_motif_matches( Counts the maximal motif match. """ assert sequence.ndim == 1 or sequence.shape[0] == 1 - sequence = "".join(sequence.flatten()) + sequence_ = "".join(sequence.flatten()) maximal_match = 0 - for seq_idx in range(len(sequence) - max(offset)): + for seq_idx in range(len(sequence_) - max(offset)): matches = 0 sequence_at_offset = np.array( - [sequence[seq_idx + offset_value] for offset_value in offset] + [sequence_[seq_idx + offset_value] for offset_value in offset] ) matches = sum(sequence_at_offset == motif) @@ -403,13 +403,13 @@ def create( motif_length: int, n_motifs: int, quantization: int | None = None, - seed: int = None, + seed: int | None = None, return_value_on_unfeasible: float = -np.inf, alphabet: list[str] = AMINO_ACIDS, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/ehrlich_holo/isolated_function.py b/src/poli/objective_repository/ehrlich_holo/isolated_function.py index 839e6cd5..679b0f22 100644 --- a/src/poli/objective_repository/ehrlich_holo/isolated_function.py +++ b/src/poli/objective_repository/ehrlich_holo/isolated_function.py @@ -2,6 +2,8 @@ The isolation entry-point for Ehrlich, as implemented in Holo. """ +# pyright: reportMissingImports=false + from __future__ import annotations import numpy as np @@ -31,8 +33,8 @@ def __init__( return_value_on_unfeasible: float = -np.inf, alphabet: list[str] = AMINO_ACIDS, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): self.sequence_length = sequence_length self.motif_length = motif_length diff --git a/src/poli/objective_repository/ehrlich_holo/register.py b/src/poli/objective_repository/ehrlich_holo/register.py index 4531bb7a..e018b0ea 100644 --- a/src/poli/objective_repository/ehrlich_holo/register.py +++ b/src/poli/objective_repository/ehrlich_holo/register.py @@ -87,14 +87,14 @@ def __init__( n_motifs: int, quantization: int | None = None, noise_std: float = 0.0, - seed: int = None, + seed: int | None = None, epistasis_factor: float = 0.0, return_value_on_unfeasible: float = -np.inf, alphabet: list[str] = AMINO_ACIDS, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ): super().__init__(batch_size, parallelize, num_workers, evaluation_budget) @@ -215,14 +215,14 @@ def create( n_motifs: int, quantization: int | None = None, noise_std: float = 0.0, - seed: int = None, + seed: int | None = None, epistasis_factor: float = 0.0, return_value_on_unfeasible: float = -np.inf, alphabet: list[str] = AMINO_ACIDS, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/fexofenadine_mpo/register.py b/src/poli/objective_repository/fexofenadine_mpo/register.py index 5d1cdfa4..9ef84373 100644 --- a/src/poli/objective_repository/fexofenadine_mpo/register.py +++ b/src/poli/objective_repository/fexofenadine_mpo/register.py @@ -83,12 +83,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Fexofenadine_MPO", @@ -142,12 +142,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/foldx_rfp_lambo/isolated_function.py b/src/poli/objective_repository/foldx_rfp_lambo/isolated_function.py index 3279b2ca..b3d47701 100644 --- a/src/poli/objective_repository/foldx_rfp_lambo/isolated_function.py +++ b/src/poli/objective_repository/foldx_rfp_lambo/isolated_function.py @@ -1,5 +1,7 @@ """RFP objective factory and black box function.""" +# pyright: reportMissingImports=false + __author__ = "Simon Bartels" import logging @@ -172,7 +174,7 @@ def _download_assets_from_lambo(): class RFPWrapperIsolatedLogic(AbstractIsolatedFunction): def __init__( self, - seed: int = None, + seed: int | None = None, ): self.alphabet = AMINO_ACIDS self.problem_sequence = PROBLEM_SEQ @@ -220,7 +222,7 @@ def __init__( def __call__(self, x, context=None): best_b_cand = None - min_hd = np.infty # Hamming distance of best_b_cand to x + min_hd = np.inf # Hamming distance of best_b_cand to x # TODO: this assumes a batch_size of 1. Is that clear in the docs? seq = "".join(x[0]) # take out the string from the np array diff --git a/src/poli/objective_repository/foldx_rfp_lambo/register.py b/src/poli/objective_repository/foldx_rfp_lambo/register.py index 1d45c362..8c240329 100644 --- a/src/poli/objective_repository/foldx_rfp_lambo/register.py +++ b/src/poli/objective_repository/foldx_rfp_lambo/register.py @@ -18,11 +18,11 @@ class FoldXRFPLamboBlackBox(AbstractBlackBox): def __init__( self, - seed: int = None, + seed: int | None = None, parallelize: bool = False, - num_workers: int = None, - batch_size: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + batch_size: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ): super().__init__( @@ -79,11 +79,11 @@ def __init__(self): def create( self, - seed: int = None, - batch_size: int = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/foldx_sasa/isolated_function.py b/src/poli/objective_repository/foldx_sasa/isolated_function.py index 0cf0a5d4..d4b71e0d 100644 --- a/src/poli/objective_repository/foldx_sasa/isolated_function.py +++ b/src/poli/objective_repository/foldx_sasa/isolated_function.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Union +from typing import Union import numpy as np @@ -15,9 +15,9 @@ class FoldXSASAIsolatedLogic(FoldxIsolatedFunction): Parameters ----------- - wildtype_pdb_path : Union[Path, List[Path]] + wildtype_pdb_path : Union[Path, list[Path]] The path(s) to the wildtype PDB file(s). Default is None. - alphabet : List[str], optional + alphabet : list[str], optional The alphabet of amino acids. Default is None. experiment_id : str, optional The ID of the experiment. Default is None. @@ -37,9 +37,9 @@ class FoldXSASAIsolatedLogic(FoldxIsolatedFunction): def __init__( self, - wildtype_pdb_path: Union[Path, List[Path]], - experiment_id: str = None, - tmp_folder: Path = None, + wildtype_pdb_path: Union[Path, list[Path]], + experiment_id: str | None = None, + tmp_folder: Path | None = None, eager_repair: bool = False, verbose: bool = False, ): diff --git a/src/poli/objective_repository/foldx_sasa/register.py b/src/poli/objective_repository/foldx_sasa/register.py index b4d2cd69..ae0737ec 100644 --- a/src/poli/objective_repository/foldx_sasa/register.py +++ b/src/poli/objective_repository/foldx_sasa/register.py @@ -15,7 +15,7 @@ """ from pathlib import Path -from typing import List, Union +from typing import Union, cast import numpy as np @@ -35,7 +35,7 @@ class FoldXSASABlackBox(AbstractBlackBox): Parameters ----------- - wildtype_pdb_path : Union[Path, List[Path]] + wildtype_pdb_path : Union[Path, list[Path]] The path(s) to the wildtype PDB file(s). Default is None. experiment_id : str, optional The ID of the experiment. Default is None. @@ -63,15 +63,15 @@ class FoldXSASABlackBox(AbstractBlackBox): def __init__( self, - wildtype_pdb_path: Union[Path, List[Path]], - experiment_id: str = None, - tmp_folder: Path = None, + wildtype_pdb_path: Union[Path, list[Path]], + experiment_id: str | None = None, + tmp_folder: Path | None = None, eager_repair: bool = False, verbose: bool = False, batch_size: int = 1, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ): super().__init__( @@ -165,16 +165,16 @@ class FoldXSASAProblemFactory(AbstractProblemFactory): def create( self, - wildtype_pdb_path: Union[Path, List[Path]], - experiment_id: str = None, - tmp_folder: Path = None, + wildtype_pdb_path: Union[Path, list[Path]], + experiment_id: str | None = None, + tmp_folder: Path | None = None, eager_repair: bool = False, verbose: bool = False, - seed: int = None, + seed: int | None = None, batch_size: int = 1, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ @@ -182,7 +182,7 @@ def create( Parameters ---------- - wildtype_pdb_path : Union[Path, List[Path]] + wildtype_pdb_path : Union[Path, list[Path]] Path or list of paths to the wildtype PDB files. experiment_id : str, optional Identifier for the experiment. @@ -232,7 +232,9 @@ def create( wildtype_pdb_path = [wildtype_pdb_path] elif isinstance(wildtype_pdb_path, list): if isinstance(wildtype_pdb_path[0], str): - wildtype_pdb_path = [Path(x.strip()) for x in wildtype_pdb_path] + wildtype_pdb_path = [ + Path(cast(str, x).strip()) for x in wildtype_pdb_path + ] elif isinstance(wildtype_pdb_path[0], Path): pass else: diff --git a/src/poli/objective_repository/foldx_stability/isolated_function.py b/src/poli/objective_repository/foldx_stability/isolated_function.py index 03f2d8c4..c8993f93 100644 --- a/src/poli/objective_repository/foldx_stability/isolated_function.py +++ b/src/poli/objective_repository/foldx_stability/isolated_function.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Union +from typing import Union import numpy as np @@ -16,9 +16,9 @@ class FoldXStabilityIsolatedLogic(FoldxIsolatedFunction): Parameters ----------- - wildtype_pdb_path : Union[Path, List[Path]] + wildtype_pdb_path : Union[Path, list[Path]] The path(s) to the wildtype PDB file(s). Default is None. - alphabet : List[str], optional + alphabet : list[str], optional The alphabet of amino acids. Default is None. experiment_id : str, optional The ID of the experiment. Default is None. @@ -38,9 +38,9 @@ class FoldXStabilityIsolatedLogic(FoldxIsolatedFunction): def __init__( self, - wildtype_pdb_path: Union[Path, List[Path]], - experiment_id: str = None, - tmp_folder: Path = None, + wildtype_pdb_path: Union[Path, list[Path]], + experiment_id: str | None = None, + tmp_folder: Path | None = None, eager_repair: bool = False, verbose: bool = False, ): diff --git a/src/poli/objective_repository/foldx_stability/register.py b/src/poli/objective_repository/foldx_stability/register.py index 972cbe7b..b1e4489c 100644 --- a/src/poli/objective_repository/foldx_stability/register.py +++ b/src/poli/objective_repository/foldx_stability/register.py @@ -17,7 +17,7 @@ from __future__ import annotations from pathlib import Path -from typing import List, Union +from typing import Union, cast import numpy as np @@ -37,7 +37,7 @@ class FoldXStabilityBlackBox(AbstractBlackBox): Parameters ---------- - wildtype_pdb_path : Union[Path, List[Path]] + wildtype_pdb_path : Union[Path, list[Path]] The path(s) to the wildtype PDB file(s). experiment_id : str, optional The ID of the experiment (default is None). @@ -70,15 +70,15 @@ class FoldXStabilityBlackBox(AbstractBlackBox): def __init__( self, - wildtype_pdb_path: Union[Path, List[Path]], - experiment_id: str = None, - tmp_folder: Path = None, + wildtype_pdb_path: Union[Path, list[Path]], + experiment_id: str | None = None, + tmp_folder: Path | None = None, eager_repair: bool = False, verbose: bool = False, batch_size: int = 1, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ): super().__init__( @@ -113,7 +113,7 @@ def __init__( self.x0 = inner_function.x0 self.wildtype_amino_acids = inner_function.wildtype_amino_acids - def _black_box(self, x: np.ndarray, context: None) -> np.ndarray: + def _black_box(self, x: np.ndarray, context: None = None) -> np.ndarray: """ Runs the given input x and pdb files provided in the context through FoldX and returns the @@ -168,16 +168,16 @@ def get_black_box_info(self) -> BlackBoxInformation: class FoldXStabilityProblemFactory(AbstractProblemFactory): def create( self, - wildtype_pdb_path: Union[Path, List[Path]], - experiment_id: str = None, - tmp_folder: Path = None, + wildtype_pdb_path: Union[Path, list[Path]], + experiment_id: str | None = None, + tmp_folder: Path | None = None, eager_repair: bool = False, verbose: bool = False, - seed: int = None, + seed: int | None = None, batch_size: int = 1, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ @@ -185,9 +185,9 @@ def create( Parameters ---------- - wildtype_pdb_path : Union[Path, List[Path]] + wildtype_pdb_path : Union[Path, list[Path]] Path(s) to the wildtype PDB file(s). - alphabet : List[str], optional + alphabet : list[str], optional List of amino acids to use as the alphabet. experiment_id : str, optional Identifier for the experiment. @@ -233,7 +233,9 @@ def create( wildtype_pdb_path = [wildtype_pdb_path] elif isinstance(wildtype_pdb_path, list): if isinstance(wildtype_pdb_path[0], str): - wildtype_pdb_path = [Path(x.strip()) for x in wildtype_pdb_path] + wildtype_pdb_path = [ + Path(cast(str, x).strip()) for x in wildtype_pdb_path + ] elif isinstance(wildtype_pdb_path[0], Path): pass else: diff --git a/src/poli/objective_repository/foldx_stability_and_sasa/isolated_function.py b/src/poli/objective_repository/foldx_stability_and_sasa/isolated_function.py index 36f5fbd0..bbcd1031 100644 --- a/src/poli/objective_repository/foldx_stability_and_sasa/isolated_function.py +++ b/src/poli/objective_repository/foldx_stability_and_sasa/isolated_function.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Union +from typing import Union import numpy as np @@ -15,9 +15,9 @@ class FoldXStabilitityAndSASAIsolatedLogic(FoldxIsolatedFunction): Parameters ----------- - wildtype_pdb_path : Union[Path, List[Path]] + wildtype_pdb_path : Union[Path, list[Path]] The path(s) to the wildtype PDB file(s). Default is None. - alphabet : List[str], optional + alphabet : list[str], optional The alphabet of amino acids. Default is None. experiment_id : str, optional The ID of the experiment. Default is None. @@ -37,9 +37,9 @@ class FoldXStabilitityAndSASAIsolatedLogic(FoldxIsolatedFunction): def __init__( self, - wildtype_pdb_path: Union[Path, List[Path]], - experiment_id: str = None, - tmp_folder: Path = None, + wildtype_pdb_path: Union[Path, list[Path]], + experiment_id: str | None = None, + tmp_folder: Path | None = None, eager_repair: bool = False, verbose: bool = False, ): diff --git a/src/poli/objective_repository/foldx_stability_and_sasa/register.py b/src/poli/objective_repository/foldx_stability_and_sasa/register.py index a7f39bf9..13aec70c 100644 --- a/src/poli/objective_repository/foldx_stability_and_sasa/register.py +++ b/src/poli/objective_repository/foldx_stability_and_sasa/register.py @@ -17,7 +17,7 @@ """ from pathlib import Path -from typing import List, Union +from typing import Union, cast import numpy as np @@ -37,7 +37,7 @@ class FoldXStabilityAndSASABlackBox(AbstractBlackBox): Parameters ----------- - wildtype_pdb_path : Union[Path, List[Path]] + wildtype_pdb_path : Union[Path, list[Path]] The path(s) to the wildtype PDB file(s). experiment_id : str, optional The ID of the experiment. Default is None. @@ -65,15 +65,15 @@ class FoldXStabilityAndSASABlackBox(AbstractBlackBox): def __init__( self, - wildtype_pdb_path: Union[Path, List[Path]], - experiment_id: str = None, - tmp_folder: Path = None, + wildtype_pdb_path: Union[Path, list[Path]], + experiment_id: str | None = None, + tmp_folder: Path | None = None, eager_repair: bool = False, verbose: bool = False, batch_size: int = 1, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ): super().__init__( @@ -104,7 +104,7 @@ def __init__( ) self.wildtype_amino_acids = inner_function.wildtype_amino_acids - def _black_box(self, x: np.ndarray, context: None) -> np.ndarray: + def _black_box(self, x: np.ndarray, context: None = None) -> np.ndarray: """ Runs the given input x and pdb files provided in the context through FoldX and returns the @@ -166,16 +166,16 @@ class FoldXStabilityAndSASAProblemFactory(AbstractProblemFactory): def create( self, - wildtype_pdb_path: Union[Path, List[Path]], - experiment_id: str = None, - tmp_folder: Path = None, + wildtype_pdb_path: Union[Path, list[Path]], + experiment_id: str | None = None, + tmp_folder: Path | None = None, eager_repair: bool = False, verbose: bool = False, - seed: int = None, - batch_size: int = None, + seed: int | None = None, + batch_size: int = 1, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ @@ -183,7 +183,7 @@ def create( Parameters ---------- - wildtype_pdb_path : Union[Path, List[Path]] + wildtype_pdb_path : Union[Path, list[Path]] Path or list of paths to the wildtype PDB files. experiment_id : str, optional Identifier for the experiment. @@ -231,7 +231,9 @@ def create( wildtype_pdb_path = [wildtype_pdb_path] elif isinstance(wildtype_pdb_path, list): if isinstance(wildtype_pdb_path[0], str): - wildtype_pdb_path = [Path(x.strip()) for x in wildtype_pdb_path] + wildtype_pdb_path = [ + Path(cast(str, x).strip()) for x in wildtype_pdb_path + ] elif isinstance(wildtype_pdb_path[0], Path): pass else: diff --git a/src/poli/objective_repository/gfp_cbas/isolated_function.py b/src/poli/objective_repository/gfp_cbas/isolated_function.py index dc0c7592..e92624d1 100644 --- a/src/poli/objective_repository/gfp_cbas/isolated_function.py +++ b/src/poli/objective_repository/gfp_cbas/isolated_function.py @@ -1,3 +1,5 @@ +# pyright: reportMissingImports=false +# pyright: reportMissingModuleSource=false from pathlib import Path from typing import Literal from warnings import warn @@ -5,6 +7,7 @@ import numpy as np import pandas as pd import torch +from numpy.typing import NDArray from poli.core.abstract_isolated_function import AbstractIsolatedFunction from poli.core.black_box_information import BlackBoxInformation @@ -22,7 +25,7 @@ def __init__( problem_type: Literal["gp", "vae", "elbo"], info: BlackBoxInformation, n_starting_points: int = 1, - seed: int = None, + seed: int | None = None, functional_only: bool = False, ignore_stops: bool = True, unique=True, @@ -134,7 +137,7 @@ def _vae_embedding(self, x: np.ndarray) -> np.ndarray: oh_x = one_hot_encode_aa_array(x) return self.model.predict(oh_x)[0] - def __call__(self, x: np.array, context=None) -> np.ndarray: + def __call__(self, x: NDArray[np.str_], context=None) -> np.ndarray: """ x is encoded sequence return function value given problem name """ diff --git a/src/poli/objective_repository/gfp_cbas/register.py b/src/poli/objective_repository/gfp_cbas/register.py index 6e113c06..4e3897af 100644 --- a/src/poli/objective_repository/gfp_cbas/register.py +++ b/src/poli/objective_repository/gfp_cbas/register.py @@ -2,6 +2,7 @@ from warnings import warn import numpy as np +from numpy.typing import NDArray from poli.core.abstract_black_box import AbstractBlackBox from poli.core.abstract_problem_factory import AbstractProblemFactory @@ -20,11 +21,11 @@ def __init__( ignore_stops: bool = True, unique=True, n_starting_points: int = 1, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - seed: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + seed: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, negate: bool = False, ): @@ -59,7 +60,7 @@ def __init__( ) self.x0 = inner_function.x0 - def _black_box(self, x: np.array, context=None) -> np.ndarray: + def _black_box(self, x: NDArray[np.str_], context=None) -> np.ndarray: """ x is encoded sequence return function value given problem name """ @@ -114,11 +115,11 @@ def create( n_starting_points: int = 1, functional_only: bool = False, unique: bool = True, - seed: int = None, - batch_size: int = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, negate: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/gfp_select/isolated_function.py b/src/poli/objective_repository/gfp_select/isolated_function.py index e9adfe6e..9aed1917 100644 --- a/src/poli/objective_repository/gfp_select/isolated_function.py +++ b/src/poli/objective_repository/gfp_select/isolated_function.py @@ -1,7 +1,10 @@ +# pyright: reportMissingImports=false +# pyright: reportMissingModuleSource=false from pathlib import Path import numpy as np import pandas as pd +from numpy.typing import NDArray from poli.core.abstract_isolated_function import AbstractIsolatedFunction @@ -9,7 +12,7 @@ class GFPSelectIsolatedLogic(AbstractIsolatedFunction): def __init__( self, - seed: int = None, + seed: int | None = None, ): gfp_df_path = Path(__file__).parent.resolve() / "assets" / "gfp_data.csv" self.seed = seed @@ -25,15 +28,15 @@ def __init__( self.x0 = x0 - def __call__(self, x: np.array, context=None) -> np.ndarray: + def __call__(self, x: NDArray[np.str_], context=None) -> np.ndarray: """ x is string sequence which we look-up in avilable df, return median Brightness """ if isinstance(x, np.ndarray): _arr = x.copy() - x = ["".join(_seq) for _seq in _arr] + x_ = ["".join(_seq) for _seq in _arr] ys = [] - for _x in x: + for _x in x_: seq_subsets = self.gfp_lookup_df[ self.gfp_lookup_df.aaSequence.str.lower() == _x.lower() ] diff --git a/src/poli/objective_repository/gfp_select/register.py b/src/poli/objective_repository/gfp_select/register.py index b5dbcecf..35bef8ad 100644 --- a/src/poli/objective_repository/gfp_select/register.py +++ b/src/poli/objective_repository/gfp_select/register.py @@ -1,4 +1,5 @@ import numpy as np +from numpy.typing import NDArray from poli.core.abstract_black_box import AbstractBlackBox from poli.core.abstract_problem_factory import AbstractProblemFactory @@ -12,11 +13,11 @@ class GFPSelectionBlackBox(AbstractBlackBox): def __init__( self, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, - seed: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, + seed: int | None = None, force_isolation: bool = False, ): super().__init__( @@ -41,7 +42,7 @@ def __init__( name="gfp_select__isolated", seed=seed ) - def _black_box(self, x: np.array, context=None) -> np.ndarray: + def _black_box(self, x: NDArray[np.str_], context=None) -> np.ndarray: """ x is string sequence which we look-up in avilable df, return median Brightness """ @@ -66,11 +67,11 @@ class GFPSelectionProblemFactory(AbstractProblemFactory): def create( self, - seed: int = None, - batch_size: int = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: if seed is not None: diff --git a/src/poli/objective_repository/gsk3_beta/register.py b/src/poli/objective_repository/gsk3_beta/register.py index 9856b250..07804aa6 100644 --- a/src/poli/objective_repository/gsk3_beta/register.py +++ b/src/poli/objective_repository/gsk3_beta/register.py @@ -97,12 +97,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="GSK3B", @@ -162,12 +162,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/isomer_c7h8n2o2/register.py b/src/poli/objective_repository/isomer_c7h8n2o2/register.py index 2d823d84..d5ee0232 100644 --- a/src/poli/objective_repository/isomer_c7h8n2o2/register.py +++ b/src/poli/objective_repository/isomer_c7h8n2o2/register.py @@ -83,12 +83,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Isomers_C7H8N2O2", @@ -142,12 +142,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/isomer_c9h10n2o2pf2cl/register.py b/src/poli/objective_repository/isomer_c9h10n2o2pf2cl/register.py index 9f35ebb1..50633fa8 100644 --- a/src/poli/objective_repository/isomer_c9h10n2o2pf2cl/register.py +++ b/src/poli/objective_repository/isomer_c9h10n2o2pf2cl/register.py @@ -85,12 +85,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Isomers_C9H10N2O2PF2Cl", @@ -145,12 +145,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/jnk3/register.py b/src/poli/objective_repository/jnk3/register.py index fe5e352d..af3f2bc0 100644 --- a/src/poli/objective_repository/jnk3/register.py +++ b/src/poli/objective_repository/jnk3/register.py @@ -94,12 +94,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="JNK3", @@ -158,12 +158,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/median_1/register.py b/src/poli/objective_repository/median_1/register.py index 3260c392..9b6c4027 100644 --- a/src/poli/objective_repository/median_1/register.py +++ b/src/poli/objective_repository/median_1/register.py @@ -82,12 +82,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Median 1", @@ -141,12 +141,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/median_2/register.py b/src/poli/objective_repository/median_2/register.py index f857236b..8921f06c 100644 --- a/src/poli/objective_repository/median_2/register.py +++ b/src/poli/objective_repository/median_2/register.py @@ -81,12 +81,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Median 2", @@ -140,12 +140,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/mestranol_similarity/register.py b/src/poli/objective_repository/mestranol_similarity/register.py index d366d108..0bbc07c2 100644 --- a/src/poli/objective_repository/mestranol_similarity/register.py +++ b/src/poli/objective_repository/mestranol_similarity/register.py @@ -86,12 +86,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Mestranol_Similarity", @@ -145,12 +145,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/osimetrinib_mpo/register.py b/src/poli/objective_repository/osimetrinib_mpo/register.py index 45b3c79e..42bfe9f0 100644 --- a/src/poli/objective_repository/osimetrinib_mpo/register.py +++ b/src/poli/objective_repository/osimetrinib_mpo/register.py @@ -82,12 +82,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Osimetrinib_MPO", @@ -141,12 +141,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/penalized_logp_lambo/isolated_function.py b/src/poli/objective_repository/penalized_logp_lambo/isolated_function.py index 020d768b..bcf40fc2 100644 --- a/src/poli/objective_repository/penalized_logp_lambo/isolated_function.py +++ b/src/poli/objective_repository/penalized_logp_lambo/isolated_function.py @@ -11,6 +11,8 @@ arXiv, July 12, 2022. http://arxiv.org/abs/2203.12742. """ +# pyright: reportMissingImports=false + import logging import os from pathlib import Path @@ -75,7 +77,7 @@ def __init__( self.penalized = penalized _download_assets_from_lambo() - def __call__(self, x: np.ndarray, context: dict = None): + def __call__(self, x: np.ndarray, context: dict | None = None): """ Assuming that x is an array of strings (of shape [b,L]), we concatenate, translate to smiles if it's diff --git a/src/poli/objective_repository/penalized_logp_lambo/register.py b/src/poli/objective_repository/penalized_logp_lambo/register.py index ed0d545d..dc51ae1a 100644 --- a/src/poli/objective_repository/penalized_logp_lambo/register.py +++ b/src/poli/objective_repository/penalized_logp_lambo/register.py @@ -11,7 +11,7 @@ arXiv, July 12, 2022. http://arxiv.org/abs/2203.12742. """ -from typing import Literal, Tuple +from typing import Literal import numpy as np @@ -38,10 +38,10 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", penalized: bool = True, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ): super().__init__( @@ -66,7 +66,7 @@ def __init__( penalized=penalized, ) - def _black_box(self, x: np.ndarray, context: dict = None): + def _black_box(self, x: np.ndarray, context: dict | None = None): """ Assuming that x is an array of strings (of shape [b,L]), we concatenate, translate to smiles if it's @@ -105,14 +105,14 @@ class PenalizedLogPLamboProblemFactory(AbstractProblemFactory): def create( self, penalized: bool = True, - string_representation: str = "SMILES", - seed: int = None, - batch_size: int = None, + string_representation: Literal["SMILES", "SELFIES"] = "SMILES", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, - ) -> Tuple[AbstractBlackBox, np.ndarray, np.ndarray]: + ) -> Problem: if seed is not None: seed_python_numpy_and_torch(seed) @@ -123,7 +123,7 @@ def create( ) f = PenalizedLogPLamboBlackBox( - string_representation=string_representation.upper(), + string_representation=string_representation, penalized=penalized, batch_size=batch_size, parallelize=parallelize, diff --git a/src/poli/objective_repository/perindopril_mpo/register.py b/src/poli/objective_repository/perindopril_mpo/register.py index cd3ee4d1..13c5aa79 100644 --- a/src/poli/objective_repository/perindopril_mpo/register.py +++ b/src/poli/objective_repository/perindopril_mpo/register.py @@ -81,12 +81,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Perindopril_MPO", @@ -140,12 +140,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/ranolazine_mpo/register.py b/src/poli/objective_repository/ranolazine_mpo/register.py index 450ef325..0983a9bc 100644 --- a/src/poli/objective_repository/ranolazine_mpo/register.py +++ b/src/poli/objective_repository/ranolazine_mpo/register.py @@ -82,12 +82,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Ranolazine_MPO", @@ -141,12 +141,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/rasp/isolated_function.py b/src/poli/objective_repository/rasp/isolated_function.py index b86b75e9..d711d11c 100644 --- a/src/poli/objective_repository/rasp/isolated_function.py +++ b/src/poli/objective_repository/rasp/isolated_function.py @@ -21,11 +21,12 @@ from pathlib import Path from time import time -from typing import List, Union +from typing import Union, cast from uuid import uuid4 import numpy as np -import torch +import torch # type: ignore[reportMissingImports] +from numpy.typing import NDArray from poli.core.abstract_isolated_function import AbstractIsolatedFunction from poli.core.util.proteins.mutations import find_closest_wildtype_pdb_file_to_mutant @@ -80,21 +81,21 @@ class RaspIsolatedLogic(AbstractIsolatedFunction): Parameters ---------- - wildtype_pdb_path : Union[Path, List[Path]] + wildtype_pdb_path : Union[Path, list[Path]] The path(s) to the wildtype PDB file(s), by default None. additive : bool, optional Whether we treat multiple mutations as additive, by default False. If you are interested in running this black box with multiple mutations, you should set this to True. Otherwise, it will raise an error if you pass a sequence with more than one mutation. - chains_to_keep : List[str], optional + chains_to_keep : list[str], optional The chains to keep in the PDB file(s), by default we keep the chain "A" for all pdbs passed. penalize_unfeasible_with: float, optional The value to return when the input is unfeasible, by default None, which means that we raise an error when an unfeasible sequence (e.g. one with a length different from the wildtypes) is passed. - alphabet : List[str], optional + alphabet : list[str], optional The alphabet for the problem, by default we use the amino acid list provided in poli.core.util.proteins.defaults. experiment_id : str, optional @@ -128,12 +129,12 @@ class RaspIsolatedLogic(AbstractIsolatedFunction): def __init__( self, - wildtype_pdb_path: Union[Path, List[Path]], + wildtype_pdb_path: Union[Path, list[Path]], additive: bool = False, - chains_to_keep: List[str] = None, + chains_to_keep: list[str] | None = None, penalize_unfeasible_with: float | None = None, - experiment_id: str = None, - tmp_folder: Path = None, + experiment_id: str | None = None, + tmp_folder: Path | None = None, device: str | torch.device | None = None, ): """ @@ -141,21 +142,21 @@ def __init__( Parameters: ----------- - wildtype_pdb_path : Union[Path, List[Path]] + wildtype_pdb_path : Union[Path, list[Path]] The path(s) to the wildtype PDB file(s). additive : bool, optional Whether we treat multiple mutations as additive, by default False. If you are interested in running this black box with multiple mutations, you should set this to True. Otherwise, it will raise an error if you pass a sequence with more than one mutation. - chains_to_keep : List[str], optional + chains_to_keep : list[str], optional The chains to keep in the PDB file(s), by default we keep the chain "A" for all pdbs passed. penalize_unfeasible_with: float, optional The value to return when the input is unfeasible, by default None, which means that we raise an error when an unfeasible sequence (e.g. one with a length different from the wildtypes) is passed. - alphabet : List[str], optional + alphabet : list[str], optional The alphabet for the problem, by default we use the amino acid list provided in poli.core.util.proteins.defaults. experiment_id : str, optional @@ -189,7 +190,9 @@ def __init__( if isinstance(wildtype_pdb_path, list): if isinstance(wildtype_pdb_path[0], str): # Assuming that wildtype_pdb_path is a list of strings - wildtype_pdb_path = [Path(x.strip()) for x in wildtype_pdb_path] + wildtype_pdb_path = [ + Path(cast(str, x).strip()) for x in wildtype_pdb_path + ] elif isinstance(wildtype_pdb_path[0], Path): pass @@ -217,19 +220,22 @@ def __init__( # Validating the chains to keep if isinstance(chains_to_keep, type(None)): # Defaulting to always keeping chain A. - chains_to_keep = ["A"] * len(self.wildtype_pdb_paths) - - if isinstance(chains_to_keep, str): - chains_to_keep = [chains_to_keep] * len(self.wildtype_pdb_paths) - - if isinstance(chains_to_keep, list): + chains_to_keep_ = ["A"] * len(self.wildtype_pdb_paths) + elif isinstance(chains_to_keep, str): + chains_to_keep_ = [chains_to_keep] * len(self.wildtype_pdb_paths) + elif isinstance(chains_to_keep, list): assert len(chains_to_keep) == len(self.wildtype_pdb_paths), ( "The number of chains to keep must be the same as the number of wildtypes." " You can specify a single chain to keep for all wildtypes, or a list of chains." ) + chains_to_keep_ = chains_to_keep + else: + raise TypeError( + "chains_to_keep must be a string, a list of strings, or None." + ) # At this point, we are sure that chains_to_keep is a list of strings - self.chains_to_keep = chains_to_keep + self.chains_to_keep = chains_to_keep_ self.penalize_unfeasible_with = penalize_unfeasible_with @@ -252,7 +258,7 @@ def __init__( self._clean_wildtype_pdb_files() x0_pre_array = [] - for clean_wildtype_pdb_file in self.clean_wildtype_pdb_files: + for clean_wildtype_pdb_file in cast(list[Path], self.clean_wildtype_pdb_files): # Loads up the wildtype pdb files as strings wildtype_string = self.parse_pdb_as_residue_strings(clean_wildtype_pdb_file) x0_pre_array.append(list(wildtype_string)) @@ -307,12 +313,12 @@ def _clean_wildtype_pdb_files(self): for wildtype_pdb_path in self.wildtype_pdb_paths ] - def parse_pdb_as_residue_strings(self, pdb_file: Path) -> List[str]: + def parse_pdb_as_residue_strings(self, pdb_file: Path) -> list[str]: return parse_pdb_as_residue_strings(pdb_file) def _compute_mutant_residue_string_ddg( self, mutant_residue_string: str - ) -> np.ndarray: + ) -> NDArray[np.float64]: for i, char in enumerate(mutant_residue_string): if char not in self.rasp_interface.alphabet: raise ValueError( @@ -321,14 +327,12 @@ def _compute_mutant_residue_string_ddg( f"in the alphabet: {self.rasp_interface.alphabet}." ) try: - ( - closest_wildtype_pdb_file, - hamming_distance, - ) = find_closest_wildtype_pdb_file_to_mutant( - self.clean_wildtype_pdb_files, + res = find_closest_wildtype_pdb_file_to_mutant( + cast(list[Path], self.clean_wildtype_pdb_files), mutant_residue_string, return_hamming_distance=True, ) + closest_wildtype_pdb_file, hamming_distance = cast(tuple[Path, int], res) except ValueError as e: # This means that the mutant is unfeasible if self.penalize_unfeasible_with is not None: @@ -354,7 +358,7 @@ def _compute_mutant_residue_string_ddg( # Loading the models in preparation for inference cavity_model_net, ds_model_net = load_cavity_and_downstream_models( - device=self.device + device=self.device # type: ignore ) dataset_key = "predictions" @@ -388,13 +392,13 @@ def _compute_mutant_residue_string_ddg( " https://github.com/MachineLearningLifeScience/poli/issues" ) - result = np.sum(sliced_values_for_mutant, keepdims=True) + result = np.sum(sliced_values_for_mutant, keepdims=True) # type: ignore else: result = sliced_values_for_mutant else: result = sliced_values_for_mutant - return result + return cast(NDArray[np.float64], result) def __call__(self, x, context=None): """ @@ -427,7 +431,7 @@ def __call__(self, x, context=None): # and each of the wildtypes in self.wildtype_residue_strings. # closest_wildtypes will be a dictionary - # of the form {wildtype_path: List[str] of mutations} + # of the form {wildtype_path: list[str] of mutations} # closest_wildtypes = defaultdict(list) # mutant_residue_strings = [] # mutant_residue_to_hamming_distances = dict() diff --git a/src/poli/objective_repository/rasp/register.py b/src/poli/objective_repository/rasp/register.py index d5ea4f07..66a09a7b 100644 --- a/src/poli/objective_repository/rasp/register.py +++ b/src/poli/objective_repository/rasp/register.py @@ -20,7 +20,7 @@ from __future__ import annotations from pathlib import Path -from typing import List, Union +from typing import Union, cast from poli.core.abstract_black_box import AbstractBlackBox from poli.core.abstract_problem_factory import AbstractProblemFactory @@ -37,14 +37,14 @@ class RaspBlackBox(AbstractBlackBox): Parameters ---------- - wildtype_pdb_path : Union[Path, List[Path]] + wildtype_pdb_path : Union[Path, list[Path]] The path(s) to the wildtype PDB file(s), by default None. additive : bool, optional Whether we treat multiple mutations as additive, by default False. If you are interested in running this black box with multiple mutations, you should set this to True. Otherwise, it will raise an error if you pass a sequence with more than one mutation. - chains_to_keep : List[str], optional + chains_to_keep : list[str], optional The chains to keep in the PDB file(s), by default we keep the chain "A" for all pdbs passed. experiment_id : str, optional @@ -87,17 +87,17 @@ class RaspBlackBox(AbstractBlackBox): def __init__( self, - wildtype_pdb_path: Union[Path, List[Path]], + wildtype_pdb_path: Union[Path, list[Path]], additive: bool = False, - chains_to_keep: List[str] = None, + chains_to_keep: list[str] | None = None, penalize_unfeasible_with: float | None = None, device: str | None = None, - experiment_id: str = None, - tmp_folder: Path = None, - batch_size: int = None, + experiment_id: str | None = None, + tmp_folder: Path | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ): """ @@ -105,14 +105,14 @@ def __init__( Parameters: ----------- - wildtype_pdb_path : Union[Path, List[Path]] + wildtype_pdb_path : Union[Path, list[Path]] The path(s) to the wildtype PDB file(s). additive : bool, optional Whether we treat multiple mutations as additive, by default False. If you are interested in running this black box with multiple mutations, you should set this to True. Otherwise, it will raise an error if you pass a sequence with more than one mutation. - chains_to_keep : List[str], optional + chains_to_keep : list[str], optional The chains to keep in the PDB file(s), by default we keep the chain "A" for all pdbs passed. penalize_unfeasible_with : float | None, optional @@ -164,7 +164,11 @@ def __init__( evaluation_budget=evaluation_budget, ) self.force_isolation = force_isolation - self.wildtype_pdb_path = wildtype_pdb_path + self.wildtype_pdb_path = ( + wildtype_pdb_path + if isinstance(wildtype_pdb_path, list) + else [wildtype_pdb_path] + ) self.chains_to_keep = chains_to_keep self.experiment_id = experiment_id self.tmp_folder = tmp_folder @@ -236,18 +240,18 @@ def get_black_box_info(self) -> BlackBoxInformation: class RaspProblemFactory(AbstractProblemFactory): def create( self, - wildtype_pdb_path: Union[Path, List[Path]], + wildtype_pdb_path: Union[Path, list[Path]], additive: bool = False, - chains_to_keep: List[str] = None, + chains_to_keep: list[str] | None = None, penalize_unfeasible_with: float | None = None, device: str | None = None, - experiment_id: str = None, - tmp_folder: Path = None, - seed: int = None, - batch_size: int = None, + experiment_id: str | None = None, + tmp_folder: Path | None = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ @@ -256,14 +260,14 @@ def create( Parameters ---------- - wildtype_pdb_path : Union[Path, List[Path]] + wildtype_pdb_path : Union[Path, list[Path]] The path(s) to the wildtype PDB file(s). additive: bool, optional Whether we treat multiple mutations as additive, by default False. If you are interested in running this black box with multiple mutations, you should set this to True. Otherwise, it will raise an error if you pass a sequence with more than one mutation. - chains_to_keep : List[str], optional + chains_to_keep : list[str], optional The chains to keep in the PDB file(s), by default we keep the chain "A" for all pdbs passed. penalize_unfeasible_with : float | None, optional @@ -314,7 +318,9 @@ def create( wildtype_pdb_path = [wildtype_pdb_path] elif isinstance(wildtype_pdb_path, list): if isinstance(wildtype_pdb_path[0], str): - wildtype_pdb_path = [Path(x.strip()) for x in wildtype_pdb_path] + wildtype_pdb_path = [ + Path(cast(str, x).strip()) for x in wildtype_pdb_path + ] elif isinstance(wildtype_pdb_path[0], Path): pass else: diff --git a/src/poli/objective_repository/rdkit_logp/register.py b/src/poli/objective_repository/rdkit_logp/register.py index 88e97728..ca0a00bb 100644 --- a/src/poli/objective_repository/rdkit_logp/register.py +++ b/src/poli/objective_repository/rdkit_logp/register.py @@ -77,11 +77,11 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ): """ @@ -110,7 +110,9 @@ def __init__( self.from_selfies = string_representation.upper() == "SELFIES" self.from_smiles = string_representation.upper() == "SMILES" self.alphabet = alphabet - self.max_sequence_length = max_sequence_length + self.max_sequence_length = ( + max_sequence_length if max_sequence_length != "inf" else float("inf") + ) self.string_representation = string_representation super().__init__( @@ -121,7 +123,7 @@ def __init__( ) # The only method you have to define - def _black_box(self, x: np.ndarray, context: dict = None) -> np.ndarray: + def _black_box(self, x: np.ndarray, context: dict | None = None) -> np.ndarray: """Computes the logP of a molecule x (array of strings). Assuming that x is an array of integers of length L, @@ -147,7 +149,7 @@ def _black_box(self, x: np.ndarray, context: dict = None) -> np.ndarray: for molecule in molecules: if molecule is not None: - logp_value = Descriptors.MolLogP(molecule) + logp_value = Descriptors.MolLogP(molecule) # type: ignore # If the qed value is not a float, return NaN if not isinstance(logp_value, float): @@ -182,12 +184,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """Creates a logP problem instance. @@ -216,7 +218,7 @@ def create( if seed is not None: seed_python_numpy_and_torch(seed) - if string_representation.upper() not in ["SMILES", "SELFIES"]: + if string_representation not in ["SMILES", "SELFIES"]: raise ValueError( "Missing required keyword argument: string_representation: str. " "String representation must be either 'SMILES' or 'SELFIES'." @@ -225,7 +227,7 @@ def create( self.string_representation = string_representation f = LogPBlackBox( - string_representation=string_representation.upper(), + string_representation=string_representation, alphabet=alphabet, max_sequence_length=max_sequence_length, batch_size=batch_size, @@ -240,6 +242,6 @@ def create( else: x0 = np.array([["[C]" * 10]]) - problem = TDCProblem(f, x0) + problem = TDCProblem(f, x0) # type: ignore return problem diff --git a/src/poli/objective_repository/rdkit_qed/register.py b/src/poli/objective_repository/rdkit_qed/register.py index 543852d8..24a082c0 100644 --- a/src/poli/objective_repository/rdkit_qed/register.py +++ b/src/poli/objective_repository/rdkit_qed/register.py @@ -78,11 +78,11 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): """ Initialize the QEDBlackBox. @@ -118,7 +118,9 @@ def __init__( self.string_representation = string_representation self.alphabet = alphabet - self.max_sequence_length = max_sequence_length + self.max_sequence_length = ( + max_sequence_length if max_sequence_length != "inf" else float("inf") + ) super().__init__( batch_size=batch_size, @@ -128,7 +130,7 @@ def __init__( ) # The only method you have to define - def _black_box(self, x: np.ndarray, context: dict = None) -> np.ndarray: + def _black_box(self, x: np.ndarray, context: dict | None = None) -> np.ndarray: """Computes the qed of the molecule in x. Parameters @@ -218,12 +220,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """Creates a QED black box function and initial observations. @@ -261,7 +263,7 @@ def create( self.string_representation = string_representation f = QEDBlackBox( - string_representation=string_representation.upper(), + string_representation=string_representation, alphabet=alphabet, max_sequence_length=max_sequence_length, batch_size=batch_size, @@ -276,4 +278,4 @@ def create( else: x0 = np.array([["[C]" * 10]]) - return TDCProblem(f, x0) + return TDCProblem(f, x0) # type: ignore diff --git a/src/poli/objective_repository/rfp_foldx_stability/register.py b/src/poli/objective_repository/rfp_foldx_stability/register.py index 29bb829c..51017d72 100644 --- a/src/poli/objective_repository/rfp_foldx_stability/register.py +++ b/src/poli/objective_repository/rfp_foldx_stability/register.py @@ -12,15 +12,15 @@ class RFPFoldXStabilityBlackBox(FoldXStabilityBlackBox): def __init__( self, - experiment_id=None, - tmp_folder=None, - eager_repair=False, - verbose=False, - batch_size=1, - parallelize=False, - num_workers=None, - evaluation_budget=None, - force_isolation=False, + experiment_id: str | None = None, + tmp_folder: Path | None = None, + eager_repair: bool = False, + verbose: bool = False, + batch_size: int = 1, + parallelize: bool = False, + num_workers: int | None = None, + evaluation_budget: int | None = None, + force_isolation: bool = False, ): RFP_FOLDX_ASSETS_DIR = Path(__file__).parent / "assets" diff --git a/src/poli/objective_repository/rfp_foldx_stability_and_sasa/register.py b/src/poli/objective_repository/rfp_foldx_stability_and_sasa/register.py index d4351262..4920daba 100644 --- a/src/poli/objective_repository/rfp_foldx_stability_and_sasa/register.py +++ b/src/poli/objective_repository/rfp_foldx_stability_and_sasa/register.py @@ -18,7 +18,7 @@ import warnings from pathlib import Path -from typing import List, Union +from typing import List, Union, cast import numpy as np @@ -42,18 +42,18 @@ class RFPFoldXStabilityAndSASAProblemFactory(AbstractProblemFactory): def create( self, - wildtype_pdb_path: Union[Path, List[Path]], - n_starting_points: int = None, + wildtype_pdb_path: Union[Path, list[Path]], + n_starting_points: int | None = None, strict: bool = False, - experiment_id: str = None, - tmp_folder: Path = None, + experiment_id: str | None = None, + tmp_folder: Path | None = None, eager_repair: bool = False, verbose: bool = False, - seed: int = None, + seed: int | None = None, batch_size: int = 1, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ @@ -61,7 +61,7 @@ def create( Parameters ---------- - wildtype_pdb_path : Union[Path, List[Path]] + wildtype_pdb_path : Union[Path, list[Path]] Path or list of paths to the wildtype PDB files. n_starting_points: int, optional Size of D_0. Default is all available data. @@ -113,7 +113,9 @@ def create( wildtype_pdb_path = [wildtype_pdb_path] elif isinstance(wildtype_pdb_path, list): if isinstance(wildtype_pdb_path[0], str): - wildtype_pdb_path = [Path(x.strip()) for x in wildtype_pdb_path] + wildtype_pdb_path = [ + Path(cast(str, x).strip()) for x in wildtype_pdb_path + ] elif isinstance(wildtype_pdb_path[0], Path): pass else: @@ -162,7 +164,7 @@ def create( remaining_wildtype_pdb_files = list( set(wildtype_pdb_path) - set(pareto_pdb_files) ) - np.random.shuffle(remaining_wildtype_pdb_files) + np.random.shuffle(remaining_wildtype_pdb_files) # type: ignore remaining_wildtype_pdb_files = remaining_wildtype_pdb_files[ :remaining_n_starting_points ] # subselect w.r.t. requested number of sequences diff --git a/src/poli/objective_repository/rfp_rasp/register.py b/src/poli/objective_repository/rfp_rasp/register.py index 0e33397d..fd52f07c 100644 --- a/src/poli/objective_repository/rfp_rasp/register.py +++ b/src/poli/objective_repository/rfp_rasp/register.py @@ -11,16 +11,16 @@ class RFPRaspBlackBox(RaspBlackBox): def __init__( self, - additive=True, - penalize_unfeasible_with=None, - device=None, - experiment_id=None, - tmp_folder=None, - batch_size=None, - parallelize=False, - num_workers=None, - evaluation_budget=None, - force_isolation=False, + additive: bool = True, + penalize_unfeasible_with: float | None = None, + device: str | None = None, + experiment_id: str | None = None, + tmp_folder: Path | None = None, + batch_size: int | None = None, + parallelize: bool = False, + num_workers: int | None = None, + evaluation_budget: int | None = None, + force_isolation: bool = False, ): RFP_PDB_PATH = Path(__file__).parent / "assets" wildtype_pdb_path = [ @@ -54,13 +54,13 @@ def create( additive: bool = True, penalize_unfeasible_with: float | None = None, device: str | None = None, - experiment_id: str = None, - tmp_folder: Path = None, - seed: int = None, - batch_size: int = None, + experiment_id: str | None = None, + tmp_folder: Path | None = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ): if seed is not None: diff --git a/src/poli/objective_repository/rmf_landscape/isolated_function.py b/src/poli/objective_repository/rmf_landscape/isolated_function.py index 6d602f22..84d35216 100644 --- a/src/poli/objective_repository/rmf_landscape/isolated_function.py +++ b/src/poli/objective_repository/rmf_landscape/isolated_function.py @@ -1,10 +1,9 @@ from __future__ import annotations import logging -from typing import List import numpy as np -from scipy.stats import genpareto +from scipy.stats import genpareto # type: ignore[reportMissingImports] from poli.core.abstract_isolated_function import AbstractIsolatedFunction from poli.core.util.proteins.defaults import AMINO_ACIDS, ENCODING @@ -16,11 +15,11 @@ class RMFIsolatedLogic(AbstractIsolatedFunction): Parameters ---------- - wildtype : List[str] + wildtype : list[str] String sequence of the reference, default: None. c : float, optional - alphabet : List[str] + alphabet : list[str] Alphabet for the problem, by default AA list provided from poli.core.util.proteins.defaults stochasticity: str, optional Methods @@ -36,11 +35,11 @@ class RMFIsolatedLogic(AbstractIsolatedFunction): def __init__( self, - wildtype: List[str], - wt_val: float | None = 0.0, + wildtype: list[str], + wt_val: float = 0.0, c: float | None = None, - kappa: float | None = 0.1, - alphabet: List[str] | None = None, + kappa: float = 0.1, + alphabet: list[str] | None = None, seed: int | None = 0, ) -> None: """ @@ -51,16 +50,18 @@ def __init__( "Did you forget to pass it to the create of the black box?" ) if not isinstance(wildtype, np.ndarray): - wildtype = np.array(list(wildtype)) - self.wildtype = wildtype + wildtype_ = np.array(list(wildtype)) + else: + wildtype_ = wildtype + self.wildtype = wildtype_ self.seed = seed if alphabet is None: logging.info("using default alphabet AAs.") alphabet = AMINO_ACIDS assert all( - [aa in ENCODING.keys() for aa in wildtype] + [aa in ENCODING.keys() for aa in wildtype_] ), "Input wildtype elements not in encoding alphabet." - self.wt_int = np.array([ENCODING.get(aa) for aa in wildtype]) + self.wt_int = np.array([ENCODING.get(aa) for aa in wildtype_]) if c is None: c = 1 / (len(alphabet) - 1) else: diff --git a/src/poli/objective_repository/rmf_landscape/register.py b/src/poli/objective_repository/rmf_landscape/register.py index a4253de3..eeb6c9f8 100644 --- a/src/poli/objective_repository/rmf_landscape/register.py +++ b/src/poli/objective_repository/rmf_landscape/register.py @@ -11,8 +11,6 @@ from __future__ import annotations -from typing import List - import numpy as np from poli.core.abstract_black_box import AbstractBlackBox @@ -30,7 +28,7 @@ class RMFBlackBox(AbstractBlackBox): Parameters ---------- - wildtype : str + wildtype : list[str] The wildtype amino-acid sequence (aka reference sequence) against which all RMF values are computed against. wt_val : float , optional The reference value for the WT, zero if observations are standardized, else float value e.g. ddGs @@ -41,7 +39,7 @@ class RMFBlackBox(AbstractBlackBox): Determines what type of distribution will be sampled from exponential family, Weibull, etc. seed : int, optional Random seed for replicability of results, by default None. - alphabet : List[str], optional + alphabet : list[str], optional Type of alphabet of the sequences, by default Amino Acids. Nucleic Acids possible. batch_size : int, optional @@ -58,14 +56,14 @@ class RMFBlackBox(AbstractBlackBox): def __init__( self, - wildtype: str, + wildtype: list[str], wt_val: float = 0.0, c: float | None = None, kappa: float = 0.1, seed: int | None = None, - alphabet: List[str] | None = None, + alphabet: list[str] | None = None, batch_size: int | None = None, - parallelize: bool | None = False, + parallelize: bool = False, num_workers: int | None = None, evaluation_budget: int | None = None, force_isolation: bool = False, @@ -112,7 +110,7 @@ def __init__( force_isolation=self.force_isolation, ) - def _black_box(self, x: np.ndarray, context: None) -> np.ndarray: + def _black_box(self, x: np.ndarray, context: dict | None = None) -> np.ndarray: """ Runs the given input x provided in the context with the RMF function and returns the @@ -169,16 +167,16 @@ class RMFProblemFactory(AbstractProblemFactory): def create( self, - wildtype: List[str] | str, - wt_val: float | None = 0.0, + wildtype: list[str] | str, + wt_val: float = 0.0, c: float | None = None, kappa: float = 0.1, - alphabet: List[str] | None = None, - seed: int = None, - batch_size: int = None, + alphabet: list[str] | None = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ @@ -186,7 +184,7 @@ def create( Parameters ---------- - wildtype : List[str] | str + wildtype : list[str] | str Reference (wild-type) sequence is pseudo-optimum on start. wt_val : float, optional Reference function value (standardized observations) of WT. @@ -195,7 +193,7 @@ def create( If None passed default value is regularizing 1/(len(alphabet)-1) . kappa: float Determines generalized Pareto continuous RV. - alphabet: List[str], optional + alphabet: list[str], optional Problem alphabet used, if None is passed default: AMINO_ACIDS. seed : int, optional Seed for random number generators. If None is passed, diff --git a/src/poli/objective_repository/rosetta_energy/isolated_function.py b/src/poli/objective_repository/rosetta_energy/isolated_function.py index b3169602..65341352 100644 --- a/src/poli/objective_repository/rosetta_energy/isolated_function.py +++ b/src/poli/objective_repository/rosetta_energy/isolated_function.py @@ -1,3 +1,4 @@ +# type: ignore from __future__ import annotations import logging diff --git a/src/poli/objective_repository/rosetta_energy/register.py b/src/poli/objective_repository/rosetta_energy/register.py index 0c567959..fb723964 100644 --- a/src/poli/objective_repository/rosetta_energy/register.py +++ b/src/poli/objective_repository/rosetta_energy/register.py @@ -18,7 +18,7 @@ from __future__ import annotations from pathlib import Path -from typing import Callable, List +from typing import Callable import numpy as np @@ -124,7 +124,7 @@ class RosettaEnergyBlackBox(AbstractBlackBox): def __init__( self, - wildtype_pdb_path: Path | List[Path], + wildtype_pdb_path: Path | list[Path], score_function: str = "default", seed: int = 0, unit: str = "DDG", @@ -135,10 +135,10 @@ def __init__( cycle: int = 3, constraint_weight: float = 5.0, n_threads: int = 4, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ): super().__init__( @@ -179,9 +179,9 @@ def __init__( n_threads=self.n_threads, ) self.inner_function = opt_in_wrapper(inner_function) - self.x0 = self.inner_function.x0 + self.x0 = self.inner_function.x0 # type: ignore - def _black_box(self, x: np.ndarray, context: dict = None) -> np.ndarray: + def _black_box(self, x: np.ndarray, context: dict | None = None) -> np.ndarray: """ Computes the stability of the mutant(s) in x. @@ -225,7 +225,7 @@ def get_setup_information(self) -> BlackBoxInformation: def create( self, - wildtype_pdb_path: Path | List[Path], + wildtype_pdb_path: Path | list[Path], score_function: str = "default", seed: int = 0, unit: str = "DDG", @@ -236,10 +236,10 @@ def create( cycle: int = 3, constraint_weight: int | float = 5, n_threads: int = 4, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ @@ -324,6 +324,6 @@ def create( ) # Your first input (an np.array[str] of shape [b, L] or [b,]) - x0 = f.inner_function.x0 + x0 = f.inner_function.x0 # type: ignore return Problem(f, x0) diff --git a/src/poli/objective_repository/sa_tdc/register.py b/src/poli/objective_repository/sa_tdc/register.py index 3769015f..49dd78f2 100644 --- a/src/poli/objective_repository/sa_tdc/register.py +++ b/src/poli/objective_repository/sa_tdc/register.py @@ -51,11 +51,11 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ): """ @@ -120,12 +120,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/scaffold_hop/register.py b/src/poli/objective_repository/scaffold_hop/register.py index f136a03e..597a2e9b 100644 --- a/src/poli/objective_repository/scaffold_hop/register.py +++ b/src/poli/objective_repository/scaffold_hop/register.py @@ -82,12 +82,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Scaffold Hop", @@ -141,12 +141,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/sitagliptin_mpo/register.py b/src/poli/objective_repository/sitagliptin_mpo/register.py index cf230d16..7ad29ce0 100644 --- a/src/poli/objective_repository/sitagliptin_mpo/register.py +++ b/src/poli/objective_repository/sitagliptin_mpo/register.py @@ -81,12 +81,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Sitagliptin_MPO", @@ -140,12 +140,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/super_mario_bros/isolated_function.py b/src/poli/objective_repository/super_mario_bros/isolated_function.py index 3a8c4b67..e070292f 100644 --- a/src/poli/objective_repository/super_mario_bros/isolated_function.py +++ b/src/poli/objective_repository/super_mario_bros/isolated_function.py @@ -7,8 +7,10 @@ """ +from __future__ import annotations + from pathlib import Path -from typing import List +from typing import cast import numpy as np @@ -58,7 +60,7 @@ class SMBIsolatedLogic(AbstractIsolatedFunction): def __init__( self, - alphabet: List[str] = smb_info.alphabet, + alphabet: list[str] = cast(list[str], smb_info.alphabet), max_time: int = 30, visualize: bool = False, value_on_unplayable: float = np.nan, diff --git a/src/poli/objective_repository/super_mario_bros/level_utils.py b/src/poli/objective_repository/super_mario_bros/level_utils.py index 373d66fb..37add73d 100644 --- a/src/poli/objective_repository/super_mario_bros/level_utils.py +++ b/src/poli/objective_repository/super_mario_bros/level_utils.py @@ -1,12 +1,14 @@ """Utilities for transforming levels to arrays and back.""" +from __future__ import annotations + from itertools import product -from typing import List import numpy as np +from numpy.typing import NDArray -def level_to_list(level_txt: str) -> List[List[str]]: +def level_to_list(level_txt: str) -> list[list[str]]: """ Takes a level as a string and returns a list of lists of individual tokens. @@ -24,7 +26,9 @@ def level_to_array(level_txt: str) -> np.ndarray: return np.array(level_to_list(level_txt)) -def levels_to_onehot(levels: np.ndarray, n_sprites: int = 11) -> np.ndarray: +def levels_to_onehot( + levels: NDArray[np.int_], n_sprites: int = 11 +) -> NDArray[np.float64]: """Transforms an array [b, w, h] of integers into a one-hot array [b, n_sprites, w, h].""" batch_size, w, h = levels.shape y_onehot = np.zeros((batch_size, n_sprites, h, w)) @@ -45,7 +49,9 @@ def vectorized(prob_matrix, items): return items[k] -def onehot_to_levels(levels_onehot: np.ndarray, sampling=False, seed=0) -> np.ndarray: +def onehot_to_levels( + levels_onehot: NDArray[np.float64], sampling=False, seed=0 +) -> NDArray[np.int_]: """ Transforms a level from probits to integers. """ @@ -77,7 +83,9 @@ def onehot_to_levels(levels_onehot: np.ndarray, sampling=False, seed=0) -> np.nd return levels -def add_padding_to_level(level: np.ndarray, n_padding: int = 1) -> np.ndarray: +def add_padding_to_level( + level: NDArray[np.int_], n_padding: int = 1 +) -> NDArray[np.int_]: """ Adds padding to the left of the level, giving room for the agent to land. @@ -85,12 +93,12 @@ def add_padding_to_level(level: np.ndarray, n_padding: int = 1) -> np.ndarray: h, w = level.shape padding = 2 * np.ones((h, n_padding)) # Starting with emptyness. padding[-1, :] = 0 # Adding the ground. - level_with_padding = np.concatenate((padding, level), axis=1) + level_with_padding = np.concatenate((padding, level), axis=1).astype(np.int_) return level_with_padding -def clean_level(level: np.ndarray) -> List[List[int]]: +def clean_level(level: NDArray[np.int_]) -> list[list[int]]: """ Cleans a level by removing Mario (token id: 11), and replacing it with empty space. diff --git a/src/poli/objective_repository/super_mario_bros/register.py b/src/poli/objective_repository/super_mario_bros/register.py index d6eb1d6d..fa06f3df 100644 --- a/src/poli/objective_repository/super_mario_bros/register.py +++ b/src/poli/objective_repository/super_mario_bros/register.py @@ -7,6 +7,8 @@ """ +from __future__ import annotations + from pathlib import Path import numpy as np @@ -61,10 +63,10 @@ def __init__( max_time: int = 30, visualize: bool = False, value_on_unplayable: float = np.nan, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ): """ @@ -146,11 +148,11 @@ def create( max_time: int = 30, visualize: bool = False, value_on_unplayable: float = np.nan, - seed: int = None, - batch_size: int = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """Creates a new instance of the Super Mario Bros problem. diff --git a/src/poli/objective_repository/super_mario_bros/simulator.py b/src/poli/objective_repository/super_mario_bros/simulator.py index 8f0b101a..883fb429 100644 --- a/src/poli/objective_repository/super_mario_bros/simulator.py +++ b/src/poli/objective_repository/super_mario_bros/simulator.py @@ -9,23 +9,25 @@ from pathlib import Path import numpy as np -from level_utils import clean_level +from numpy.typing import NDArray + +from poli.objective_repository.super_mario_bros.level_utils import clean_level filepath = Path(__file__).parent.resolve() JARFILE_PATH = f"{filepath}/simulator.jar" def test_level_from_int_array( - level: np.ndarray, + level: NDArray[np.int_], human_player: bool = False, max_time: int = 45, visualize: bool = False, ) -> dict: - level = clean_level(level) - level = str(level) + level_ = clean_level(level) + level_ = str(level_) return run_level( - level, human_player=human_player, max_time=max_time, visualize=visualize + level_, human_player=human_player, max_time=max_time, visualize=visualize ) @@ -35,10 +37,10 @@ def test_level_from_str_array( max_time: int = 45, visualize: bool = False, ) -> dict: - level = str(level) + level_ = str(level) return run_level( - level, human_player=human_player, max_time=max_time, visualize=visualize + level_, human_player=human_player, max_time=max_time, visualize=visualize ) @@ -68,7 +70,7 @@ def run_level( stdout=subprocess.PIPE, ) - lines = java.stdout.readlines() + lines = java.stdout.readlines() # type: ignore res = lines[-1] res = json.loads(res.decode("utf8")) res["level"] = level diff --git a/src/poli/objective_repository/thiothixene_rediscovery/register.py b/src/poli/objective_repository/thiothixene_rediscovery/register.py index 8cc3236e..c7cbb2f9 100644 --- a/src/poli/objective_repository/thiothixene_rediscovery/register.py +++ b/src/poli/objective_repository/thiothixene_rediscovery/register.py @@ -84,12 +84,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Thiothixene_Rediscovery", @@ -142,12 +142,12 @@ def create( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + max_sequence_length: int | Literal["inf"] = "inf", + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/toy_continuous_problem/definitions.py b/src/poli/objective_repository/toy_continuous_problem/definitions.py index 4e530d8f..daa9ba7a 100644 --- a/src/poli/objective_repository/toy_continuous_problem/definitions.py +++ b/src/poli/objective_repository/toy_continuous_problem/definitions.py @@ -14,10 +14,15 @@ Test Functions and Datasets. [https://www.sfu.ca/~ssurjano/optimization.html] """ +from __future__ import annotations + +from typing import cast + import numpy as np +from numpy.typing import NDArray -def ackley_function_01(x: np.ndarray) -> np.ndarray: +def ackley_function_01(x: NDArray[np.float64]) -> NDArray[np.float64]: if len(x.shape) == 1: # Add a batch dimension if it's missing x = x.reshape(-1, x.shape[0]) @@ -38,7 +43,7 @@ def ackley_function_01(x: np.ndarray) -> np.ndarray: return res -def alpine_01(x: np.ndarray) -> np.ndarray: +def alpine_01(x: NDArray[np.float64]) -> NDArray[np.float64]: if len(x.shape) == 1: # Add a batch dimension if it's missing x = x.reshape(-1, x.shape[0]) @@ -72,7 +77,7 @@ def alpine_02(x: np.ndarray) -> np.ndarray: return res -def bent_cigar(x: np.ndarray) -> np.ndarray: +def bent_cigar(x: NDArray[np.float64]) -> NDArray[np.float64]: if len(x.shape) == 1: # Add a batch dimension if it's missing x = x.reshape(-1, x.shape[0]) @@ -91,7 +96,7 @@ def bent_cigar(x: np.ndarray) -> np.ndarray: return res -def brown(x: np.ndarray) -> np.ndarray: +def brown(x: NDArray[np.float64]) -> NDArray[np.float64]: if len(x.shape) == 1: # Add a batch dimension if it's missing x = x.reshape(-1, x.shape[0]) @@ -111,7 +116,7 @@ def brown(x: np.ndarray) -> np.ndarray: return res -def chung_reynolds(x: np.ndarray) -> np.ndarray: +def chung_reynolds(x: NDArray[np.float64]) -> NDArray[np.float64]: if len(x.shape) == 1: # Add a batch dimension if it's missing x = x.reshape(-1, x.shape[0]) @@ -128,7 +133,7 @@ def chung_reynolds(x: np.ndarray) -> np.ndarray: return res -def cosine_mixture(x: np.ndarray) -> np.ndarray: +def cosine_mixture(x: NDArray[np.float64]) -> NDArray[np.float64]: if len(x.shape) == 1: # Add a batch dimension if it's missing x = x.reshape(-1, x.shape[0]) @@ -148,7 +153,7 @@ def cosine_mixture(x: np.ndarray) -> np.ndarray: return res -def deb_01(x: np.ndarray) -> np.ndarray: +def deb_01(x: NDArray[np.float64]) -> NDArray[np.float64]: if len(x.shape) == 1: # Add a batch dimension if it's missing x = x.reshape(-1, x.shape[0]) @@ -166,7 +171,7 @@ def deb_01(x: np.ndarray) -> np.ndarray: return res -def deb_02(x: np.ndarray) -> np.ndarray: +def deb_02(x: NDArray[np.float64]) -> NDArray[np.float64]: if len(x.shape) == 1: # Add a batch dimension if it's missing x = x.reshape(-1, x.shape[0]) @@ -185,8 +190,8 @@ def deb_02(x: np.ndarray) -> np.ndarray: def deflected_corrugated_spring( - x: np.ndarray, alpha: float = 5.0, k: float = 5.0 -) -> np.ndarray: + x: NDArray[np.float64], alpha: float = 5.0, k: float = 5.0 +) -> NDArray[np.float64]: if len(x.shape) == 1: # Add a batch dimension if it's missing x = x.reshape(-1, x.shape[0]) @@ -204,7 +209,9 @@ def deflected_corrugated_spring( return res -def styblinski_tang(x: np.ndarray, normalize: bool = True) -> np.ndarray: +def styblinski_tang( + x: NDArray[np.float64], normalize: bool = True +) -> NDArray[np.float64]: """ This function is maximized at (-2.903534, ..., -2.903534), with a value of -39.16599 * d. @@ -222,7 +229,7 @@ def styblinski_tang(x: np.ndarray, normalize: bool = True) -> np.ndarray: return -0.5 * np.sum(y, axis=1) -def easom(xy: np.ndarray) -> np.ndarray: +def easom(xy: NDArray[np.float64]) -> NDArray[np.float64]: """ Easom is very flat, with a maxima at (pi, pi). @@ -232,10 +239,11 @@ def easom(xy: np.ndarray) -> np.ndarray: assert xy.shape[1] == 2, "Easom only works in 2D. " x = xy[..., 0] y = xy[..., 1] - return np.cos(x) * np.cos(y) * np.exp(-((x - np.pi) ** 2 + (y - np.pi) ** 2)) + res = np.cos(x) * np.cos(y) * np.exp(-((x - np.pi) ** 2 + (y - np.pi) ** 2)) + return cast(NDArray[np.float64], res) -def cross_in_tray(xy: np.ndarray) -> np.ndarray: +def cross_in_tray(xy: NDArray[np.float64]) -> NDArray[np.float64]: """ Cross-in-tray has several local maxima in a quilt-like pattern. @@ -251,7 +259,7 @@ def cross_in_tray(xy: np.ndarray) -> np.ndarray: ) -def egg_holder(xy: np.ndarray) -> np.ndarray: +def egg_holder(xy: NDArray[np.float64]) -> NDArray[np.float64]: """ The egg holder is especially difficult. @@ -266,7 +274,7 @@ def egg_holder(xy: np.ndarray) -> np.ndarray: ) -def shifted_sphere(x: np.ndarray) -> np.ndarray: +def shifted_sphere(x: NDArray[np.float64]) -> NDArray[np.float64]: """ The usual squared norm, but shifted away from the origin by a bit. Maximized at (1, 1, ..., 1) @@ -287,7 +295,7 @@ def shifted_sphere(x: np.ndarray) -> np.ndarray: return res -def camelback_2d(x: np.ndarray) -> np.ndarray: +def camelback_2d(x: NDArray[np.float64]) -> NDArray[np.float64]: """ Taken directly from the LineBO repository [1]. @@ -305,7 +313,7 @@ def camelback_2d(x: np.ndarray) -> np.ndarray: return np.maximum(-y, -2.5) -def hartmann_6d(x: np.ndarray) -> np.ndarray: +def hartmann_6d(x: NDArray[np.float64]) -> NDArray[np.float64]: """ The 6 dimensional Hartmann function. @@ -351,7 +359,7 @@ def hartmann_6d(x: np.ndarray) -> np.ndarray: return np.array(res).reshape(-1, 1) -def branin_2d(x: np.ndarray) -> np.ndarray: +def branin_2d(x: NDArray[np.float64]) -> NDArray[np.float64]: """ The 2D Branin function. @@ -376,7 +384,9 @@ def branin_2d(x: np.ndarray) -> np.ndarray: return -y -def rosenbrock(x: np.ndarray, a: float = 1.0, b: float = 100.0): +def rosenbrock( + x: NDArray[np.float64], a: float = 1.0, b: float = 100.0 +) -> NDArray[np.float64]: """ Compute the Rosenbrock function. @@ -399,7 +409,7 @@ def rosenbrock(x: np.ndarray, a: float = 1.0, b: float = 100.0): ) -def levy(x: np.ndarray): +def levy(x: NDArray[np.float64]) -> NDArray[np.float64]: """ Compute the Levy function. @@ -430,7 +440,7 @@ def levy(x: np.ndarray): return -(term1 + term2 + term3) -def himmelblau(x: np.ndarray): +def himmelblau(x: NDArray[np.float64]) -> NDArray[np.float64]: """ Compute the Himmelblau function. @@ -455,8 +465,3 @@ def himmelblau(x: np.ndarray): x2 = x[:, 1] return -((x1**2 + x2 - 11) ** 2 + (x1 + x2**2 - 7) ** 2) - - -if __name__ == "__main__": - b = branin_2d - maximal_b = b(np.array([[-np.pi, 12.275], [np.pi, 2.275], [9.42478, 2.475]])) diff --git a/src/poli/objective_repository/toy_continuous_problem/register.py b/src/poli/objective_repository/toy_continuous_problem/register.py index b89838e5..1d92d70c 100644 --- a/src/poli/objective_repository/toy_continuous_problem/register.py +++ b/src/poli/objective_repository/toy_continuous_problem/register.py @@ -12,7 +12,7 @@ (see the environment.yml file in this folder). """ -from typing import List +from __future__ import annotations import numpy as np @@ -22,7 +22,11 @@ from poli.core.problem import Problem from poli.core.util.seeding import seed_python_numpy_and_torch -from .toy_continuous_problem import POSSIBLE_FUNCTIONS, ToyContinuousProblem +from .toy_continuous_problem import ( + POSSIBLE_FUNCTIONS, + POSSIBLE_FUNCTIONS_TYPE, + ToyContinuousProblem, +) class ToyContinuousBlackBox(AbstractBlackBox): @@ -38,7 +42,7 @@ class ToyContinuousBlackBox(AbstractBlackBox): embed_in : int, optional If not None, the continuous problem is randomly embedded in this dimension. By default, None. - dimensions_to_embed_in: List[int], optional + dimensions_to_embed_in: list[int], optional The dimensions in which to embed the problem, by default None. Only has an effect if embed_in is not None. batch_size : int, optional The batch size for parallel evaluation, by default None. @@ -71,14 +75,14 @@ class ToyContinuousBlackBox(AbstractBlackBox): def __init__( self, - function_name: str, + function_name: POSSIBLE_FUNCTIONS_TYPE, n_dimensions: int = 2, - embed_in: int = None, - dimensions_to_embed_in: List[int] = None, - batch_size: int = None, + embed_in: int | None = None, + dimensions_to_embed_in: list[int] | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): assert ( @@ -104,7 +108,7 @@ def __init__( evaluation_budget=evaluation_budget, ) - def _black_box(self, x: np.ndarray, context: dict = None) -> np.ndarray: + def _black_box(self, x: np.ndarray, context: dict | None = None) -> np.ndarray: """ Evaluates the toy continuous problem on a continuous input x. @@ -150,15 +154,15 @@ def get_black_box_info(self) -> BlackBoxInformation: class ToyContinuousProblemFactory(AbstractProblemFactory): def create( self, - function_name: str, + function_name: POSSIBLE_FUNCTIONS_TYPE, n_dimensions: int = 2, - embed_in: int = None, - dimensions_to_embed_in: List[int] = None, - seed: int = None, - batch_size: int = None, + embed_in: int | None = None, + dimensions_to_embed_in: list[int] | None = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ @@ -173,7 +177,7 @@ def create( embed_in : int, optional If not None, the continuous problem is randomly embedded in this dimension. By default, None. - dimensions_to_embed_in: List[int], optional + dimensions_to_embed_in: list[int], optional The dimensions in which to embed the problem, by default None. Only has an effect if embed_in is not None. seed : int, optional The seed for the random number generator, by default None. diff --git a/src/poli/objective_repository/toy_continuous_problem/toy_continuous_problem.py b/src/poli/objective_repository/toy_continuous_problem/toy_continuous_problem.py index 83173ba6..7dba1a13 100644 --- a/src/poli/objective_repository/toy_continuous_problem/toy_continuous_problem.py +++ b/src/poli/objective_repository/toy_continuous_problem/toy_continuous_problem.py @@ -5,9 +5,12 @@ https://en.wikipedia.org/wiki/Test_functions_for_optimization """ -from typing import List, Literal +from __future__ import annotations + +from typing import Literal import numpy as np +from numpy.typing import NDArray from .definitions import ( ackley_function_01, @@ -68,6 +71,30 @@ ] SIX_DIMENSIONAL_PROBLEMS = ["hartmann_6d"] +POSSIBLE_FUNCTIONS_TYPE = Literal[ + "ackley_function_01", + "alpine_01", + "alpine_02", + "bent_cigar", + "brown", + "chung_reynolds", + "cosine_mixture", + "deb_01", + "deb_02", + "deflected_corrugated_spring", + "styblinski_tang", + "shifted_sphere", + "easom", + "cross_in_tray", + "egg_holder", + "camelback_2d", + "hartmann_6d", + "branin_2d", + "rosenbrock", + "levy", + "himmelblau", +] + class ToyContinuousProblem: """ @@ -80,31 +107,10 @@ class ToyContinuousProblem: def __init__( self, - name: Literal[ - "ackley_function_01", - "alpine_01", - "alpine_02", - "bent_cigar", - "brown", - "chung_reynolds", - "cosine_mixture", - "deb_01", - "deb_02", - "deflected_corrugated_spring", - "styblinski_tang", - "shifted_sphere", - "easom", - "cross_in_tray", - "egg_holder", - "camelback_2d", - "branin_2d", - "hartmann_6d", - "rosenbrock", - "levy", - ], + name: POSSIBLE_FUNCTIONS_TYPE, n_dims: int = 2, - embed_in: int = None, - dimensions_to_embed_in: List[int] = None, + embed_in: int | None = None, + dimensions_to_embed_in: list[int] | None = None, ) -> None: self.maximize = True self.known_optima = True @@ -300,8 +306,10 @@ def __init__( f" but received {n_dims}." ) - def evaluate_objective(self, x: np.array, **kwargs) -> np.array: + def evaluate_objective( + self, x: NDArray[np.float64], **kwargs + ) -> NDArray[np.float64]: return self.function(x) - def __call__(self, x: np.array) -> np.array: + def __call__(self, x: NDArray[np.float64]) -> NDArray[np.float64]: return self.function(x).reshape(-1, 1) diff --git a/src/poli/objective_repository/troglitazone_rediscovery/register.py b/src/poli/objective_repository/troglitazone_rediscovery/register.py index e1021ca0..36e6bfa3 100644 --- a/src/poli/objective_repository/troglitazone_rediscovery/register.py +++ b/src/poli/objective_repository/troglitazone_rediscovery/register.py @@ -84,12 +84,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Troglitazone_Rediscovery", @@ -140,15 +140,15 @@ class TroglitazoneRediscoveryProblemFactory(AbstractProblemFactory): def create( self, - string_representation: Literal["SMILES", "SELFIES"] = "SMILES", - alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, + string_representation: Literal["SMILES", "SELFIES"] = "SMILES", + alphabet: list[str] | None = None, + max_sequence_length: int | Literal["inf"] = "inf", ) -> Problem: """ Creates a Troglitazone rediscovery problem. diff --git a/src/poli/objective_repository/valsartan_smarts/register.py b/src/poli/objective_repository/valsartan_smarts/register.py index a6334940..20987617 100644 --- a/src/poli/objective_repository/valsartan_smarts/register.py +++ b/src/poli/objective_repository/valsartan_smarts/register.py @@ -83,12 +83,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Valsartan_SMARTS", @@ -139,15 +139,15 @@ class ValsartanSMARTSProblemFactory(AbstractProblemFactory): def create( self, - string_representation: Literal["SMILES", "SELFIES"] = "SMILES", - alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, + string_representation: Literal["SMILES", "SELFIES"] = "SMILES", + alphabet: list[str] | None = None, + max_sequence_length: int | Literal["inf"] = "inf", ) -> Problem: """ Creates a Valsartan SMARTS problem. diff --git a/src/poli/objective_repository/white_noise/register.py b/src/poli/objective_repository/white_noise/register.py index ee433154..30894798 100644 --- a/src/poli/objective_repository/white_noise/register.py +++ b/src/poli/objective_repository/white_noise/register.py @@ -42,10 +42,10 @@ class WhiteNoiseBlackBox(AbstractBlackBox): def __init__( self, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): """ Initializes a WhiteNoiseBlackBox. @@ -72,7 +72,7 @@ def __init__( evaluation_budget=evaluation_budget, ) - def _black_box(self, x: np.ndarray, context: dict = None) -> np.ndarray: + def _black_box(self, x: np.ndarray, context: dict | None = None) -> np.ndarray: """Returns standard Gaussian noise. Parameters @@ -106,11 +106,11 @@ def get_black_box_info(self) -> BlackBoxInformation: class WhiteNoiseProblemFactory(AbstractProblemFactory): def create( self, - seed: int = None, - batch_size: int = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, ) -> Problem: """ diff --git a/src/poli/objective_repository/zaleplon_mpo/register.py b/src/poli/objective_repository/zaleplon_mpo/register.py index 236ffeb2..35e434ba 100644 --- a/src/poli/objective_repository/zaleplon_mpo/register.py +++ b/src/poli/objective_repository/zaleplon_mpo/register.py @@ -82,12 +82,12 @@ def __init__( self, string_representation: Literal["SMILES", "SELFIES"] = "SMILES", alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, + max_sequence_length: int | Literal["inf"] = "inf", force_isolation: bool = False, - batch_size: int = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, ): super().__init__( oracle_name="Zaleplon_MPO", @@ -139,15 +139,15 @@ class ZaleplonMPOProblemFactory(AbstractProblemFactory): def create( self, - string_representation: Literal["SMILES", "SELFIES"] = "SMILES", - alphabet: list[str] | None = None, - max_sequence_length: int = np.inf, - seed: int = None, - batch_size: int = None, + seed: int | None = None, + batch_size: int | None = None, parallelize: bool = False, - num_workers: int = None, - evaluation_budget: int = None, + num_workers: int | None = None, + evaluation_budget: int | None = None, force_isolation: bool = False, + string_representation: Literal["SMILES", "SELFIES"] = "SMILES", + alphabet: list[str] | None = None, + max_sequence_length: int | Literal["inf"] = "inf", ) -> Problem: """ Creates a Zaleplon MPO problem. @@ -160,7 +160,7 @@ def create( The alphabet to be used for the SMILES or SELFIES representation. It is common that the alphabet depends on the dataset used, so it is recommended to pass it as an argument. Default is None. - max_sequence_length : int, optional + max_sequence_length : int | Literal["inf"], optional The maximum length of the sequence. Default is infinity. seed : int, optional Seed for random number generators. If None, no seed is set. diff --git a/src/poli/tests/observers/test_observers.py b/src/poli/tests/observers/test_observers.py index 12c5995f..2e5e597f 100644 --- a/src/poli/tests/observers/test_observers.py +++ b/src/poli/tests/observers/test_observers.py @@ -8,8 +8,11 @@ them in isolated processes using `set_observer`. """ +from __future__ import annotations + import json from pathlib import Path +from typing import cast import numpy as np @@ -25,10 +28,10 @@ class SimpleObserver(AbstractObserver): def initialize_observer( self, problem_setup_info: BlackBoxInformation, - caller_info: object, - seed: int, + caller_info: dict[str, object], + seed: int | None, ) -> object: - experiment_id = caller_info["experiment_id"] + experiment_id = cast(str, caller_info["experiment_id"]) self.experiment_id = experiment_id @@ -98,9 +101,9 @@ def test_simple_observer_logs_properly(): f(np.array([list("MIGUE")])) # Checking whether the results were properly logged - assert f.observer.results == [{"x": [["M", "I", "G", "U", "E"]], "y": [[0.0]]}] - (f.observer.experiment_path / "metadata.json").unlink() - f.observer.experiment_path.rmdir() + assert f.observer.results == [{"x": [["M", "I", "G", "U", "E"]], "y": [[0.0]]}] # type: ignore + (f.observer.experiment_path / "metadata.json").unlink() # type: ignore + f.observer.experiment_path.rmdir() # type: ignore def test_observer_registration_and_external_instancing(): @@ -131,15 +134,15 @@ def test_observer_registration_and_external_instancing(): # the poli__chem environment. f.observer # The same as problem.observer._observer try: - f.observer.unexisting_attribute + f.observer.unexisting_attribute # type: ignore except AttributeError: pass # Cleaning up (and testing whether we can access attributes # of the external observer) - print(f.observer.experiment_path) - (f.observer.experiment_path / "metadata.json").unlink() - f.observer.finish() + print(f.observer.experiment_path) # type: ignore + (f.observer.experiment_path / "metadata.json").unlink() # type: ignore + f.observer.finish() # type: ignore def test_multiple_observer_registration(): @@ -177,8 +180,8 @@ def test_multiple_observer_registration(): # Cleaning up (and testing whether we can access attributes # of the external observer) - (problem_1.observer._observer.experiment_path / "metadata.json").unlink() - (problem_2.observer._observer.experiment_path / "metadata.json").unlink() + (problem_1.observer._observer.experiment_path / "metadata.json").unlink() # type: ignore + (problem_2.observer._observer.experiment_path / "metadata.json").unlink() # type: ignore problem_1.observer._observer.finish() problem_2.observer._observer.finish() diff --git a/src/poli/tests/registry/proteins/test_rasp.py b/src/poli/tests/registry/proteins/test_rasp.py index ce54b14c..5dff4e91 100644 --- a/src/poli/tests/registry/proteins/test_rasp.py +++ b/src/poli/tests/registry/proteins/test_rasp.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import cast import numpy as np import pytest @@ -158,7 +159,9 @@ def test_rasp_penalization_works(): # This is an unfeasible mutation, since joining # all the strings would result in a sequence # that is _not_ the same length as the wildtype. - problematic_x = np.array([["A"] + [""] * (f.info.max_sequence_length - 1)]) + problematic_x = np.array( + [["A"] + [""] * (cast(int, f.info.max_sequence_length) - 1)] + ) assert f(problematic_x) == -100.0 @@ -175,7 +178,9 @@ def test_rasp_penalization_works_on_multiple_inputs(): # This is an unfeasible mutation, since joining # all the strings would result in a sequence # that is _not_ the same length as the wildtype. - problematic_x = np.array([["A"] + [""] * (f.info.max_sequence_length - 1)]) + problematic_x = np.array( + [["A"] + [""] * (cast(int, f.info.max_sequence_length) - 1)] + ) combination = np.vstack([problem.x0, problematic_x]) y = f(combination) assert y[-1] == -100.0 diff --git a/src/poli/tests/registry/proteins/test_rmf.py b/src/poli/tests/registry/proteins/test_rmf.py index e776071a..23203013 100644 --- a/src/poli/tests/registry/proteins/test_rmf.py +++ b/src/poli/tests/registry/proteins/test_rmf.py @@ -88,7 +88,7 @@ def test_rmf_seed_consistent(seed: int): @pytest.mark.poli__rmf @pytest.mark.parametrize("n_mutations", [1, 2, 3]) def test_rmf_num_mutations_expected_val(n_mutations: int): - from scipy.stats import genpareto + from scipy.stats import genpareto # type: ignore[reportMissingImports] SEED = 1 mutation_seq = list(ref_aa_seq) @@ -108,8 +108,8 @@ def test_rmf_num_mutations_expected_val(n_mutations: int): y1 = f(mutation_seq) rnd_state = np.random.default_rng(SEED) - ref_noise_0 = genpareto.rvs(f.kappa, size=1, random_state=rnd_state) - ref_noise_1 = genpareto.rvs(f.kappa, size=1, random_state=rnd_state) + ref_noise_0 = genpareto.rvs(f.kappa, size=1, random_state=rnd_state) # type: ignore + ref_noise_1 = genpareto.rvs(f.kappa, size=1, random_state=rnd_state) # type: ignore # black-box value minus noisy component should be approximately mutational distance if c==1 assert np.isclose(np.round(y0 - ref_noise_0), 0) diff --git a/src/poli/tests/registry/proteins/test_rosetta.py b/src/poli/tests/registry/proteins/test_rosetta.py index ffc1165c..7e434fe9 100644 --- a/src/poli/tests/registry/proteins/test_rosetta.py +++ b/src/poli/tests/registry/proteins/test_rosetta.py @@ -61,7 +61,7 @@ def test_rosetta_wt_zero_ddg(unit): f, x0 = problem.black_box, problem.x0 y0 = f(x0) if unit == "REU": - assert f.inner_function.wt_score == y0 + assert f.inner_function.wt_score == y0 # type: ignore else: assert np.isclose(y0, 0.0) @@ -93,7 +93,7 @@ def test_rosetta_on_3ned_sequence_mutations_correct(): # E10N for i, mutant in enumerate(three_mutations): - assert mutant[:20] == f.inner_function.x_t[i][:20] + assert mutant[:20] == f.inner_function.x_t[i][:20] # type: ignore @pytest.mark.poli__rosetta_energy diff --git a/src/poli/tests/registry/test_force_isolation.py b/src/poli/tests/registry/test_force_isolation.py index ad905135..86a7c9e1 100644 --- a/src/poli/tests/registry/test_force_isolation.py +++ b/src/poli/tests/registry/test_force_isolation.py @@ -5,9 +5,12 @@ import subprocess +import pytest + from poli.core.util.isolation.instancing import get_inner_function +@pytest.mark.isolation def test_force_isolation_on_tdc(): from poli import objective_factory diff --git a/src/poli/tests/registry/test_passing_array_of_strings.py b/src/poli/tests/registry/test_passing_array_of_strings.py index fecbd7b7..893a8a6c 100644 --- a/src/poli/tests/registry/test_passing_array_of_strings.py +++ b/src/poli/tests/registry/test_passing_array_of_strings.py @@ -1,12 +1,11 @@ """This module tests whether giving black boxes an array of b strings is equivalent to giving them an array of [b, L] tokens.""" -from typing import List - import pytest # TODO: parametrize by all non-aligned blackboxes +@pytest.mark.isolation @pytest.mark.parametrize( "black_box_name, example_non_flat_input, example_flat_input, kwargs", [ @@ -104,8 +103,8 @@ ) def test_passing_array_of_strings( black_box_name: str, - example_non_flat_input: List[List[str]], - example_flat_input: List[str], + example_non_flat_input: list[list[str]], + example_flat_input: list[str], kwargs: dict, ): """This test checks whether passing an array of strings [b,] diff --git a/src/poli/tests/registry/toy_continuous_problems/test_embedding_problems_into_higher_dims.py b/src/poli/tests/registry/toy_continuous_problems/test_embedding_problems_into_higher_dims.py index fa63fa5f..73d1b02c 100644 --- a/src/poli/tests/registry/toy_continuous_problems/test_embedding_problems_into_higher_dims.py +++ b/src/poli/tests/registry/toy_continuous_problems/test_embedding_problems_into_higher_dims.py @@ -8,14 +8,19 @@ of the problem is lower than the actual dimensionality. """ +from __future__ import annotations + +from typing import cast + import numpy as np +from poli.objective_repository.toy_continuous_problem.register import ( + ToyContinuousBlackBox, +) + def test_embed_camelback_into_high_dimensions(): from poli import objective_factory - from poli.objective_repository.toy_continuous_problem.register import ( - ToyContinuousProblem, - ) problem = objective_factory.create( name="toy_continuous_problem", @@ -23,7 +28,7 @@ def test_embed_camelback_into_high_dimensions(): n_dimensions=2, embed_in=10, ) - f_camelback: ToyContinuousProblem = problem.black_box + f_camelback = cast(ToyContinuousBlackBox, problem.black_box) dimensions_to_embed_in = f_camelback.function.dimensions_to_embed_in @@ -36,8 +41,8 @@ def test_embed_camelback_into_high_dimensions(): another_x[0, dimensions_to_embed_in] = [0.0, 0.0] assert np.allclose( - f_camelback(one_x), - f_camelback(another_x), + f_camelback(one_x), # type: ignore + f_camelback(another_x), # type: ignore ) # Testing whether the output is different if we are @@ -45,6 +50,10 @@ def test_embed_camelback_into_high_dimensions(): one_x[0, dimensions_to_embed_in] = [1.0, 1.0] assert not np.allclose( - f_camelback(one_x), - f_camelback(another_x), + f_camelback(one_x), # type: ignore + f_camelback(another_x), # type: ignore ) + + +if __name__ == "__main__": + test_embed_camelback_into_high_dimensions() diff --git a/src/poli/tests/test_lambda_black_box.py b/src/poli/tests/test_lambda_black_box.py index 0482864d..fd21d432 100644 --- a/src/poli/tests/test_lambda_black_box.py +++ b/src/poli/tests/test_lambda_black_box.py @@ -1,4 +1,5 @@ import numpy as np +from numpy.typing import NDArray from poli.core.black_box_information import BlackBoxInformation from poli.core.lambda_black_box import LambdaBlackBox @@ -6,7 +7,7 @@ def test_lambda_black_box_works_without_custom_info(): - def f_(x: np.ndarray) -> np.ndarray: + def f_(x: NDArray[np.str_]) -> NDArray[np.float64]: return np.zeros((len(x), 1)) f = LambdaBlackBox( @@ -14,19 +15,19 @@ def f_(x: np.ndarray) -> np.ndarray: ) assert f.info is not None - assert (f(np.ones((10, 10))) == 0.0).all + assert (f(np.array([["a"] for _ in range(10)])) == 0.0).all() assert f.num_evaluations == 10 def test_lambda_black_box_with_custom_info_works(): - def f_(x: np.ndarray) -> np.ndarray: + def f_(x: NDArray[np.str_]) -> NDArray[np.float64]: return np.zeros((len(x), 1)) f = LambdaBlackBox( function=f_, info=BlackBoxInformation( name="zero", - max_sequence_length=np.inf, + max_sequence_length="inf", aligned=False, fixed_length=False, deterministic=True, @@ -38,19 +39,19 @@ def f_(x: np.ndarray) -> np.ndarray: assert f.info is not None assert f.info.name == "zero" assert f.info.deterministic - assert (f(np.ones((10, 10))) == 0.0).all + assert (f(np.array([["a"] for _ in range(10)])) == 0.0).all() assert f.num_evaluations == 10 def test_attaching_observer_to_lambda_black_box_works(): - def f_(x: np.ndarray) -> np.ndarray: + def f_(x: NDArray[np.str_]) -> NDArray[np.float64]: return np.zeros((len(x), 1)) f = LambdaBlackBox( function=f_, info=BlackBoxInformation( name="zero", - max_sequence_length=np.inf, + max_sequence_length="inf", aligned=False, fixed_length=False, deterministic=True, @@ -66,9 +67,9 @@ def f_(x: np.ndarray) -> np.ndarray: f.set_observer(observer) - f(np.ones((10, 10))) + f(np.array([["a"] for _ in range(10)])) - assert (np.array(observer.results[0]["x"]) == 1.0).all() + assert (np.array(observer.results[0]["x"]) == "a").all() assert (np.array(observer.results[0]["y"]) == 0.0).all() diff --git a/tox.ini b/tox.ini index 08760c0c..0bc05b53 100644 --- a/tox.ini +++ b/tox.ini @@ -26,13 +26,13 @@ commands = [testenv:lint] description = check the code style with black deps = - black - isort - ruff + pre-commit + pyright + pytest + -e. commands = - black --check --diff . - isort --profile black --check-only src/ - ruff check + sh -c "export PYTHONPATH=$(pwd)/src" + pre-commit run --all-files [testenv:poli-base-py310] description = run the tests with pytest on the base environment for poli @@ -99,6 +99,8 @@ deps= -r requirements.txt -e. commands= + sh -c "conda tos accept --override-channels --channel pkgs/main" + sh -c "conda tos accept --override-channels --channel pkgs/r" sh -c 'if conda info --envs | grep -q poli__rasp; then echo "poli__rasp already exists"; else conda env create -f ./src/poli/objective_repository/rasp/environment.yml; fi' sh -c "conda run -n poli__rasp python -m pip uninstall -y poli" sh -c "conda run -n poli__rasp python -m pip install -e ." @@ -113,6 +115,8 @@ deps= -r requirements.txt -e.[rmf] commands= + sh -c "conda tos accept --override-channels --channel pkgs/main" + sh -c "conda tos accept --override-channels --channel pkgs/r" pytest {tty:--color=yes} -v -m 'not slow and poli__rmf' {posargs} [testenv:poli-ehrlich-holo-py310] @@ -124,6 +128,8 @@ deps= -r requirements.txt -e.[ehrlich] commands= + sh -c "conda tos accept --override-channels --channel pkgs/main" + sh -c "conda tos accept --override-channels --channel pkgs/r" pytest {tty:--color=yes} -v -m 'not slow and poli__ehrlich_holo' {posargs} [testenv:poli-rosetta_energy-py310] @@ -135,6 +141,8 @@ deps= -r requirements.txt -e. commands= + sh -c "conda tos accept --override-channels --channel pkgs/main" + sh -c "conda tos accept --override-channels --channel pkgs/r" sh -c 'if conda info --envs | grep -q poli__rosetta_energy; then echo "poli__rosetta_energy already exists"; else conda env create -f ./src/poli/objective_repository/rosetta_energy/environment.yml; fi' sh -c "conda run -n poli__rosetta_energy python -m pip uninstall -y poli" sh -c "conda run -n poli__rosetta_energy python -m pip install -e ."