jmstate.types.PrecisionParameters¶
- class PrecisionParameters(flat, dim, precision_type)[source]¶
nn.Module encapsulating precision matrix parameters.
This class provides three types of precision matrix parametrization: full, diagonal, and spherical (scalar). The default is full matrix parametrization. Precision matrices are internally represented using the log-Cholesky parametrization of the inverse covariance (precision) matrix. Formally, let \(P = \Sigma^{-1}\) be the precision matrix and \(L\) its Cholesky factor with positive diagonal elements. The log-Cholesky representation \(\tilde{L}\) is defined by:
\[\tilde{L}_{ij} = L_{ij}, \quad i > j\]\[\tilde{L}_{ii} = \log L_{ii}.\]This representation ensures numerical stability and avoids explicit inversion when computing quadratic forms. The log determinant of the precision matrix is then
\[\log \det P = 2 \operatorname{Tr}(\tilde{L}).\]Instances can be created from a precision matrix using the from_precision or from a covariance matrix using the from_covariance classmethod with precision_type set to ‘full’, ‘diag’, or ‘spherical’.
- Variables:
flat (torch.Tensor) – Flat representation of the precision matrix suitable for optimization.
dim (int) – Dimension of the precision matrix.
precision_type (str) – Type of parametrization, one of ‘full’, ‘diag’, or ‘spherical’.
- Parameters:
flat (Tensor)
dim (int)
precision_type (str)
Examples
>>> random_prec = PrecisionParameters.from_covariance(torch.eye(3), "diag") >>> noise_prec = PrecisionParameters.from_covariance(torch.eye(2), "spherical")
- __init__(flat, dim, precision_type)[source]¶
Initializes the PrecisionParameters object.
- Parameters:
flat (torch.Tensor) – The flat representation of the precision matrix.
dim (int) – The dimension of the precision matrix.
precision_type (str) – The method used to parametrize the precision matrix.
- Raises:
ValueError – If the representation is invalid.
- classmethod from_precision(P, precision_type='full')[source]¶
Gets instance from precision matrix according to choice of precision type.
- Parameters:
P (torch.Tensor) – The square precision matrix.
precision_type (str, optional) – The method, ‘full’, ‘diag’, or ‘spherical’. Defaults to ‘full’.
- Returns:
The usable representation.
- Return type:
Self
- classmethod from_covariance(V, precision_type='full')[source]¶
Gets instance from covariance matrix according to choice of precision type.
- Parameters:
V (torch.Tensor) – The square covariance matrix.
precision_type (str, optional) – The method, ‘full’, ‘diag’, or ‘spherical’. Defaults to ‘full’.
- Returns:
The usable representation.
- Return type:
Self
- property precision: Tensor¶
Gets the precision matrix.
- Returns:
The precision matrix.
- Return type:
torch.Tensor
- property covariance: Tensor¶
Gets the covariance matrix.
- Returns:
The covariance matrix.
- Return type:
torch.Tensor