Training Statistics¶
-
qucumber.utils.training_statistics.
KL
(nn_state, target_psi, space, bases=None, **kwargs)[source]¶ A function for calculating the total KL divergence.
- Parameters
nn_state (qucumber.nn_states.WaveFunctionBase) – The neural network state (i.e. complex wavefunction or positive wavefunction).
target_psi (torch.Tensor or dict(str, torch.Tensor)) – The true wavefunction of the system. Can be a dictionary with each value being the wavefunction represented in a different basis, and the key identifying the basis.
space (torch.Tensor) – The basis elements of the Hilbert space of the system . The ordering of the basis elements must match with the ordering of the coefficients given in target_psi.
bases (np.array(dtype=str)) – An array of unique bases. If given, the KL divergence will be computed for each basis and the average will be returned.
**kwargs – Extra keyword arguments that may be passed. Will be ignored.
- Returns
The KL divergence.
- Return type
-
qucumber.utils.training_statistics.
NLL
(nn_state, samples, space, bases=None, **kwargs)[source]¶ A function for calculating the negative log-likelihood (NLL).
- Parameters
nn_state (qucumber.nn_states.WaveFunctionBase) – The neural network state (i.e. complex wavefunction or positive wavefunction).
samples (torch.Tensor) – Samples to compute the NLL on.
space (torch.Tensor) – The basis elements of the Hilbert space of the system .
bases (np.array(dtype=str)) – An array of bases where measurements were taken.
**kwargs – Extra keyword arguments that may be passed. Will be ignored.
- Returns
The Negative Log-Likelihood.
- Return type
-
qucumber.utils.training_statistics.
fidelity
(nn_state, target_psi, space, **kwargs)[source]¶ Calculates the square of the overlap (fidelity) between the reconstructed wavefunction and the true wavefunction (both in the computational basis).
- Parameters
nn_state (qucumber.nn_states.WaveFunctionBase) – The neural network state (i.e. complex wavefunction or positive wavefunction).
target_psi (torch.Tensor) – The true wavefunction of the system.
space (torch.Tensor) – The basis elements of the Hilbert space of the system . The ordering of the basis elements must match with the ordering of the coefficients given in target_psi.
**kwargs – Extra keyword arguments that may be passed. Will be ignored.
- Returns
The fidelity.
- Return type
-
qucumber.utils.training_statistics.
rotate_psi
(nn_state, basis, space, unitaries, psi=None)[source]¶ A function that rotates the reconstructed wavefunction to a different basis.
- Parameters
nn_state (qucumber.nn_states.WaveFunctionBase) – The neural network state (i.e. complex wavefunction or positive wavefunction).
basis (str) – The basis to rotate the wavefunction to.
space (torch.Tensor) – The basis elements of the Hilbert space of the system .
unitaries (dict(str, torch.Tensor)) – A dictionary of (2x2) unitary operators.
psi (torch.Tensor) – A wavefunction that the user can input to override the neural network state’s wavefunction.
- Returns
A wavefunction in a new basis.
- Return type