Callbacks¶
-
class
qucumber.callbacks.
CallbackBase
[source]¶ Base class for callbacks.
-
on_batch_end
(nn_state, epoch, batch)[source]¶ Called at the end of each batch.
- Parameters
nn_state (qucumber.nn_states.WaveFunctionBase) – The WaveFunction being trained.
epoch (int) – The current epoch.
batch (int) – The current batch index.
-
on_batch_start
(nn_state, epoch, batch)[source]¶ Called at the start of each batch.
- Parameters
nn_state (qucumber.nn_states.WaveFunctionBase) – The WaveFunction being trained.
epoch (int) – The current epoch.
batch (int) – The current batch index.
-
on_epoch_end
(nn_state, epoch)[source]¶ Called at the end of each epoch.
- Parameters
nn_state (qucumber.nn_states.WaveFunctionBase) – The WaveFunction being trained.
epoch (int) – The current epoch.
-
on_epoch_start
(nn_state, epoch)[source]¶ Called at the start of each epoch.
- Parameters
nn_state (qucumber.nn_states.WaveFunctionBase) – The WaveFunction being trained.
epoch (int) – The current epoch.
-
on_train_end
(nn_state)[source]¶ Called at the end of the training cycle.
- Parameters
nn_state (qucumber.nn_states.WaveFunctionBase) – The WaveFunction being trained.
-
on_train_start
(nn_state)[source]¶ Called at the start of the training cycle.
- Parameters
nn_state (qucumber.nn_states.WaveFunctionBase) – The WaveFunction being trained.
-
-
class
qucumber.callbacks.
LambdaCallback
(on_train_start=None, on_train_end=None, on_epoch_start=None, on_epoch_end=None, on_batch_start=None, on_batch_end=None)[source]¶ Class for creating simple callbacks.
This callback is constructed using the passed functions that will be called at the appropriate time.
- Parameters
on_train_start (callable or None) – A function to be called at the start of the training cycle. Must follow the same signature as
CallbackBase.on_train_start
.on_train_end (callable or None) – A function to be called at the end of the training cycle. Must follow the same signature as
CallbackBase.on_train_end
.on_epoch_start (callable or None) – A function to be called at the start of every epoch. Must follow the same signature as
CallbackBase.on_epoch_start
.on_epoch_end (callable or None) – A function to be called at the end of every epoch. Must follow the same signature as
CallbackBase.on_epoch_end
.on_batch_start (callable or None) – A function to be called at the start of every batch. Must follow the same signature as
CallbackBase.on_batch_start
.on_batch_end (callable or None) – A function to be called at the end of every batch. Must follow the same signature as
CallbackBase.on_batch_end
.
-
class
qucumber.callbacks.
ModelSaver
(period, folder_path, file_name, save_initial=True, metadata=None, metadata_only=False)[source]¶ Callback which allows model parameters (along with some metadata) to be saved to disk at regular intervals.
This callback is called at the end of each epoch. If save_initial is True, will also be called at the start of the training cycle.
- Parameters
period (int) – Frequency of model saving (in epochs).
folder_path (str) – The directory in which to save the files
file_name (str) – The name of the output files. Should be a format string with one blank, which will be filled with either the epoch number or the word “initial”.
save_initial (bool) – Whether to save the initial parameters (and metadata).
metadata (callable or dict or None) – The metadata to save to disk with the model parameters Can be either a function or a dictionary. In the case of a function, it must take 2 arguments the RBM being trained, and the current epoch number, and then return a dictionary containing the metadata to be saved.
metadata_only (bool) – Whether to save only the metadata to disk.
-
class
qucumber.callbacks.
Logger
(period, logger_fn=<built-in function print>, msg_gen=None, **msg_gen_kwargs)[source]¶ Callback which logs output at regular intervals.
This callback is called at the end of each epoch.
- Parameters
period (int) – Logging frequency (in epochs).
logger_fn (callable) – The function used for logging. Must take 1 string as an argument. Defaults to the standard print function.
msg_gen (callable) – A callable which generates the string to be logged. Must take 2 positional arguments: the RBM being trained and the current epoch. It must also be able to take some keyword arguments.
**kwargs – Keyword arguments which will be passed to msg_gen.
-
class
qucumber.callbacks.
EarlyStopping
(period, tolerance, patience, evaluator_callback, quantity_name)[source]¶ Stop training once the model stops improving. The specific criterion for stopping is:
where is the metric value at the current evaluation (time ), is the “patience” parameter, and is the tolerance.
This callback is called at the end of each epoch.
- Parameters
period (int) – Frequency with which the callback checks whether training has converged (in epochs).
tolerance (float) – The maximum relative change required to consider training as having converged.
patience (int) – How many intervals to wait before claiming the training has converged.
evaluator_callback (
MetricEvaluator
orObservableEvaluator
) – An instance ofMetricEvaluator
orObservableEvaluator
which computes the metric that we want to check for convergence.quantity_name (str) – The name of the metric/observable stored in evaluator_callback.
-
class
qucumber.callbacks.
VarianceBasedEarlyStopping
(period, tolerance, patience, evaluator_callback, quantity_name, variance_name)[source]¶ Stop training once the model stops improving. This is a variation on the
EarlyStopping
class which takes the variance of the metric into account. The specific criterion for stopping is:where is the metric value at the current evaluation (time ), is the “patience” parameter, is the variance of the metric, and is the tolerance.
This callback is called at the end of each epoch.
- Parameters
period (int) – Frequency with which the callback checks whether training has converged (in epochs).
tolerance (float) – The maximum (standardized) change required to consider training as having converged.
patience (int) – How many intervals to wait before claiming the training has converged.
evaluator_callback (
MetricEvaluator
orObservableEvaluator
) – An instance ofMetricEvaluator
orObservableEvaluator
which computes the metric/observable that we want to check for convergence.quantity_name (str) – The name of the metric/obserable stored in evaluator_callback.
variance_name (str) – The name of the variance stored in evaluator_callback.
-
class
qucumber.callbacks.
MetricEvaluator
(period, metrics, verbose=False, log=None, **metric_kwargs)[source]¶ Evaluate and hold on to the results of the given metric(s).
This callback is called at the end of each epoch.
Note
Since callbacks are given to
fit
as a list, they will be called in a deterministic order. It is therefore recommended that instances ofMetricEvaluator
be among the first callbacks in the list passed tofit
, as one would often use it in conjunction with other callbacks likeEarlyStopping
which may depend onMetricEvaluator
having been called.- Parameters
period (int) – Frequency with which the callback evaluates the given metric(s).
metrics (dict(str, callable)) – A dictionary of callables where the keys are the names of the metrics and the callables take the WaveFunction being trained as their positional argument, along with some keyword arguments. The metrics are evaluated and put into an internal dictionary structure resembling the structure of metrics.
verbose (bool) – Whether to print metrics to stdout.
log (str) – A filepath to log metric values to in CSV format.
**metric_kwargs – Keyword arguments to be passed to metrics.
-
__getattr__
(metric)[source]¶ Return an array of all recorded values of the given metric.
- Parameters
metric (str) – The metric to retrieve.
- Returns
The past values of the metric.
- Return type
np.array
-
__getitem__
(metric)[source]¶ Alias for
__getattr__
to enable subscripting.
-
epochs
¶ Return a list of all epochs that have been recorded.
- Return type
np.array
-
class
qucumber.callbacks.
ObservableEvaluator
(period, observables, verbose=False, log=None, **sampling_kwargs)[source]¶ Evaluate and hold on to the results of the given observable(s).
This callback is called at the end of each epoch.
Note
Since callback are given to
fit
as a list, they will be called in a deterministic order. It is therefore recommended that instances ofObservableEvaluator
be among the first callbacks in the list passed tofit
, as one would often use it in conjunction with other callbacks likeEarlyStopping
which may depend onObservableEvaluator
having been called.- Parameters
period (int) – Frequency with which the callback evaluates the given observables(s).
observables (list(qucumber.observables.ObservableBase)) – A list of Observables. Observable statistics are evaluated by sampling the WaveFunction. Note that observables that have the same name will conflict, and precedence will be given to the one which appears later in the list.
verbose (bool) – Whether to print metrics to stdout.
log (str) – A filepath to log metric values to in CSV format.
**sampling_kwargs – Keyword arguments to be passed to Observable.statistics. Ex. num_samples, num_chains, burn_in, steps.
-
__getattr__
(observable)[source]¶ Return an ObservableStatistics containing recorded statistics of the given observable.
- Parameters
observable (str) – The observable to retrieve.
- Returns
The past values of the observable.
- Return type
ObservableStatistics
-
__getitem__
(observable)[source]¶ Alias for
__getattr__
to enable subscripting.
-
__len__
()[source]¶ Return the number of timesteps that observables have been evaluated for.
- Return type
-
epochs
¶ Return a list of all epochs that have been recorded.
- Return type
np.array
-
class
qucumber.callbacks.
LivePlotting
(period, evaluator_callback, quantity_name, error_name=None, total_epochs=None, smooth=True)[source]¶ Plots metrics/observables.
This callback is called at the end of each epoch.
- Parameters
period (int) – Frequency with which the callback updates the plots (in epochs).
evaluator_callback (
MetricEvaluator
orObservableEvaluator
) – An instance ofMetricEvaluator
orObservableEvaluator
which computes the metric/observable that we want to plot.quantity_name (str) – The name of the metric/observable stored in evaluator_callback.
error_name (str) – The name of the error stored in evaluator_callback.
-
class
qucumber.callbacks.
Timer
(verbose=True)[source]¶ Callback which records the training time.
This callback is always called at the start and end of training. It will run at the end of an epoch or batch if the given model’s stop_training property is set to True.
- Parameters
verbose (bool) – Whether to print the elapsed time at the end of training.