diff --git a/.gitignore b/.gitignore index 43bd95c2b..5034957b4 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,7 @@ deploy_key temp_*.* .python-version .nox +.venv ### Visual Studio Code ### !.vscode/settings.json @@ -37,4 +38,4 @@ temp_*.* .LSOverride .vscode -.idea \ No newline at end of file +.idea diff --git a/kernel_tuner/__init__.py b/kernel_tuner/__init__.py index b64d69813..3f575faa0 100644 --- a/kernel_tuner/__init__.py +++ b/kernel_tuner/__init__.py @@ -1,5 +1,5 @@ from kernel_tuner.integration import store_results, create_device_targets -from kernel_tuner.interface import tune_kernel, run_kernel +from kernel_tuner.interface import tune_kernel, tune_cache, run_kernel from importlib.metadata import version diff --git a/kernel_tuner/core.py b/kernel_tuner/core.py index 655779337..1cd47d297 100644 --- a/kernel_tuner/core.py +++ b/kernel_tuner/core.py @@ -480,11 +480,14 @@ def benchmark(self, func, gpu_args, instance, verbose, objective, skip_nvml_sett print( f"skipping config {util.get_instance_string(instance.params)} reason: too many resources requested for launch" ) - result[objective] = util.RuntimeFailedConfig() + result['__error__'] = util.RuntimeFailedConfig() else: logging.debug("benchmark encountered runtime failure: " + str(e)) print("Error while benchmarking:", instance.name) raise e + + assert util.check_result_type(result), "The error in a result MUST be an actual error." + return result def check_kernel_output( @@ -571,7 +574,7 @@ def compile_and_benchmark(self, kernel_source, gpu_args, params, kernel_options, instance = self.create_kernel_instance(kernel_source, kernel_options, params, verbose) if isinstance(instance, util.ErrorConfig): - result[to.objective] = util.InvalidConfig() + result['__error__'] = util.InvalidConfig() else: # Preprocess the argument list. This is required to deal with `MixedPrecisionArray`s gpu_args = _preprocess_gpu_arguments(gpu_args, params) @@ -581,7 +584,7 @@ def compile_and_benchmark(self, kernel_source, gpu_args, params, kernel_options, start_compilation = time.perf_counter() func = self.compile_kernel(instance, verbose) if not func: - result[to.objective] = util.CompilationFailedConfig() + result['__error__'] = util.CompilationFailedConfig() else: # add shared memory arguments to compiled module if kernel_options.smem_args is not None: @@ -635,6 +638,8 @@ def compile_and_benchmark(self, kernel_source, gpu_args, params, kernel_options, result["verification_time"] = last_verification_time or 0 result["benchmark_time"] = last_benchmark_time or 0 + assert util.check_result_type(result), "The error in a result MUST be an actual error." + return result def compile_kernel(self, instance, verbose): diff --git a/kernel_tuner/file_utils.py b/kernel_tuner/file_utils.py index e5d3dcb90..9d7b7042c 100644 --- a/kernel_tuner/file_utils.py +++ b/kernel_tuner/file_utils.py @@ -32,20 +32,20 @@ def output_file_schema(target): return current_version, json_string -def get_configuration_validity(objective) -> str: +def get_configuration_validity(error) -> str: """Convert internal Kernel Tuner error to string.""" errorstring: str - if not isinstance(objective, util.ErrorConfig): + if not isinstance(error, util.ErrorConfig): errorstring = "correct" else: - if isinstance(objective, util.CompilationFailedConfig): + if isinstance(error, util.CompilationFailedConfig): errorstring = "compile" - elif isinstance(objective, util.RuntimeFailedConfig): + elif isinstance(error, util.RuntimeFailedConfig): errorstring = "runtime" - elif isinstance(objective, util.InvalidConfig): + elif isinstance(error, util.InvalidConfig): errorstring = "constraints" else: - raise ValueError(f"Unkown objective type {type(objective)}, value {objective}") + raise ValueError(f"Unkown error type {type(error)}, value {error}") return errorstring @@ -110,7 +110,8 @@ def store_output_file(output_filename: str, results, tune_params, objective="tim out["times"] = timings # encode the validity of the configuration - out["invalidity"] = get_configuration_validity(result[objective]) + # out["invalidity"] = get_configuration_validity(result[objective]) + out["invalidity"] = get_configuration_validity(result['__error__']) # Kernel Tuner does not support producing results of configs that fail the correctness check # therefore correctness is always 1 @@ -127,7 +128,10 @@ def store_output_file(output_filename: str, results, tune_params, objective="tim # In Kernel Tuner we currently support only one objective at a time, this can be a user-defined # metric that combines scores from multiple different quantities into a single value to support # multi-objective tuning however. - out["objectives"] = [objective] + # NOTE(maric): With PyMOO integrated we do support multi-objective tuning without scalarization + objectives = [objective] if isinstance(objective, str) else list(objective) + assert isinstance(objectives, list) + out["objectives"] = objectives # append to output output_data.append(out) diff --git a/kernel_tuner/interface.py b/kernel_tuner/interface.py index 97ae22848..c13a3d0a1 100644 --- a/kernel_tuner/interface.py +++ b/kernel_tuner/interface.py @@ -57,6 +57,7 @@ pso, random_sample, simulated_annealing, + pymoo_minimize, ) strategy_map = { @@ -75,6 +76,8 @@ "simulated_annealing": simulated_annealing, "firefly_algorithm": firefly_algorithm, "bayes_opt": bayes_opt, + "nsga2": pymoo_minimize, + "nsga3": pymoo_minimize, } @@ -425,7 +428,7 @@ def __deepcopy__(self, _): """Optimization objective to sort results on, consisting of a string that also occurs in results as a metric or observed quantity, default 'time'. Please see :ref:`objectives`.""", - "string", + "str | list[str]", ), ), ( @@ -433,7 +436,7 @@ def __deepcopy__(self, _): ( """boolean that specifies whether the objective should be maximized (True) or minimized (False), default False.""", - "bool", + "bool | list[bool]", ), ), ( @@ -464,6 +467,7 @@ def __deepcopy__(self, _): ("metrics", ("specifies user-defined metrics, please see :ref:`metrics`.", "dict")), ("simulation_mode", ("Simulate an auto-tuning search from an existing cachefile", "bool")), ("observers", ("""A list of Observers to use during tuning, please see :ref:`observers`.""", "list")), + ("seed", ("""The random seed.""", "int")), ] ) @@ -577,6 +581,8 @@ def tune_kernel( observers=None, objective=None, objective_higher_is_better=None, + objectives=None, + seed=None, ): start_overhead_time = perf_counter() if log: @@ -586,8 +592,20 @@ def tune_kernel( _check_user_input(kernel_name, kernelsource, arguments, block_size_names) - # default objective if none is specified - objective, objective_higher_is_better = get_objective_defaults(objective, objective_higher_is_better) + if objectives: + if isinstance(objectives, dict): + objective = list(objectives.keys()) + objective_higher_is_better = list(objectives.values()) + else: + raise ValueError("objectives should be a dict of (objective, higher_is_better) pairs") + else: + objective, objective_higher_is_better = get_objective_defaults(objective, objective_higher_is_better) + objective = [objective] + objective_higher_is_better = [objective_higher_is_better] + + assert isinstance(objective, list) + assert isinstance(objective_higher_is_better, list) + assert len(list(objective)) == len(list(objective_higher_is_better)) # check for forbidden names in tune parameters util.check_tune_params_list(tune_params, observers, simulation_mode=simulation_mode) @@ -682,13 +700,33 @@ def tune_kernel( # finished iterating over search space if results: # checks if results is not empty - best_config = util.get_best_config(results, objective, objective_higher_is_better) - # add the best configuration to env - env['best_config'] = best_config - if not device_options.quiet: - units = getattr(runner, "units", None) - print("best performing configuration:") - util.print_config_output(tune_params, best_config, device_options.quiet, metrics, units) + if len(list(objective)) == 1: + objective = objective[0] + objective_higher_is_better = objective_higher_is_better[0] + best_config = util.get_best_config(results, objective, objective_higher_is_better) + # add the best configuration to env + env['best_config'] = best_config + if not device_options.quiet: + units = getattr(runner, "units", None) + print(f"\nBEST PERFORMING CONFIGURATION FOR OBJECTIVE {objective}:") + keys = list(tune_params.keys()) + keys += [objective] + if metrics: + keys += list(metrics.keys()) + print(util.get_config_string(best_config, keys, units)) + else: + pareto_front = util.get_pareto_results(results, objective, objective_higher_is_better) + # add the best configuration to env + env['best_config'] = pareto_front + if not device_options.quiet: + units = getattr(runner, "units", None) + keys = list(tune_params.keys()) + keys += list(objective) + if metrics: + keys += list(metrics.keys) + print(f"\nBEST PERFORMING CONFIGURATIONS FOR OBJECTIVES: {objective}:") + for best_config in pareto_front: + print(util.get_config_string(best_config, keys, units)) elif not device_options.quiet: print("no results to report") @@ -703,6 +741,28 @@ def tune_kernel( tune_kernel.__doc__ = _tune_kernel_docstring + +def tune_cache(*, + cache_path, + restrictions = None, + **kwargs, +): + cache = util.read_cache(cache_path, open_cache=False) + tune_args = util.infer_args_from_cache(cache) + _restrictions = [util.infer_restrictions_from_cache(cache)] + + # Add the user provided restrictions + if restrictions: + if isinstance(restrictions, list): + _restrictions.extend(restrictions) + else: + raise ValueError("The restrictions must be a list()") + + tune_args.update(kwargs) + + return tune_kernel(**tune_args, cache=cache_path, restrictions=_restrictions, simulation_mode=True) + + _run_kernel_docstring = """Compile and run a single kernel Compiles and runs a single kernel once, given a specific instance of the kernels tuning parameters. diff --git a/kernel_tuner/runners/sequential.py b/kernel_tuner/runners/sequential.py index aeebd5116..dae34a3c8 100644 --- a/kernel_tuner/runners/sequential.py +++ b/kernel_tuner/runners/sequential.py @@ -5,6 +5,7 @@ from kernel_tuner.core import DeviceInterface from kernel_tuner.runners.runner import Runner +import kernel_tuner.util as util from kernel_tuner.util import ErrorConfig, print_config_output, process_metrics, store_cache @@ -44,8 +45,15 @@ def __init__(self, kernel_source, kernel_options, device_options, iterations, ob #move data to the GPU self.gpu_args = self.dev.ready_argument_list(kernel_options.arguments) + # It is the task of the cost function to increment there counters + self.config_eval_count = 0 + self.infeasable_config_eval_count = 0 + def get_environment(self, tuning_options): - return self.dev.get_environment() + env = self.dev.get_environment() + env["config_eval_count"] = self.config_eval_count + env["infeasable_config_eval_count"] = self.infeasable_config_eval_count + return env def run(self, parameter_space, tuning_options): """Iterate through the entire parameter space using a single Python process. @@ -90,17 +98,19 @@ def run(self, parameter_space, tuning_options): result = self.dev.compile_and_benchmark(self.kernel_source, self.gpu_args, params, self.kernel_options, tuning_options) + assert util.check_result_type(result) + params.update(result) - if tuning_options.objective in result and isinstance(result[tuning_options.objective], ErrorConfig): + if '__error__' in result: logging.debug('kernel configuration was skipped silently due to compile or runtime failure') # only compute metrics on configs that have not errored - if tuning_options.metrics and not isinstance(params.get(tuning_options.objective), ErrorConfig): + if tuning_options.metrics and '__error__' not in params: params = process_metrics(params, tuning_options.metrics) # get the framework time by estimating based on other times - total_time = 1000 * ((perf_counter() - self.start_time) - warmup_time) + total_time = 1000 * ((perf_counter() - self.start_time) - warmup_time) params['strategy_time'] = self.last_strategy_time params['framework_time'] = max(total_time - (params['compile_time'] + params['verification_time'] + params['benchmark_time'] + params['strategy_time']), 0) params['timestamp'] = str(datetime.now(timezone.utc)) @@ -113,6 +123,8 @@ def run(self, parameter_space, tuning_options): # add configuration to cache store_cache(x_int, params, tuning_options) + assert util.check_result_type(params) + # all visited configurations are added to results to provide a trace for optimization strategies results.append(params) diff --git a/kernel_tuner/runners/simulation.py b/kernel_tuner/runners/simulation.py index 22c7c667c..cd181288a 100644 --- a/kernel_tuner/runners/simulation.py +++ b/kernel_tuner/runners/simulation.py @@ -47,7 +47,8 @@ def __init__(self, kernel_source, kernel_options, device_options, iterations, ob :type iterations: int """ self.quiet = device_options.quiet - self.dev = SimulationDevice(1024, dict(device_name="Simulation"), self.quiet) + # NOTE(maric): had to increase max_threas so the default restraints would pass + self.dev = SimulationDevice(1_000_000_000, dict(device_name="Simulation"), self.quiet) self.kernel_source = kernel_source self.simulation_mode = True @@ -58,10 +59,16 @@ def __init__(self, kernel_source, kernel_options, device_options, iterations, ob self.last_strategy_time = 0 self.units = {} + # It is the task of the cost function to increment there counters + self.config_eval_count = 0 + self.infeasable_config_eval_count = 0 + def get_environment(self, tuning_options): env = self.dev.get_environment() env["simulation"] = True env["simulated_time"] = tuning_options.simulated_time + env["config_eval_count"] = self.config_eval_count + env["infeasable_config_eval_count"] = self.infeasable_config_eval_count return env def run(self, parameter_space, tuning_options): diff --git a/kernel_tuner/strategies/common.py b/kernel_tuner/strategies/common.py index d01eae937..1901476f2 100644 --- a/kernel_tuner/strategies/common.py +++ b/kernel_tuner/strategies/common.py @@ -72,6 +72,8 @@ def __call__(self, x, check_restrictions=True): # check if max_fevals is reached or time limit is exceeded util.check_stop_criterion(self.tuning_options) + self.runner.config_eval_count += 1 + # snap values in x to nearest actual value for each parameter, unscale x if needed if self.snap: if self.scaling: @@ -92,9 +94,12 @@ def __call__(self, x, check_restrictions=True): legal = util.check_restrictions(self.searchspace.restrictions, params_dict, self.tuning_options.verbose) if not legal: result = params_dict - result[self.tuning_options.objective] = util.InvalidConfig() + result['__error__'] = util.InvalidConfig() + self.runner.infeasable_config_eval_count += 1 if legal: + assert ('__error__' not in result), "A legal config MUST NOT have an error result." + # compile and benchmark this instance res = self.runner.run([params], self.tuning_options) result = res[0] @@ -108,11 +113,17 @@ def __call__(self, x, check_restrictions=True): # upon returning from this function control will be given back to the strategy, so reset the start time self.runner.last_strategy_start_time = perf_counter() - # get numerical return value, taking optimization direction into account - return_value = result[self.tuning_options.objective] or sys.float_info.max - return_value = return_value if not self.tuning_options.objective_higher_is_better else -return_value + # get the cost of the result + cost_vec = util.get_result_cost( + result, + self.tuning_options.objective, + self.tuning_options.objective_higher_is_better + ) - return return_value + if len(cost_vec) == 1: + return cost_vec[0] + else: + return cost_vec def get_bounds_x0_eps(self): """Compute bounds, x0 (the initial guess), and eps.""" diff --git a/kernel_tuner/strategies/pymoo_minimize.py b/kernel_tuner/strategies/pymoo_minimize.py new file mode 100644 index 000000000..0ac530941 --- /dev/null +++ b/kernel_tuner/strategies/pymoo_minimize.py @@ -0,0 +1,274 @@ +"""The Pymoo strategy that uses a minimizer method for searching through the parameter space.""" + +from typing import assert_never +import numpy as np + +from pymoo.algorithms.moo.nsga2 import NSGA2 +from pymoo.algorithms.moo.nsga3 import NSGA3 +from pymoo.core.algorithm import Algorithm +from pymoo.core.problem import ElementwiseProblem +from pymoo.core.duplicate import ElementwiseDuplicateElimination +from pymoo.core.termination import NoTermination, Termination +from pymoo.core.sampling import Sampling +from pymoo.core.mutation import Mutation +from pymoo.core.repair import Repair +from pymoo.operators.crossover.pntx import TwoPointCrossover + +from kernel_tuner import util +from kernel_tuner.runners.runner import Runner +from kernel_tuner.searchspace import Searchspace +from kernel_tuner.strategies.common import ( + CostFunc, + get_strategy_docstring, +) + +from enum import StrEnum + +class SupportedAlgos(StrEnum): + NSGA2 = "nsga2" + NSGA3 = "nsga3" + +supported_algos = [ algo.value for algo in SupportedAlgos ] + +supported_crossover_opers = [ + # "uniform-crossover", + # "single-point-crossover", + "two-point-crossover", +] + +_options = { + "pop_size": ("Initial population size", 20), + "crossover_operator": ("The crossover operator", "two-point-crossover"), + "crossover_prob": ("Crossover probability", 1.0), + "mutation_prob": ("Mutation probability", 0.1), + "ref_dirs_list": ("The list of reference directions on the unit hyperplane in the objective space to guide NSGA-III, see https://pymoo.org/misc/reference_directions.html for more information.", []), +} + +_option_defaults = { key: option_pair[1] for key, option_pair in _options.items() } + + +def tune( + searchspace: Searchspace, + runner: Runner, + tuning_options, +): + algo_name: str = tuning_options.strategy + strategy_options = tuning_options.strategy_options + + algo_name = algo_name.lower() + if algo_name not in SupportedAlgos: + raise ValueError(f"\"{algo_name}\" is not supported. The supported algorithms are: {supported_algos}\n") + else: + algo_name = SupportedAlgos(algo_name) + + pop_size = strategy_options.get("pop_size", _option_defaults["pop_size"]) + crossover_prob = strategy_options.get("crossover_prob", _option_defaults["crossover_prob"]) + mutation_prob = strategy_options.get("mutation_prob", _option_defaults["mutation_prob"]) + ref_dirs_list = strategy_options.get("ref_dirs_list", _option_defaults["ref_dirs_list"]) + + if algo_name == "nsga3" and len(ref_dirs_list) == 0: + raise ValueError("NSGA-III requires reference directions to be specified, but they are missing.") + + cost_func = CostFunc(searchspace, tuning_options, runner, scaling=False) + + problem = TuningProblem( + cost_func = cost_func, + n_var = len(tuning_options.tune_params), + n_obj = len(tuning_options.objective), + ) + + sampling = TuningSearchspaceRandomSampling(searchspace) + crossover = TwoPointCrossover(prob = crossover_prob) + mutation = TuningParamConfigNeighborhoodMutation(prob = mutation_prob, searchspace = searchspace) + repair = TuningParamConfigRepair() + eliminate_duplicates = TuningParamConfigDuplicateElimination() + + # algorithm_type = get_algorithm(method) + algo: Algorithm + match algo_name: + case SupportedAlgos.NSGA2: + algo = NSGA2( + pop_size = pop_size, + sampling = sampling, + crossover = crossover, + mutation = mutation, + repair = repair, + eliminate_duplicates = eliminate_duplicates, + ) + case SupportedAlgos.NSGA3: + algo = NSGA3( + pop_size = pop_size, + ref_dirs = ref_dirs_list, + sampling = sampling, + crossover = crossover, + mutation = mutation, + repair = repair, + eliminate_duplicates = eliminate_duplicates, + ) + case _ as unreachable: + assert_never(unreachable) + + # TODO: + # - CostFunc throws exception when done, so isn't really needed + termination = None + if "max_fevals" in tuning_options.strategy_options or "time_limit" in tuning_options.strategy_options: + termination = NoTermination() + + try: + algo.setup( + problem, + termination = termination, + verbose = tuning_options.verbose, + progress = tuning_options.verbose, + seed = tuning_options.seed, + ) + + while algo.has_next(): + algo.next() + + except util.StopCriterionReached as e: + if tuning_options.verbose: + print(f"Stopped because of {e}") + + results = cost_func.results + + if results and tuning_options.verbose: + print(f"{results.message=}") + + return results + + +tune.__doc__ = get_strategy_docstring("Pymoo minimize", _options) + + +class TuningProblem(ElementwiseProblem): + def __init__( + self, + cost_func: CostFunc, + n_var: int, + n_obj: int, + **kwargs, + ): + super().__init__( + n_var = n_var, + n_obj = n_obj, + **kwargs, + ) + self.cost_func = cost_func + self.searchspace = cost_func.searchspace + self.tuning_options = cost_func.tuning_options + + def _evaluate( self, x, out, *args, **kwargs, ): + # A copy of `x` is made to make sure sharing does not happen + F = self.cost_func(tuple(x)) + out["F"] = F + + def _calc_pareto_front( self, *args, **kwargs, ): + # Can only compute the pareto front if we are in simulation mode. + if not self.tuning_options.simulation_mode: + return None + + objectives = self.tuning_options.objective + higher_is_better = self.tuning_options.objective_higher_is_better + pareto_results = util.get_pareto_results( + list(self.tuning_options.cache.values()), + objectives, + higher_is_better, + ) + + pareto_front_list = list() + for res in pareto_results: + cost = util.get_result_cost(res, objectives, higher_is_better) + pareto_front_list.append(cost) + + return np.array(pareto_front_list, dtype=float) + + +class TuningTermination(Termination): + def __init__( self, tuning_options, ): + super().__init__() + self.tuning_options = tuning_options + self.reason = None + + def _update( + self, + algorithm, + ): + try: + util.check_stop_criterion(self.tuning_options) + print(f"progress: {len(self.tuning_options.unique_results) / self.tuning_options.max_fevals}") + return 0.0 + except util.StopCriterionReached as e: + self.terminate() + self.reason = e + return 1.0 + + +class TuningSearchspaceRandomSampling(Sampling): + def __init__( self, searchspace, ): + super().__init__() + self.searchspace = searchspace + + def _do( self, problem, n_samples: int, **kwargs, ): + sample = self.searchspace.get_random_sample(n_samples) + return np.array(sample, dtype=object) + + +class TuningParamConfigNeighborhoodMutation(Mutation): + def __init__( + self, + prob, + searchspace: Searchspace, + **kwargs + ): + super().__init__( + prob = prob, + # prob_var = None, + **kwargs, + ) + self.searchspace = searchspace + + def _do( + self, + problem: TuningProblem, + X: np.ndarray, + **kwargs, + ): + for X_index in range(X.shape[0]): + params_config_tuple = tuple(X[X_index]) + neighbors_indices = self.searchspace.get_neighbors_indices_no_cache(params_config_tuple, neighbor_method="Hamming") + if len(neighbors_indices) > 0: + neighbor_index = neighbors_indices[np.random.choice(len(neighbors_indices))] + neighbor = self.searchspace.get_param_configs_at_indices([neighbor_index])[0] + X[X_index] = np.array(neighbor, dtype=object) + + return X + + +class TuningParamConfigRepair(Repair): + + def _do( + self, + problem: TuningProblem, + X: np.ndarray, + **kwargs, + ): + for X_index in range(X.shape[0]): + params_config_tuple = tuple(X[X_index]) + if problem.searchspace.is_param_config_valid(params_config_tuple): + continue + for neighbor_method in ["strictly-adjacent", "adjacent", "Hamming"]: + neighbors_indices = problem.searchspace.get_neighbors_indices_no_cache(params_config_tuple, neighbor_method) + if len(neighbors_indices) > 0: + neighbor_index = neighbors_indices[np.random.choice(len(neighbors_indices))] + neighbor = problem.searchspace.get_param_configs_at_indices([neighbor_index])[0] + X[X_index] = np.array(neighbor, dtype=object) + break + + return X + + +class TuningParamConfigDuplicateElimination(ElementwiseDuplicateElimination): + + def is_equal(self, a, b): + return np.all(a.X == b.X) diff --git a/kernel_tuner/util.py b/kernel_tuner/util.py index 710b59e0d..fc2c941ed 100644 --- a/kernel_tuner/util.py +++ b/kernel_tuner/util.py @@ -43,6 +43,8 @@ from kernel_tuner.observers.nvml import NVMLObserver +from pymoo.util.nds.find_non_dominated import find_non_dominated + # number of special values to insert when a configuration cannot be measured @@ -79,6 +81,32 @@ def default(self, obj): return super(NpEncoder, self).default(obj) +def get_result_cost( + result: dict, + objectives: list[str], + objective_higher_is_better: list[bool] +) -> list[float]: + """Returns the cost of a result, taking the objective directions into account.""" + # return the highest cost for invalid results + if '__error__' in result: + return [sys.float_info.max] * len(objectives) + + cost_vec = list() + for objective, is_maximizer in zip(objectives, objective_higher_is_better): + objective_value = result[objective] + cost = -objective_value if is_maximizer else objective_value + cost_vec.append(cost) + + return cost_vec + + +def check_result_type(r): + """Check if the result has the right format.""" + if '__error__' in r: + return isinstance(r['__error__'], ErrorConfig) + return True + + class TorchPlaceHolder: def __init__(self): self.Tensor = Exception # using Exception here as a type that will never be among kernel arguments @@ -191,10 +219,20 @@ def check_argument_list(kernel_name, kernel_string, args): def check_stop_criterion(to): """Checks if max_fevals is reached or time limit is exceeded.""" - if "max_fevals" in to and len(to.unique_results) >= to.max_fevals: - raise StopCriterionReached("max_fevals reached") - if "time_limit" in to and (((time.perf_counter() - to.start_time) + (to.simulated_time * 1e-3)) > to.time_limit): - raise StopCriterionReached("time limit exceeded") + if "max_fevals" in to: + if to.verbose: + print(f"Progress: {len(to.unique_results)/to.max_fevals}") + if len(to.unique_results) >= to.max_fevals: + raise StopCriterionReached("max_fevals reached") + if "time_limit" in to: + # if to.verbose: + # print(f"Progress: {((time.perf_counter() - to.start_time) + (to.simulated_time * 1e-3)) / to.time_limit}") + # if (((time.perf_counter() - to.start_time) + (to.simulated_time * 1e-3)) > to.time_limit): + # raise StopCriterionReached("time limit exceeded") + if to.verbose: + print(f"Progress: {((time.perf_counter() - to.start_time)) / to.time_limit}") + if (((time.perf_counter() - to.start_time)) > to.time_limit): + raise StopCriterionReached("time limit exceeded") def check_tune_params_list(tune_params, observers, simulation_mode=False): @@ -244,8 +282,11 @@ def check_block_size_params_names_list(block_size_names, tune_params): def check_restriction(restrict, params: dict) -> bool: """Check whether a configuration meets a search space restriction.""" + # if it's a function python-constraint it can be called directly + if isinstance(restrict, FunctionConstraint): + return restrict._func(*params.values()) # if it's a python-constraint, convert to function and execute - if isinstance(restrict, Constraint): + elif isinstance(restrict, Constraint): restrict = convert_constraint_restriction(restrict) return restrict(list(params.values())) # if it's a string, fill in the parameters and evaluate @@ -393,11 +434,40 @@ def get_best_config(results, objective, objective_higher_is_better=False): ignore_val = sys.float_info.max if not objective_higher_is_better else -sys.float_info.max best_config = func( results, - key=lambda x: x[objective] if isinstance(x[objective], float) else ignore_val, + key=lambda x: x[objective] if '__error__' not in x and isinstance(x[objective], float) else ignore_val, ) return best_config +def get_pareto_results( + results: list[dict], + objectives: list[str], + objective_higher_is_better: list[bool], + mark_optima=True +): + assert isinstance(results, list) + assert isinstance(objectives, list) + + n_rows = len(results) + n_cols = len(objectives) + Y = np.empty((n_rows, n_cols), dtype=float) + for row_idx, result in enumerate(results): + if "__error__" in result: + Y[row_idx, :] = sys.float_info.max + continue + for col_idx, (objective_name, higher_is_better) in enumerate(zip(objectives, objective_higher_is_better)): + y = result[objective_name] + # negate for maximizers to optimize through minimization + Y[row_idx, col_idx] = -y if higher_is_better else y + + pf_indices = find_non_dominated(Y) + pf = [results[idx] for idx in pf_indices] + if mark_optima: + for p in pf: + p["optimal"] = True + return pf + + def get_config_string(params, keys=None, units=None): """Return a compact string representation of a measurement.""" @@ -1158,7 +1228,8 @@ def process_cache(cache, kernel_options, tuning_options, runner): # if file exists else: - cached_data = read_cache(cache) + # cached_data = read_cache(cache) + cached_data = read_cache(cache, open_cache=(not runner.simulation_mode)) # if in simulation mode, use the device name from the cache file as the runner device name if runner.simulation_mode: @@ -1288,3 +1359,31 @@ def cuda_error_check(error): if error != nvrtc.nvrtcResult.NVRTC_SUCCESS: _, desc = nvrtc.nvrtcGetErrorString(error) raise RuntimeError(f"NVRTC error: {desc.decode()}") + + +def infer_restrictions_from_cache(cache: dict): + param_names = cache["tune_params_keys"] + valid_param_config_set = set( + tuple(result[param_name] for param_name in param_names) + for result in cache['cache'].values() + if '__error__' not in result + ) + + def restrictions_func(*param_values) -> bool: + nonlocal valid_param_config_set + return param_values in valid_param_config_set + + return FunctionConstraint(restrictions_func) + + +def infer_args_from_cache(cache: dict) -> dict: + inferred_args = dict( + kernel_name = cache['kernel_name'], + kernel_source = "", + problem_size = tuple(cache['problem_size']), + arguments = [], + tune_params = cache['tune_params'], + # restrictions = infer_restrictions_from_cache(cache), + ) + + return inferred_args diff --git a/noxfile.py b/noxfile.py index e32bbb588..fe26ef1d4 100644 --- a/noxfile.py +++ b/noxfile.py @@ -15,7 +15,7 @@ # set the test parameters verbose = False -python_versions_to_test = ["3.9", "3.10", "3.11", "3.12"] +python_versions_to_test = ["3.10", "3.11", "3.12"] nox.options.stop_on_first_error = True nox.options.error_on_missing_interpreters = True nox.options.default_venv_backend = 'virtualenv' @@ -38,7 +38,7 @@ def create_settings(session: Session) -> None: venvbackend = nox.options.default_venv_backend envdir = "" # conversion from old notenv.txt - if noxenv_file_path.exists(): + if noxenv_file_path.exists(): venvbackend = noxenv_file_path.read_text().strip() noxenv_file_path.unlink() # write the settings @@ -91,7 +91,7 @@ def check_development_environment(session: Session) -> None: # packages = re.findall(r"• Installing .* | • Updating .*", output, flags=re.MULTILINE) # assert packages is not None session.warn(f""" - Your development environment is out of date ({installs} installs, {updates} updates). + Your development environment is out of date ({installs} installs, {updates} updates). Update with 'poetry install --sync', using '--with' and '-E' for optional dependencies, extras respectively. Note: {removals} packages are not in the specification (i.e. installed manually) and may be removed. To preview changes, run 'poetry install --sync --dry-run' (with optional dependencies and extras).""") diff --git a/pyproject.toml b/pyproject.toml index 48034bf15..2791b929c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ generate-setup-file = false # ATTENTION: if anything is changed here, run `poetry update` [tool.poetry.dependencies] -python = ">=3.9,<3.13" # NOTE when changing the supported Python versions, also change the test versions in the noxfile +python = ">=3.10,<3.13" # NOTE when changing the supported Python versions, also change the test versions in the noxfile numpy = "^1.26.0" # Python 3.12 requires numpy at least 1.26 scipy = ">=1.11.0" packaging = "*" # required by file_utils @@ -84,6 +84,7 @@ hip-python = { version = "*", optional = true } # Tutorial (for the notebooks used in the examples) jupyter = { version = "^1.0.0", optional = true } matplotlib = { version = "^3.5.0", optional = true } +pymoo = "^0.6.1.6" [tool.poetry.extras] cuda = ["pycuda", "nvidia-ml-py", "pynvml"] diff --git a/test/strategies/test_common.py b/test/strategies/test_common.py index 29ead8615..945290494 100644 --- a/test/strategies/test_common.py +++ b/test/strategies/test_common.py @@ -19,6 +19,8 @@ def fake_runner(): runner = Mock() runner.last_strategy_start_time = perf_counter() runner.run.return_value = [fake_result] + runner.config_eval_count = 0 + runner.infeasable_config_eval_count = 0 return runner @@ -29,7 +31,7 @@ def test_cost_func(): x = [1, 4] tuning_options = Options(scaling=False, snap=False, tune_params=tune_params, restrictions=None, strategy_options={}, cache={}, unique_results={}, - objective="time", objective_higher_is_better=False, metrics=None) + objective=["time"], objective_higher_is_better=[False], metrics=None) runner = fake_runner() time = CostFunc(Searchspace(tune_params, None, 1024), tuning_options, runner)(x) @@ -41,7 +43,7 @@ def restrictions(_): tuning_options = Options(scaling=False, snap=False, tune_params=tune_params, restrictions=restrictions, strategy_options={}, verbose=True, cache={}, unique_results={}, - objective="time", objective_higher_is_better=False, metrics=None) + objective=["time"], objective_higher_is_better=[False], metrics=None) time = CostFunc(Searchspace(tune_params, restrictions, 1024), tuning_options, runner)(x) assert time == sys.float_info.max