Source code for searchgrid

from collections import Mapping as _Mapping
import itertools as _itertools

from sklearn.model_selection import GridSearchCV as _GridSearchCV
from sklearn.pipeline import Pipeline as _Pipeline


[docs]def set_grid(estimator, **grid): """Set the grid to search for the specified estimator Overwrites any previously set grid. Parameters ---------- grid : dict (str -> list of values) Keyword arguments define the values to be searched for each specified parameter. Returns ------- estimator Useful for chaining """ estimator._param_grid = grid return estimator
def _update_grid(dest, src, prefix=None): # TODO: needs docs if src is None: return dest if prefix: src = [{prefix + k: v for k, v in d.items()} for d in src] out = [] for d1, d2 in _itertools.product(dest, src): out_d = d1.copy() out_d.update(d2) out.append(out_d) return out def _build_param_grid(estimator): grid = getattr(estimator, '_param_grid', {}) if isinstance(grid, _Mapping): grid = [grid] # handle estimator parameters having their own grids for param_name, value in estimator.get_params().items(): if '__' not in param_name and hasattr(value, 'get_params'): out = [] value_grid = _build_param_grid(value) for sub_grid in grid: if param_name in sub_grid: sub_grid = [sub_grid] else: sub_grid = _update_grid([sub_grid], value_grid, param_name + '__') out.extend(sub_grid) grid = out # handle grid values having their own grids out = [] for out_d in grid: part = [out_d] for param_name, values in out_d.items(): to_update = [] no_sub_grid = [] for v in values: if hasattr(v, 'get_params'): sub_grid = _build_param_grid(v) if sub_grid is not None: to_update.extend(_update_grid([{param_name: [v]}], sub_grid, param_name + '__')) continue no_sub_grid.append(v) if no_sub_grid: to_update.append({param_name: no_sub_grid}) part = _update_grid(part, to_update) out.extend(part) if out == [{}]: return None return out
[docs]def build_param_grid(estimator): """Determine the parameter grid annotated on the estimator Parameters ---------- estimator : scikit-learn compatible estimator Should have been annotated using :func:`set_grid` Notes ----- Most often, it is unnecessary for this to be used directly, and :func:`make_grid_search` should be used instead. """ out = _build_param_grid(estimator) if out is None: return {} elif len(out) == 1: return out[0] return out
def _check_estimator(estimator): if isinstance(estimator, list): estimator = set_grid(_Pipeline([('root', estimator[0])]), root=estimator) elif not hasattr(estimator, 'fit'): raise ValueError('Expected estimator, but %r does not have .fit' % estimator) return estimator