autogoal.utils._helpers
¶
Note
This module contains the optimize
function that allows to apply black-box hyper-parameter
search to an arbitrary Python code.
import inspect
import textwrap
import functools
from typing import Callable
from autogoal.search import PESearch
from autogoal.grammar import generate_cfg
Black-box optimization¶
The following function defines a black-box optimization that can be applied to any function.
def optimize(
fn,
search_strategy=PESearch,
generations=100,
pop_size=10,
allow_duplicates=False,
logger=None,
**kwargs,
):
Note
A general-purpose optimization function.
Simply define any function fn
with suitable parameter annotations
and apply optimize
.
Parameters:
search_strategy
: customize the search strategy. By default aPESearch
will be performed.generations
: max number of generations to run.logger
: instance ofLogger
(or list) to pass to the search strategy.**kwargs
: additional keyword arguments passed to the search strategy constructor.
params_func = _make_params_func(fn)
@functools.wraps(fn)
def eval_func(kwargs):
return fn(**kwargs)
grammar = generate_cfg(params_func)
search = search_strategy(
grammar,
eval_func,
pop_size=pop_size,
allow_duplicates=allow_duplicates,
**kwargs,
)
best, best_fn = search.run(generations, logger=logger)
return best, best_fn
Implementation details¶
To make optimize
work we need to define both a grammar and a callable function
to pass to the search algorithm
class _ParamsDict(dict):
pass
def _make_params_func(fn: Callable):
signature = inspect.signature(fn)
func_name = f"{fn.__name__}_params"
args_names = signature.parameters.keys()
def annotation_repr(ann):
if inspect.isclass(ann) or inspect.isfunction(ann):
return ann.__name__
return repr(ann)
args_line = ",\n ".join(f"{k}={k}" for k in args_names)
params_line = ", ".join(
f"{arg.name}:{annotation_repr(arg.annotation)}"
for arg in signature.parameters.values()
)
func_code = textwrap.dedent(
f"""
def {func_name}({params_line}):
return _ParamsDict(
{args_line}
)"""
)
globals_dict = dict(fn.__globals__)
globals_dict["_ParamsDict"] = _ParamsDict
locals_dict = {}
exec(func_code, globals_dict, locals_dict)
return locals_dict[func_name]