Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
exclude: '.*\.pdb$'
Expand All @@ -20,4 +20,12 @@ repos:
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
args: [ --fix ]
- repo: local
hooks:
- id: pyright
name: pyright
entry: pyright
language: system
require_serial: true
types: [python]
13 changes: 11 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ ehrlichholo = [
"pytorch-holo",
]
tdc = [
"pytdc",
"pytdc==1.1.14",
]
dockstring = [
"dockstring"
Expand All @@ -65,7 +65,7 @@ rosetta_energy = [
"biopython",
"pyrosetta-installer",
]
dev = ["black", "tox", "pytest", "bump-my-version"]
dev = ["black", "tox", "pytest", "bump-my-version", "pre-commit", "pyright"]
docs = ["sphinx", "furo"]

[project.urls]
Expand All @@ -83,9 +83,15 @@ markers = [
"poli__rosetta_energy: marks tests that run in poli__rosetta_energy",
"poli__ehrlich_holo: marks tests that run in poli__ehrlich_holo environment",
"poli__dms: marks tests that run in poli__dms environment",
"isolation: marks tests that require isolation of the black box function",
"unmarked: All other tests, which usually run in the base environment",
]

[tool.pyright]
include = ["src/poli"]
exclude = ["src/poli/core/util/proteins/rasp/inner_rasp", "src/poli/objective_repository/gfp_cbas", "examples", "src/poli/tests"]
reportIncompatibleMethodOverride = "none"

[tool.isort]
profile = "black"

Expand Down Expand Up @@ -148,4 +154,7 @@ replace = 'version: {new_version}'
[dependency-groups]
dev = [
"pre-commit>=4.2.0",
"pyright>=1.1.403",
"pytest>=8.4.0",
"ruff>=0.12.3",
]
2 changes: 1 addition & 1 deletion src/poli/benchmarks/guacamol.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _initialize_problem(self, index: int) -> Problem:
problem_factory = self.problem_factories[index]

problem = problem_factory.create(
string_representation=self.string_representation,
string_representation=self.string_representation, # type: ignore
seed=self.seed,
batch_size=self.batch_size,
parallelize=self.parallelize,
Expand Down
2 changes: 1 addition & 1 deletion src/poli/benchmarks/pmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
batch_size: Union[int, None] = None,
parallelize: bool = False,
num_workers: Union[int, None] = None,
evaluation_budget: int = None,
evaluation_budget: int | None = None,
) -> None:
super().__init__(
string_representation=string_representation,
Expand Down
25 changes: 16 additions & 9 deletions src/poli/benchmarks/toy_continuous_functions_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
https://www.sfu.ca/~ssurjano/optimization.html.
"""

from typing import List, Union
from typing import Sequence, Union, cast

from poli.core.abstract_benchmark import AbstractBenchmark
from poli.core.problem import Problem
from poli.objective_repository import ToyContinuousProblemFactory
from poli.objective_repository.toy_continuous_problem.toy_continuous_problem import (
POSSIBLE_FUNCTIONS,
POSSIBLE_FUNCTIONS_TYPE,
SIX_DIMENSIONAL_PROBLEMS,
TWO_DIMENSIONAL_PROBLEMS,
)
Expand Down Expand Up @@ -49,12 +50,12 @@ def __init__(
self,
n_dimensions: int = 2,
embed_in: Union[int, None] = None,
dimensions_to_embed_in: Union[List[int], None] = None,
dimensions_to_embed_in: Union[list[int], None] = None,
seed: Union[int, None] = None,
batch_size: Union[int, None] = None,
parallelize: bool = False,
num_workers: Union[int, None] = None,
evaluation_budget: Union[int, List[int]] = None,
evaluation_budget: int | None = None,
) -> None:
super().__init__(
seed=seed,
Expand All @@ -66,7 +67,7 @@ def __init__(
self.n_dimensions = n_dimensions
self.embed_in = embed_in
self.dimensions_to_embed_in = dimensions_to_embed_in
self.function_names = list(
self.function_names: Sequence[POSSIBLE_FUNCTIONS_TYPE] = list( # type: ignore
(
set(POSSIBLE_FUNCTIONS)
- set(TWO_DIMENSIONAL_PROBLEMS)
Expand All @@ -78,7 +79,9 @@ def __init__(
)

def _initialize_problem(self, index: int) -> Problem:
problem_factory: ToyContinuousProblemFactory = self.problem_factories[index]
problem_factory: ToyContinuousProblemFactory = cast(
ToyContinuousProblemFactory, self.problem_factories[index]
)

problem = problem_factory.create(
function_name=self.function_names[index],
Expand Down Expand Up @@ -121,7 +124,7 @@ def __init__(
batch_size: Union[int, None] = None,
parallelize: bool = False,
num_workers: Union[int, None] = None,
evaluation_budget: Union[int, List[int]] = None,
evaluation_budget: int | None = None,
) -> None:
super().__init__(
seed,
Expand All @@ -134,7 +137,9 @@ def __init__(
self.problem_factories = [ToyContinuousProblemFactory()] * len(self.embed_in)

def _initialize_problem(self, index: int) -> Problem:
problem_factory: ToyContinuousProblemFactory = self.problem_factories[index]
problem_factory: ToyContinuousProblemFactory = cast(
ToyContinuousProblemFactory, self.problem_factories[index]
)

problem = problem_factory.create(
function_name="branin_2d",
Expand Down Expand Up @@ -174,7 +179,7 @@ def __init__(
batch_size: Union[int, None] = None,
parallelize: bool = False,
num_workers: Union[int, None] = None,
evaluation_budget: Union[int, List[int]] = None,
evaluation_budget: int | None = None,
) -> None:
super().__init__(
seed,
Expand All @@ -187,7 +192,9 @@ def __init__(
self.problem_factories = [ToyContinuousProblemFactory()] * len(self.embed_in)

def _initialize_problem(self, index: int) -> Problem:
problem_factory: ToyContinuousProblemFactory = self.problem_factories[index]
problem_factory: ToyContinuousProblemFactory = cast(
ToyContinuousProblemFactory, self.problem_factories[index]
)

if index == 0:
problem = problem_factory.create(
Expand Down
8 changes: 4 additions & 4 deletions src/poli/core/abstract_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from typing import List, Union
from typing import Union

from poli.core.abstract_problem_factory import AbstractProblemFactory
from poli.core.problem import Problem


class AbstractBenchmark:
problem_factories: List[AbstractProblemFactory]
problem_factories: list[AbstractProblemFactory]
index: int = 0

def __init__(
Expand All @@ -16,7 +16,7 @@ def __init__(
batch_size: Union[int, None] = None,
parallelize: bool = False,
num_workers: Union[int, None] = None,
evaluation_budget: int = None,
evaluation_budget: int | None = None,
) -> None:
self.seed = seed
self.batch_size = batch_size
Expand Down Expand Up @@ -46,7 +46,7 @@ def info(self) -> str:
raise NotImplementedError

@property
def problem_names(self) -> List[str]:
def problem_names(self) -> list[str]:
return [
problem_factory.__module__.replace(
"poli.objective_repository.", ""
Expand Down
32 changes: 17 additions & 15 deletions src/poli/core/abstract_black_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from __future__ import annotations

from multiprocessing import Pool, cpu_count
from typing import cast
from warnings import warn

import numpy as np
from numpy.typing import NDArray

from poli.core.black_box_information import BlackBoxInformation
from poli.core.exceptions import BudgetExhaustedException
Expand All @@ -22,17 +24,17 @@ class AbstractBlackBox:

Parameters
----------
batch_size : int, optional
batch_size : int | None, optional
The batch size for evaluating the black box function. Default is None.
parallelize : bool, optional
Flag indicating whether to evaluate the black box function in parallel.
Default is False.
num_workers : int, optional
num_workers : int | None, optional
The number of workers to use for parallel evaluation. Default is None,
which uses half of the available CPU cores.
evaluation_budget : int, optional
evaluation_budget : int | None, optional
The maximum number of evaluations allowed for the black box function.
Default is None).
Default is None, which means an infinite budget.

Attributes
----------
Expand All @@ -44,7 +46,7 @@ class AbstractBlackBox:
Flag indicating whether to evaluate the black box function in parallel.
num_workers : int
The number of workers to use for parallel evaluation.
batch_size : int or None
batch_size : int | None
The batch size for evaluating the black box function.

Methods
Expand Down Expand Up @@ -84,13 +86,13 @@ def __init__(

Parameters
----------
batch_size : int, optional
batch_size : int | None, optional
The batch size for parallel execution, by default None.
parallelize : bool, optional
Flag indicating whether to parallelize the execution, by default False.
num_workers : int, optional
num_workers : int | None, optional
The number of workers for parallel execution, by default we use half the available CPUs.
evaluation_budget : int, optional
evaluation_budget : int | None, optional
The maximum number of evaluations allowed for the black box function, by default it is None, which means no limit.
"""
self.observer = None
Expand Down Expand Up @@ -145,13 +147,13 @@ def set_observer(self, observer: AbstractObserver):
)
self.observer = observer

def set_observer_info(self, observer_info: object):
def set_observer_info(self, observer_info: dict[str, object] | None):
"""
Set the observer information after initialization.

Parameters
----------
observer_info : object
observer_info : dict[str, object]
The information given by the observer after initialization.
"""
self.observer_info = observer_info
Expand All @@ -160,7 +162,7 @@ def reset_evaluation_budget(self):
"""Resets the evaluation budget by setting the number of evaluations made to 0."""
self.num_evaluations = 0

def __call__(self, x: np.array, context=None):
def __call__(self, x: NDArray[np.str_], context=None):
"""Calls the black box function.

The purpose of this function is to enforce that inputs are equal across
Expand Down Expand Up @@ -340,7 +342,7 @@ def terminate(self) -> None:
Terminate the black box optimization problem.
"""
if hasattr(self, "inner_function"):
self.inner_function.terminate()
self.inner_function.terminate() # type: ignore
# if self.observer is not None:
# # NOTE: terminating a problem should gracefully end the observer process -> write the last state.
# self.observer.finish()
Expand Down Expand Up @@ -387,13 +389,13 @@ def __init__(self, f: AbstractBlackBox):
batch_size=f.batch_size,
parallelize=f.parallelize,
num_workers=f.num_workers,
evaluation_budget=f.evaluation_budget,
evaluation_budget=cast(int | None, f.evaluation_budget),
)

def __call__(self, x, context=None):
def __call__(self, x: NDArray[np.str_], context=None):
return -self.f.__call__(x, context)

def _black_box(self, x, context=None):
def _black_box(self, x: NDArray[np.str_], context=None):
return self.f._black_box(x, context)

def __str__(self) -> str:
Expand Down
8 changes: 4 additions & 4 deletions src/poli/core/abstract_problem_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ class AbstractProblemFactory(metaclass=MetaProblemFactory):

def create(
self,
seed: int = None,
batch_size: int = None,
seed: int | None = None,
batch_size: int | None = None,
parallelize: bool = False,
num_workers: int = None,
evaluation_budget: int = None,
num_workers: int | None = None,
evaluation_budget: int | None = None,
force_isolation: bool = False,
) -> Problem:
"""
Expand Down
4 changes: 2 additions & 2 deletions src/poli/core/benchmark_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(
fixed_length: bool,
deterministic: bool,
alphabet: list,
log_transform_recommended: bool = None,
log_transform_recommended: bool | None = None,
discrete: bool = True,
fidelity: Union[Literal["high", "low"], None] = None,
padding_token: str = "",
Expand Down Expand Up @@ -111,7 +111,7 @@ def get_alphabet(self) -> list:
"""
return self.alphabet

def log_transform_recommended(self) -> bool:
def is_log_transform_recommended(self) -> bool | None:
"""
Returns whether the black-box recommends log-transforming the targets.

Expand Down
Loading