jmstate.types.PrecisionParameters

class PrecisionParameters(flat, dim, precision_type)[source]

nn.Module encapsulating precision matrix parameters.

This class provides three types of precision matrix parametrization: full, diagonal, and spherical (scalar). The default is full matrix parametrization. Precision matrices are internally represented using the log-Cholesky parametrization of the inverse covariance (precision) matrix. Formally, let \(P = \Sigma^{-1}\) be the precision matrix and \(L\) its Cholesky factor with positive diagonal elements. The log-Cholesky representation \(\tilde{L}\) is defined by:

\[\tilde{L}_{ij} = L_{ij}, \quad i > j\]
\[\tilde{L}_{ii} = \log L_{ii}.\]

This representation ensures numerical stability and avoids explicit inversion when computing quadratic forms. The log determinant of the precision matrix is then

\[\log \det P = 2 \operatorname{Tr}(\tilde{L}).\]

Instances can be created from a precision matrix using the from_precision or from a covariance matrix using the from_covariance classmethod with precision_type set to ‘full’, ‘diag’, or ‘spherical’.

Variables:
  • flat (torch.Tensor) – Flat representation of the precision matrix suitable for optimization.

  • dim (int) – Dimension of the precision matrix.

  • precision_type (str) – Type of parametrization, one of ‘full’, ‘diag’, or ‘spherical’.

Parameters:
  • flat (Tensor)

  • dim (int)

  • precision_type (str)

Examples

>>> random_prec = PrecisionParameters.from_covariance(torch.eye(3), "diag")
>>> noise_prec = PrecisionParameters.from_covariance(torch.eye(2), "spherical")
__init__(flat, dim, precision_type)[source]

Initializes the PrecisionParameters object.

Parameters:
  • flat (torch.Tensor) – The flat representation of the precision matrix.

  • dim (int) – The dimension of the precision matrix.

  • precision_type (str) – The method used to parametrize the precision matrix.

Raises:

ValueError – If the representation is invalid.

classmethod from_precision(P, precision_type='full')[source]

Gets instance from precision matrix according to choice of precision type.

Parameters:
  • P (torch.Tensor) – The square precision matrix.

  • precision_type (str, optional) – The method, ‘full’, ‘diag’, or ‘spherical’. Defaults to ‘full’.

Returns:

The usable representation.

Return type:

Self

classmethod from_covariance(V, precision_type='full')[source]

Gets instance from covariance matrix according to choice of precision type.

Parameters:
  • V (torch.Tensor) – The square covariance matrix.

  • precision_type (str, optional) – The method, ‘full’, ‘diag’, or ‘spherical’. Defaults to ‘full’.

Returns:

The usable representation.

Return type:

Self

property precision: Tensor

Gets the precision matrix.

Returns:

The precision matrix.

Return type:

torch.Tensor

property covariance: Tensor

Gets the covariance matrix.

Returns:

The covariance matrix.

Return type:

torch.Tensor