nems.modelspec module

Defines modelspec object and helper functions.

class nems.modelspec.ModelSpec(raw=None, phis=None, fit_index=0, cell_index=0, jack_index=0, recording=None, cell_count=1, fit_count=1, jack_count=1)[source]

Bases: object

Defines a model based on a NEMS modelspec.

Long goes here. TODO docs

__init__(raw=None, phis=None, fit_index=0, cell_index=0, jack_index=0, recording=None, cell_count=1, fit_count=1, jack_count=1)[source]

Initialize the modelspec.

TODO more details a modelspec can have multiple fits, each of which contains a different set of phi values.

Parameters
  • raw (dict) – Nested list of dictionaries. Equivalent of the old NEMS modelspec. The first level is a list of cells, each of which is a list of lists. The second level is a list of fits, each of which is a list of dictionaries. The third level is a list of jacknifes, each of which is a list of dictionaries. Each dictionary specifies a module, or one step in the model.

  • phis (list) – The free parameters.

  • fit_index (int) – Index of which fit to reference when multiple are present. Defaults to 0.

  • cell_index (int) – Index of which cell to reference when multiple are present. Defaults to 0.

  • jack_index (int) – Index of which jacknife to reference when multiple are present. Defaults to 0.

  • cell_count (int) – Defaults to 1.

  • fit_count (int) – Defaults to 1.

  • jack_count (int) – Defaults to 1.

  • recording – recording for evaluation & plotting. Defaults to None

__getitem__(key)[source]

Get the given item from the modelspec.

Overloaded in order to allow accessing of other elements. Key can be either an int, a slice object, meta, or phi.

Parameters

key – Index or object to retrieve from the modelspec.

Returns

Either a module of the modelspec, or the phi, the meta, or the slice of the data.

Raises

ValueError – Raised if key out of bounds or not one of the above.

__setitem__(key, val)[source]

Update the raw dict of the modelspec.

Updates the current modelspec raw dict at the current cell_index, fit_index, and jack_index.

Parameters
  • key (int) – Which index in the modelspec to update the value of

  • val – The value to update to.

Returns

Self, updated.

Raises

ValueError – If unable to set

__iter__()[source]

Set the mod_index to zero for iterators.

Returns

self, with updated mod_index

__next__()[source]

Return the proper index of the modelspec for iterators, and update the mod_index.

Returns

The current module mod_index.

__repr__()[source]

Overloaded repr.

Returns

Repr of the modelspec raw dict.

__str__()[source]

Overloaded str.

Return str

String newline concat of the module functions.

__len__()[source]

Overloaded len.

Return int

Length of the raw dict at the current cell_index, fit_index, and jack_index..

copy(fit_index=None, jack_index=None)[source]

Generate a deep copy of the modelspec.

Parameters
  • fit_index (int) –

  • jack_index (int) –

Returns

A deep copy of the modelspec (subset of modules if specified).

get_module(mod_index=None)[source]

Get the requested module.

Returns the raw dict at the mod_index, and current cell_index, fit_index, and jack_index.

Parameters

mod_index (int) – Index of module to return, defaults to mod_index if None.

Returns

Single module from current fit_index. Does not create a copy.

drop_module(mod_index=None, in_place=False)[source]

Drop a module from the modelspec.

Return a new modelspec with the module dropped, or optionally drop the module in place.

Parameters
  • mod_index (int) – Index of module ot drop.

  • in_place (bool) – Whether or not to drop in place, or return a copy.

Returns

None if in place, otherwise a new modelspec without the dropped module.

property modules

All of the modules for the current cell_index, fit_index, and jack_index.

tile_fits(fit_count=1)[source]

Create fit_count sets of fit parameters to allow for multiple fits.

Useful for n-fold cross validation or starting from multiple initial conditions. Values of each phi are copied from the existing first value. Applied in-place.

Parameters

fit_count (int) – Number of tiles to create.

Returns

Self.

tile_jacks(jack_count=0)[source]

Create jack_count sets of fit parameters to allow for multiple jackknifes.

Useful for n-fold cross validation or starting from multiple initial conditions. Values of each phi are copied from the existing first value. Applied in-place.

Parameters

jack_count (int) – Number of tiles to create.

Returns

Self.

property cell_count

Number of cells (sets of phi values) in this modelspec.

property fit_count

Number of fits (sets of phi values) in this modelspec.

property jack_count

Number of jackknifes (sets of phi values) in this modelspec.

set_cell(cell_index=None)[source]

Set the cell_index. Done in place.

Parameters

cell_index (int) – The updated cell_index.

Returns

Self.

set_fit(fit_index)[source]

Set the fit_index. Done in place.

Parameters

fit_index (int) – The updated fit_index.

Returns

Self.

set_jack(jack_index=None)[source]

Set the jack_index. Done in place.

Parameters

jack_index (int) – The updated jack_index.

Returns

Self.

fits()[source]

List of modelspecs, one for each fit, for compatibility with some old functions.

property meta

Dict of meta information.

property recording

recording for current cell_index.

property modelspecname

Name of the modelspec.

fn()[source]

List of fn for each module.

property phi

The free parameters for the model.

Parameters

fit_index (int) – Which model fit to use (defaults to fit_index).

Returns

List of phi dictionaries, or None for modules with no phi.

property phi_mean

Mean of phi across fit_indexes and/or jack_indexes.

Parameters

mod_idx (int) – Which module to use (default all modules).

Returns

List of phi dictionaries, mean of each value.

property phi_sem

SEM of phi across fit_indexes and/or jack_indexes.

Parameters

mod_idx (int) – Which module to use (default all modules).

Returns

List of phi dictionaries, jackknife sem of each value.

property phi_vector

Vector of phi across fit_indexes.

Parameters

fit_index (int) – Which model fit to use (defaults to fit_index).

Returns

Vector of phi values from all modules.

get_plot_fn(mod_index=None, plot_fn_idx=None, fit_index=None)[source]

Get the plotting function for the specified module.

Parameters
  • mod_index (int) – Which module in the modelspec to get the plotting function for.

  • plot_fn_idx (int) – Which plotting function in the list to get.

  • fit_index (int) – Update the fit index if not None.

Returns

A plotting function.

plot(mod_index=0, plot_fn_idx=None, fit_index=None, rec=None, sig_name='pred', channels=None, ax=None, **kwargs)[source]

Generate the plot for a single module.

Parameters
  • mod_index – Which module in the modelspec to generate the plot for.

  • plot_fn_idx – Which function in the list of plot functions.

  • fit_index – Update the fit index.

  • rec – The recording from which to pull the data.

  • sig_name – Which signal in the recording.

  • channels – Which channel in the signal.

  • ax – Axis on which to plot.

  • kwargs – Optional keyword args.

quickplot(rec=None, epoch=None, occurrence=None, fit_index=None, include_input=True, include_output=True, size_mult=(1.0, 2.0), figsize=None, fig=None, time_range=None, sig_names=None, modidx_set=None)[source]

Generate a summary plot of a subset of the data.

Parameters
  • rec – The recording from which to pull the data.

  • epoch – Name of epoch from which to extract data.

  • occurrence (int) – Which occurrences of the data to plot.

  • fit_index (int) – Update the fit index.

  • include_input (bool) – Whether to include default plot of the inputs.

  • include_output (bool) – Whether to include default plot of the outputs.

  • size_mult (tuple) – Scale factors for width and height of figure.

  • figsize (tuple) – Size of figure (tuple of inches).

  • time_range (tuple) – If not None, plot signals from time_range[0]-time_range[1] sec

  • sig_names – list of signal name strings (default [‘stim’])

  • modidx_set – list of mod indexes to plot (default all)

Returns

Matplotlib figure.

append(module)[source]

Append a module to the modelspec.

Parameters

module – A module dict.

pop_module()[source]

Remove the last module from the modelspec.

get_priors(data)[source]

TODO docs.

Parameters

data – TODO docs

Returns

TODO docs

evaluate(rec=None, **kwargs)[source]

Evaluate the Model on a recording. essentially a wrapper for modelspec.evaluate.

Parameters
  • rec – Recording object (default is self.recording preset to val usually)

  • modelspec – Modelspec object.

  • start – Start evaluation at module start, assuming rec[‘pred’] is in the appropriate state to feed into modelspec[start].

  • stop – Stop at this module.

Returns

Recording copy of input with pred updated with prediction.

fast_eval_on(rec=None, subset=None)[source]

Quickly evaluates a model on a recording.

Enter fast eval mode, where model is evaluated up through the first module that has a fittable phi. Evaluate model on rec up through the preceding module and save in freeze_rec.

Parameters
  • rec – Recording object to evaluate.

  • subset – Which subset of the data to evaluate.

fast_eval_off()[source]

Turn off fast_eval and purge freeze_rec to free up memory.

generate_tensor(data, phi)[source]

Evaluate the module given the input data and phi.

Parameters
  • data (dict) – Dictionary of arrays and/or tensors.

  • phi (list(dict)) – list of dictionaries. Each entry in the list maps to the corresponding module in the model. If a module does not require any input parameters, use a blank dictionary. All elements in phi must be scalars, arrays or tensors.

Returns

dictionary of Signals

get_shortname()[source]

Get a string that is just the module IDs in this modelspec.

Return str

Shortname, the module IDs.

get_longname()[source]

Return a long name for this modelspec suitable for use in saving to disk without a path.

Return str

Longname, more details about the modelspec.

modelspec2tf(tps_per_stim=550, feat_dims=1, data_dims=1, state_dims=0, fs=100, net_seed=1, weight_scale=0.1, use_modelspec_init=True, distr='norm')[source]

Converts a modelspec object to Tensorflow layers.

Maps modelspec modules to Tensorflow layers. Adapted from code by Sam Norman-Haignere. https://github.com/snormanhaignere/cnn/blob/master/cnn.py

Parameters
  • tps_per_stim

  • feat_dims (int) –

  • data_dims (int) –

  • state_dims (int) –

  • fs

  • net_seed

  • weight_scale

  • use_modelspec_init (bool) –

modelspec2tf2(seed=0, use_modelspec_init=True, fs=100, initializer='random_normal', freeze_layers=None, kernel_regularizer=None)[source]

New version

TODO

get_dstrf(rec: nems.recording.Recording, index: int, width: int = 30, rebuild_model: bool = False, out_channel: int = 0, method: str = 'jacobian') numpy.array[source]

Creates a tf model from the modelspec and generates the dstrf.

Parameters
  • rec – The input recording, of shape [channels, time].

  • index – The index at which the dstrf is calculated. Must be within the data.

  • width – The width of the returned dstrf (i.e. time lag from the index). If 0, returns the whole dstrf.

Rebuild_model

Rebuild the model to avoid using the cached one.

Zero padded if out of bounds.

Returns

np array of size [channels, width]

nems.modelspec.get_modelspec_metadata(modelspec)[source]

Return a dict of the metadata for this modelspec.

Purely by convention, metadata info for the entire modelspec is stored in the first module.

Parameters

modelspec – Modelspec object from which to get metadata.

Return dict

Modelspec meta dict.

nems.modelspec.set_modelspec_metadata(modelspec, key, value)[source]

Set a key/value pair in the modelspec’s metadata.

Purely by convention, metadata info for the entire modelspec is stored in the first module.

Parameters
  • modelspec – Modelspec object from which to get metadata.

  • key – Update key.

  • value – Update value.

Param

The modelspec with updated meta.

nems.modelspec.get_modelspec_shortname(modelspec)[source]

Return a string that is just the module ids in this modelspec.

Parameters

modelspec – Modelspec object from which to get metadata.

Return str

The modelspec shortname.

nems.modelspec.get_modelspec_longname(modelspec)[source]

Return a LONG name for this modelspec suitable for use in saving to disk without a path.

Parameters

modelspec – Modelspec object from which to get metadata.

Return str

The modelspec longname.

nems.modelspec._modelspec_filename(basepath, number)[source]

Append a number to the end of a filepath.

Parameters
  • basepath – Path to add number to.

  • number – Number to add.

Returns

String of basepath with suffix added.

nems.modelspec.save_modelspec(modelspec, filepath)[source]

Save a modelspec to filepath. Overwrites any existing file.

Parameters
  • modelspec – Modelspec object from which to get metadata.

  • filepath – Save location.

nems.modelspec.save_modelspecs(directory, modelspecs, basename=None)[source]

Save one or more modelspecs to disk with stereotyped filenames.

Ex:

directory/basename.0000.json directory/basename.0001.json directory/basename.0002.json …etc…

Basename will be automatically generated if not provided.

Parameters
  • directory – Save location.

  • modelspecs (list) – List of modelspecs to save.

  • basename – Save name of modelspecs, otherwise will use modelspec long name.

Returns

The filepath of the last saved modelspec.

nems.modelspec.load_modelspec(uri)[source]

Return a single modelspecs loaded from uri.

Parameters

uri – URI of modelspec.

Returns

A new modelspec object loaded form the uri.

nems.modelspec.load_modelspecs(directory, basename, regex=None)[source]

Return a list of modelspecs loaded from directory/basename.*.json.

Parameters
  • directory – Directory to search for modelspecs.

  • basename – Name of modelspecs to match against.

  • regex – Optional regex matching for modelspec names.

Returns

A new modelspec object.

nems.modelspec._lookup_fn_at(fn_path, ignore_table=False)[source]

Private function that returns a function handle found at a given module.

Basically, a way to import a single function. e.g.

myfn = _lookup_fn_at(‘nems.modules.fir.fir_filter’) myfn(data) …

Parameters
  • fn_path – Path to the function.

  • ignore_table – Whether or not to look up the function in the cache.

Returns

Function handle.

nems.modelspec.fit_mode_on(modelspec, rec=None, subset=None)[source]

Turn on norm.recalc for each module when present.

TODO docs can this be removed?

Parameters

modelspec

param rec: param subset:

nems.modelspec.fit_mode_off(modelspec)[source]

Turn off norm.recalc for each module when present.

TODO docs can this be removed?

Parameters

modelspec

nems.modelspec.eval_ms_layer(data: numpy.ndarray, layer_spec: Union[None, str] = None, state_data: Optional[numpy.ndarray] = None, stop: Union[None, int] = None, modelspec: Optional[nems.modelspec.ModelSpec] = None) numpy.ndarray[source]

Takes in a numpy array and applies a single ms layer to it.

Parameters
  • data – The input data. Shape of (reps, time, channels).

  • layer_spec – A layer spec for layers of a modelspec.

  • state_data – State gain data, optional. Same shape as data.

  • stop – What layer to eval to. Non inclusive. If not passed, will evaluate the whole layer spec.

  • modelspec – Optionally use an existing modelspec. Takes precedence over layer_spec.

Returns

The processed data.

nems.modelspec.evaluate(rec, modelspec, start=None, stop=None)[source]

Given a recording object and a modelspec, return a prediction in a new recording.

Does not alter modelspec’s arguments in any way. Only evaluates modules at indices start through stop-1. A value of None for start will include the beginning of the list, and a value of None for stop will include the end of the list (whereas a value of -1 for stop will not). Evaluates using cell/fit/jack currently selected for modelspec.

Parameters
  • rec – Recording object.

  • modelspec – Modelspec object.

  • start – Start evaluation at module start, assuming rec[‘pred’] is in the appropriate state to feed into modelspec[start].

  • stop – Stop at this module.

Returns

Recording copy of input with pred updated with prediction.

nems.modelspec.evaluate_tf(rec, modelspec, epoch_name='REFERENCE', **kwargs)[source]
nems.modelspec.summary_stats(modelspecs, mod_key='fn', meta_include=[], stats_keys=[])[source]

Generate summary statistics for a list of modelspecs.

Each modelspec must be of the same length and contain the same modules (though they need not be in the same order).

For example, ten modelspecs composed of the same modules that were fit to ten different datasets can be compared. However, ten modelspecs all with different modules fit to the same data cannot be compared because there is no guarantee that they contain comparable parameter values.

Parameters
  • modelspecs (list) – List of modelspecs

  • mod_key – TODO docs

  • stats_keys – TODO docs remove?

Returns

Nested dictionary of stats. {‘module.function—parameter’:

{‘mean’:M, ‘std’:S, ‘values’:[v1,v2 …]}}

Where M, S and v might be scalars or arrays depending on the typical type for the parameter.

nems.modelspec.get_best_modelspec(modelspecs, metakey='r_test', comparison='greatest')[source]

Get the best modelspec ranked by the given metakey.

Examine the first-module meta information within each modelspec in a list, and return a singleton list containing the modelspec with the greatest value for the specified metakey by default (or the least value optionally).

Parameters
  • modelspecs (list) – Modelspecs to compare.

  • metakey – Key to compare across modelspecs.

  • comparison (str) – greatest or least.

Return list

Modelspec with greatest/least metakey.

nems.modelspec.sort_modelspecs(modelspecs, metakey='r_test', order='descending')[source]

Sort Modelspecs by given metakey.

Sorts modelspecs in order of the given metakey, which should be in the first-module meta entry of each modelspec.

Parameters
  • modelspecs (list) – List of modelspecs to sort.

  • metakey – Key to compare across modelspecs.

  • orderdescending or ascending.

Return list

Sorted list of modelspecs.

nems.modelspec.try_scalar(x)[source]

Try to convert x to scalar, in case of ValueError just return x.

Parameters

x – Value to convert to scalar.