Callbacks

class qucumber.callbacks.CallbackBase[source]

Base class for callbacks.

on_batch_end(nn_state, epoch, batch)[source]

Called at the end of each batch.

Parameters:
  • nn_state (WaveFunction) – The WaveFunction being trained.
  • epoch (int) – The current epoch.
  • batch (int) – The current batch index.
on_batch_start(nn_state, epoch, batch)[source]

Called at the start of each batch.

Parameters:
  • nn_state (WaveFunction) – The WaveFunction being trained.
  • epoch (int) – The current epoch.
  • batch (int) – The current batch index.
on_epoch_end(nn_state, epoch)[source]

Called at the end of each epoch.

Parameters:
  • nn_state (WaveFunction) – The WaveFunction being trained.
  • epoch (int) – The current epoch.
on_epoch_start(nn_state, epoch)[source]

Called at the start of each epoch.

Parameters:
  • nn_state (WaveFunction) – The WaveFunction being trained.
  • epoch (int) – The current epoch.
on_train_end(nn_state)[source]

Called at the end of the training cycle.

Parameters:nn_state (WaveFunction) – The WaveFunction being trained.
on_train_start(nn_state)[source]

Called at the start of the training cycle.

Parameters:nn_state (WaveFunction) – The WaveFunction being trained.
class qucumber.callbacks.LambdaCallback(on_train_start=None, on_train_end=None, on_epoch_start=None, on_epoch_end=None, on_batch_start=None, on_batch_end=None)[source]

Class for creating simple callbacks.

This callback is constructed using the passed functions that will be called at the appropriate time.

Parameters:
  • on_train_start (callable or None) – A function to be called at the start of the training cycle. Must follow the same signature as CallbackBase.on_train_start.
  • on_train_end (callable or None) – A function to be called at the end of the training cycle. Must follow the same signature as CallbackBase.on_train_end.
  • on_epoch_start (callable or None) – A function to be called at the start of every epoch. Must follow the same signature as CallbackBase.on_epoch_start.
  • on_epoch_end (callable or None) – A function to be called at the end of every epoch. Must follow the same signature as CallbackBase.on_epoch_end.
  • on_batch_start (callable or None) – A function to be called at the start of every batch. Must follow the same signature as CallbackBase.on_batch_start.
  • on_batch_end (callable or None) – A function to be called at the end of every batch. Must follow the same signature as CallbackBase.on_batch_end.
class qucumber.callbacks.ModelSaver(period, folder_path, file_name, save_initial=True, metadata=None, metadata_only=False)[source]

CallbackBase which allows model parameters (along with some metadata) to be saved to disk at regular intervals.

This CallbackBase is called at the end of each epoch. If save_initial is True, will also be called at the start of the training cycle.

Parameters:
  • period (int) – Frequency of model saving (in epochs).
  • folder_path (str) – The directory in which to save the files
  • file_name (str) – The name of the output files. Should be a format string with one blank, which will be filled with either the epoch number or the word “initial”.
  • save_initial (bool) – Whether to save the initial parameters (and metadata).
  • metadata (callable or dict or None) – The metadata to save to disk with the model parameters Can be either a function or a dictionary. In the case of a function, it must take 2 arguments the RBM being trained, and the current epoch number, and then return a dictionary containing the metadata to be saved.
  • metadata_only (bool) – Whether to save only the metadata to disk.
class qucumber.callbacks.Logger(period, logger_fn=<built-in function print>, msg_gen=None, **msg_gen_kwargs)[source]

CallbackBase which logs output at regular intervals.

This CallbackBase is called at the end of each epoch.

Parameters:
  • period (int) – Logging frequency (in epochs).
  • logger_fn (callable) – The function used for logging. Must take 1 string as an argument. Defaults to the standard print function.
  • msg_gen (callable) – A callable which generates the string to be logged. Must take 2 positional arguments: the RBM being trained and the current epoch. It must also be able to take some keyword arguments.
  • **kwargs – Keyword arguments which will be passed to msg_gen.
class qucumber.callbacks.EarlyStopping(period, tolerance, patience, evaluator_callback, quantity_name)[source]

Stop training once the model stops improving. The specific criterion for stopping is:

\left\vert\frac{M_{t-p} - M_t}{M_{t-p}}\right\vert < \epsilon

where M_t is the metric value at the current evaluation (time t), p is the “patience” parameter, and \epsilon is the tolerance.

This CallbackBase is called at the end of each epoch.

Parameters:
  • period (int) – Frequency with which the callback checks whether training has converged (in epochs).
  • tolerance (float) – The maximum relative change required to consider training as having converged.
  • patience (int) – How many intervals to wait before claiming the training has converged.
  • evaluator_callback (MetricEvaluator or ObservableEvaluator) – An instance of MetricEvaluator or ObservableEvaluator which computes the metric that we want to check for convergence.
  • quantity_name (str) – The name of the metric/observable stored in evaluator_callback.
class qucumber.callbacks.VarianceBasedEarlyStopping(period, tolerance, patience, evaluator_callback, quantity_name, variance_name)[source]

Stop training once the model stops improving. This is a variation on the EarlyStopping class which takes the variance of the metric into account. The specific criterion for stopping is:

\left\vert\frac{M_{t-p} - M_t}{\sigma_{t-p}}\right\vert < \kappa

where M_t is the metric value at the current evaluation (time t), p is the “patience” parameter, \sigma_t is the variance of the metric, and \kappa is the tolerance.

This CallbackBase is called at the end of each epoch.

Parameters:
  • period (int) – Frequency with which the callback checks whether training has converged (in epochs).
  • tolerance (float) – The maximum (standardized) change required to consider training as having converged.
  • patience (int) – How many intervals to wait before claiming the training has converged.
  • evaluator_callback (MetricEvaluator or ObservableEvaluator) – An instance of MetricEvaluator or ObservableEvaluator which computes the metric/observable that we want to check for convergence.
  • quantity_name (str) – The name of the metric/obserable stored in evaluator_callback.
  • variance_name (str) – The name of the variance stored in evaluator_callback.
class qucumber.callbacks.MetricEvaluator(period, metrics, verbose=False, log=None, **metric_kwargs)[source]

Evaluate and hold on to the results of the given metric(s).

This CallbackBase is called at the end of each epoch.

Note

Since CallbackBases are given to fit as a list, they will be called in a deterministic order. It is therefore recommended that instances of MetricEvaluator be among the first callbacks in the list passed to fit, as one would often use it in conjunction with other callbacks like EarlyStopping which may depend on MetricEvaluator having been called.

Parameters:
  • period (int) – Frequency with which the callback evaluates the given metric(s).
  • metrics (dict(str, callable)) – A dictionary of callables where the keys are the names of the metrics and the callables take the WaveFunction being trained as their positional argument, along with some keyword arguments. The metrics are evaluated and put into an internal dictionary structure resembling the structure of metrics.
  • verbose (bool) – Whether to print metrics to stdout.
  • log (str) – A filepath to log metric values to in CSV format.
  • **metric_kwargs – Keyword arguments to be passed to metrics.
__getattr__(metric)[source]

Return an array of all recorded values of the given metric.

Parameters:metric (str) – The metric to retrieve.
Returns:The past values of the metric.
Return type:np.array
__len__()[source]

Return the number of timesteps that metrics have been evaluated for.

Return type:int
clear_history()[source]

Delete all metric values the instance is currently storing.

epochs

Return a list of all epochs that have been recorded.

Return type:np.array
get_value(name, index=None)[source]

Retrieve the value of the desired metric from the given timestep.

Parameters:
  • name (str) – The name of the metric to retrieve.
  • index (int or None) – The index/timestep from which to retrieve the metric. Negative indices are supported. If None, will just get the most recent value.
names

The names of the tracked metrics.

Return type:list[str]
class qucumber.callbacks.ObservableEvaluator(period, observables, verbose=False, log=None, **sampling_kwargs)[source]

Evaluate and hold on to the results of the given observable(s).

This CallbackBase is called at the end of each epoch.

Note

Since CallbackBases are given to fit as a list, they will be called in a deterministic order. It is therefore recommended that instances of ObservableEvaluator be among the first callbacks in the list passed to fit, as one would often use it in conjunction with other callbacks like EarlyStopping which may depend on ObservableEvaluator having been called.

Parameters:
  • period (int) – Frequency with which the callback evaluates the given observables(s).
  • observables (list(qucumber.observables.Observable)) – A list of Observables. Observable statistics are evaluated by sampling the WaveFunction. Note that observables that have the same name will conflict, and precedence will be given to the right-most observable argument.
  • verbose (bool) – Whether to print metrics to stdout.
  • log (str) – A filepath to log metric values to in CSV format.
  • **sampling_kwargs – Keyword arguments to be passed to Observable.statistics. Ex. num_samples, num_chains, burn_in, steps.
__getattr__(observable)[source]

Return an ObservableStatistics containing recorded statistics of the given observable.

Parameters:observable (str) – The observable to retrieve.
Returns:The past values of the observable.
Return type:ObservableStatistics
__len__()[source]

Return the number of timesteps that observables have been evaluated for.

Return type:int
clear_history()[source]

Delete all statistics the instance is currently storing.

epochs

Return a list of all epochs that have been recorded.

Return type:np.array
get_value(name, index=None)[source]

Retrieve the statistics of the desired observable from the given timestep.

Parameters:
  • name (str) – The name of the observable to retrieve.
  • index (int or None) – The index/timestep from which to retrieve the observable. Negative indices are supported. If None, will just get the most recent value.
names

The names of the tracked observables.

Return type:list[str]
class qucumber.callbacks.Timer(verbose=True)[source]

CallbackBase which records the training time.

This CallbackBase is always called at the start and end of training. It will run at the end of an epoch or batch if the given model’s stop_training property is set to True.

Parameters:verbose (bool) – Whether to print the elapsed time at the end of training.