Source code for slurmjobs.grid

'''Parameter grids!

This lets you do basic grid expansion and grid arithmetic.

.. code-block:: python

    g = Grid([
        ('a', [1, 2]),
        ('b', [1, 2]),
    ], name='train')

    # append two configurations
    g = g + LiteralGrid([{'a': 5, 'b': 5}, {'a': 10, 'b': 10}])

    # create a bigger grid from the product of another grid
    g = g * Grid([
        ('c', [5, 6])
    ], name='dataset')

    assert list(g) == [
        {'a': 1, 'b': 1, 'c': 5},
        {'a': 1, 'b': 1, 'c': 6},
        {'a': 1, 'b': 2, 'c': 5},
        {'a': 1, 'b': 2, 'c': 6},
        {'a': 2, 'b': 1, 'c': 5},
        {'a': 2, 'b': 1, 'c': 6},
        {'a': 2, 'b': 2, 'c': 5},
        {'a': 2, 'b': 2, 'c': 6},
        {'a': 5, 'b': 5, 'c': 5},
        {'a': 5, 'b': 5, 'c': 6},
        {'a': 10, 'b': 10, 'c': 5},
        {'a': 10, 'b': 10, 'c': 6},
    ]


'''
import itertools
import collections

# Parameter Grids


[docs]class BaseGrid: '''The base class for all grids. Use this if you want to extend another grid. You just need to implement: * ``__iter__``: This should yield all items generated by the grid * ``__len__``: This should tell you the number of items in the grid. If you cannot determine the length of a grid, then raise a TypeError (as the other grids do). * ``__repr__``: A nice string representation of the grid ''' grid = () def __init__(self, name=None, ignore_job_id_keys=None, **constants) -> None: self.name = name self.ignore_job_id_keys = ignore_job_id_keys or [] self.constants = constants
[docs] def __repr__(self): '''A nice string representation of the grid.''' return '{}({})'.format(self.__class__.__name__, ', '.join(map(repr(self.grid))))
def __str__(self): '''''' # NOTE: not sure if this should be different or if we should switch str/repr return repr(self)
[docs] def __len__(self): '''Get the number of iterations in the grid. Note that any use of generators or functions without a length will cause this to raise a TypeError. ''' raise NotImplemented
[docs] def __iter__(self): '''Yield all combinations from the parameter grid.''' raise NotImplemented
[docs] def __add__(self, other): '''Combine two parameter grids sequentially.''' return GridChain(self, other)
[docs] def __mul__(self, other): '''Create a grid as the combination of two grids.''' return GridCombo(self, other)
[docs] @classmethod def as_grid(cls, grid): '''Ensure that a value is a grid.''' if isinstance(grid, BaseGrid): return grid if isinstance(grid, (list, tuple)): if all(isinstance(g, dict) for g in grid): return LiteralGrid(grid) return Grid(grid)
[docs]class Grid(BaseGrid): '''A parameter grid! To get all combinations from the grid, just do ``list(Grid(...))``. Arguments: grid (list, dict): The parameter grid. Should be either a dict or a list of key values, where the values are a list of values to use in the grid. Examples of valid inputs: .. code-block:: python # simple grid Grid([ ('a', [1, 2]), ('b', [1, 2]) ]) Grid({ 'a': [1, 2], 'b': [1, 2] }) # paired parameters Grid([ ('a', [1, 2]), (('b', 'c'), ([1, 2], [1, 2])) ]) Grid([ ('a', [1, 2]), [{'b': 1, 'c': 1}, {'b': 2, 'c': 2}], ]) # any of these are valid grid specs g = slurmjobs.Grid([ # basic ('a', [1, 2]), # paired (('b', 'c'), ([1, 1, 2, 2], [1, 2, 1, 2])), # literal list of dicts [{'d': i} for i in [1, 2]], # dict generator ({'e': i} for i in [1, 2]), # function lambda: [{'f': i} for i in [1, 2]], # function that returns a generator lambda: ({'g': i} for i in [1, 2]), # basic generator ('h', (x for x in [1, 2])), # basic function ('i', lambda: [x for x in [1, 2]]), ]) keys = 'abcdefghi' assert list(g) == [ dict(zip(keys, vals)) for vals in itertools.product(*([ [1, 2] ]*len(keys))) ] name (str): The name of this grid. Can be used to search for the parameters from this grid. **constants: Extra parameters to add to the grid that don't vary. These will not be included in the job_id name. .. .. code-block:: python .. g = Grid([ .. ('a', [1, 2]), .. ('b', [1, 2]), .. ]) .. assert list(g) == [ .. {'a': 1, 'b': 1}, .. {'a': 2, 'b': 1}, .. {'a': 1, 'b': 2}, .. {'a': 2, 'b': 2}, .. ] .. You can also do pairwise parameter expansion. .. .. code-block:: python .. g = Grid([ .. ('a', [1, 2]), .. (('b', 'c'), ([1, 2], [3, 4])), .. ]) .. assert list(g) == [ .. {'a': 1, 'b': 1, 'c': 3}, .. {'a': 2, 'b': 1, 'c': 3}, .. {'a': 1, 'b': 2, 'c': 4}, .. {'a': 2, 'b': 2, 'c': 4}, .. ] Just a heads up, there is nothing stopping you from passing an infinite generator, meaning that you can make some fancy sampling grid generators, but ``slurmjobs`` will take that and not know when to stop. If you want to use an infinite generator, just wrap it in ``itertools.islice`` which will let you provide a limit. Obviously, ``slurmjobs`` doesn't operate anywhere near the memory scale where you'd need to even use generators in the first place, but I figured why limit the implementation if it can be used for other things too. ''' _is_expandable = False _easy_access_nested = False def __init__(self, __grid, name=None, **constants): self.grid = list(__grid.items()) if isinstance(__grid, dict) else __grid super().__init__(name, **constants) def __repr__(self): return '[\n{}]'.format(''.join(map(' {!r},\n'.format, self.grid)))
[docs] def __len__(self): return prod(self._as_grid_length(g) for g in self.grid)
[docs] def __getitem__(self, index): '''Get the series for a variable name.''' res = self._get_by_key(self.grid, index) if res is not None: return res raise KeyError(index)
def __setitem__(self, index, value): res = self._set_by_key(self.grid, index, value) if res: return raise KeyError(index) def _get_by_key(self, grid, name): for xs in grid: try: len(xs) k, v = xs except (TypeError, ValueError): continue if isinstance(k, dict): continue if self._easy_access_nested and not isinstance(name, (list, tuple)) and isinstance(k, (list, tuple)): res = self._get_by_key(zip(*xs), name) if res is None: continue return res if k == name: return v def _set_by_key(self, grid, name, value): for i, xs in enumerate(grid): try: len(xs) # don't unpack generators k, v = xs except (TypeError, ValueError): continue if isinstance(k, dict): continue if self._easy_access_nested and not isinstance(name, (list, tuple)) and isinstance(k, (list, tuple)): rep = list(zip(k, v)) if self._set_by_key(rep, name, value): grid[i] = list(zip(*rep)) return True continue if k == name: grid[i] = k, value return True if self.is_expandable: grid.append((name, value)) return True return False def _as_grid_length(self, xs): '''Determine the length of a grid product item.''' try: l = len(xs) if l == 2 and not isinstance(xs[0], dict): if isinstance(xs[0], (list, tuple)): return max(len(x) for x in xs[1]) return len(xs[1]) return l except (TypeError, IndexError): raise TypeError("Could not determine accurate length from grid item: {}".format(repr(xs))) def _as_grid_product_iter(self, xs): '''This returns one iteration to use in the product. ''' if callable(xs): xs = xs() # checking the first value try: first = xs[0] # list of dicts? except TypeError: # generator of dicts: ({'a': ...} for _ in ...) firsts, xs = peek(xs) if not firsts: return [] first = firsts[0] except IndexError: # empty list return [] # [{'a': 1, 'b': 1}, {'a': 2, 'b': 2}] if isinstance(first, dict): return xs # asserting it's length 2 key, values = xs # if any of the values are functions, call them values = [v() if callable(v) else v for v in values] # (('a', 'b'), ([1, 2, 3], [3, 4, 5])) if isinstance(key, (list, tuple)): return ( collections.OrderedDict([(k, v) for k, v in zip(key, vs) if v != ...]) for vs in itertools.zip_longest(*values, fillvalue=...)) # ('a', [1, 2, 3]) return ({key: v} for v in values)
[docs] def __iter__(self): grid = [self._as_grid_product_iter(xs) for xs in self.grid] for ds in itertools.product(*grid): # expand grid pairs [('a', 'b'), ([1, 2, 3], [1, 2, 3])] yield GridItem( dict({k: v for d in ds for k, v in d.items()}, **self.constants), [k for d in ds for k in d], self.name, ignore_keys=self.ignore_job_id_keys)
[docs]class LiteralGrid(BaseGrid): '''A parameter grid, specified as a flattened list. This doesn't do any grid expansion, it lets you specify the grid as you want. Arguments: grid (list, dict): The parameter list. Should be a list of dicts, each corresponding to a parameter config. name (str): The name of this grid. Can be used to search for the parameters from this grid. .. code-block:: python g = LiteralGrid([ {'a': 1, 'b': 1}, {'a': 1, 'b': 2}, {'a': 2, 'b': 2}, ]) assert list(g) == [ {'a': 1, 'b': 1}, {'a': 1, 'b': 2}, {'a': 2, 'b': 2}, ] ''' def __init__(self, __grid, name=None, **constants): self.grid = [__grid] if isinstance(__grid, dict) else __grid super().__init__(name, **constants) def __repr__(self): return '[\n{}]'.format(''.join(map(' {!r},\n'.format, self.grid)))
[docs] def __len__(self): return len(self.grid)
[docs] def __iter__(self): for d in self.grid: # expand grid pairs [('a', 'b'), ([1, 2, 3], [1, 2, 3])] keys = d.grid_keys if isinstance(d, _BaseGridItem) else list(d) yield GridItem( dict(d, **self.constants), keys, self.name, ignore_keys=self.ignore_job_id_keys)
# should we just use dict's internal ordering for the keys? class _BaseGridItem(dict): positional = () # def variant_items(self): # return [(k, self[k]) for k in self.grid_keys] class GridItem(_BaseGridItem): '''Represents a dictionary of arguments, the keys that vary, and a name for the group of args. ''' def __init__(self, grid=None, keys=(), name=None, positional=(), ignore_keys=None): if ignore_keys: keys = [k for k in keys if k not in ignore_keys] self.grid_keys = list(keys or ()) self.name = name self.positional = positional or () super().__init__(() if grid is None else grid) def __getitem__(self, key): return super().__getitem__(key) def find(self, name): return self if self.name == name else None def pop(self, key, *a, **kw): x = super().pop(key, *a, **kw) if key in self.grid_keys: self.grid_keys.remove(key) return x def __delitem__(self, key): super().__delitem__(key) if key in self.grid_keys: self.grid_keys.remove(key) class GridItemBundle(_BaseGridItem): '''Merges GridItems/GridBundles. Merges the dict, keys, and any groups. Can use ``.find(name)`` so search for a subset of items. ''' def __init__(self, *grids, name=None, ignore_keys=None): merged = {} groups = {} keys = [] for d in grids: merged.update(d) keys.extend(d.grid_keys or ()) for ki, di in getattr(d, 'groups', {}).items(): groups[ki] = dict(groups.get(ki, ()), **di) k = d.name if k is not None: groups[k] = dict(groups.get(k, ()), **d) if name is not None: groups[name] = dict(groups.get(name, ()), **merged) self.name = name if ignore_keys: keys = (k for k in keys if k not in ignore_keys) self.grid_keys = unique(keys) self.groups = groups super().__init__(merged) def __getitem__(self, key): if key in self.groups: return self.groups[key] return super().__getitem__(key) def __getattr__(self, key): return self.__getitem__(key) def find(self, name): return self.groups.get(name)
[docs]class GridChain(BaseGrid): '''This handles the addition of two grids (one after the other). You can create this doing ``grid_a + grid_b``. The only reason to use this directly is if you want to give it a name. .. code-block:: python a = Grid(('a', [1, 2]), ('b', [1, 2])) b = Grid(('c', [1, 2])) # functionally equivalent c = a + b c = GridChain(a, b, name='my-a-then-b-grid') c_items = list(a) + list(b) ''' def __init__(self, *grids, name=None): self.grid = grids super().__init__(name) def __repr__(self): return ' + '.join(map(repr, self.grid))
[docs] def __len__(self): return sum(len(g) for g in self.grid)
[docs] def __iter__(self): for g in self.grid: yield from g
[docs]class GridCombo(BaseGrid): '''This handles the multiplication of two grids (combinations). It will create a grid as a product of all provided grids. You can create this doing ``grid_a * grid_b``. The only reason to use this directly is if you want to give it a name or if you want to make a grid product of 3 or more grids. .. code-block:: python a = Grid(('a', [1, 2]), ('b', [1, 2])) b = Grid(('c', [1, 2])) # functionally equivalent c = a * b c = GridCombo(a, b, name='my-a-b-combo-grid') c_items = [ dict(da, **db) for da, db in itertools.product(a, b) ] ''' def __init__(self, *grids, name=None): self.grid = grids super().__init__(name) def __repr__(self): return ' * '.join(map('({!r})'.format, self.grid))
[docs] def __len__(self): return prod(len(g) for g in self.grid)
[docs] def __iter__(self): for gs in itertools.product(*self.grid): yield GridItemBundle(*gs, name=self.name)
[docs]class GridOmission(BaseGrid): '''This handles the subtraction of two grids (combinations). It will yield only dicts from ``grid_a`` that don't appear in ``grid_b``. You can create this doing ``grid_a - grid_b``. The only reason to use this directly is if you want to give it a name. .. code-block:: python a = Grid(('a', [1, 2]), ('b', [1, 2])) b = Grid(('a', [2]), ('b', [1])) # functionally equivalent c = a - b c = GridOmission(a, b, name='my-a-minus-b-grid') omit = list(b) c_items = [da for da in a if da not in omit] ''' def __init__(self, grid, omission, name=None): self.grid = grid self.omission = omission super().__init__(name) def __repr__(self): return ' - '.join(map('({!r})'.format, [self.grid, self.omission]))
[docs] def __len__(self): return sum(1 for d in self)
[docs] def __iter__(self): omit = list(self.omission) for d in self.grid: if d not in omit: yield d
def prod(ns): '''Like ``sum()`` but for products.''' total = 1 for n in ns: total *= n if total == 0: break return total def unique(xs): used = set() return [x for x in xs if not (x in used or used.add(x))] def peek(it, n=1): it = iter(it) first = [x for i, x in zip(range(n), it)] return first, (x for xs in (first, it) for x in xs)