jmstate.types.ModelParameters

class ModelParameters(fixed_effects, random_prec, noise_prec, base_hazards, link_coefs, x_coefs)[source]

nn.Module encapsulating all model parameters for a multistate joint model.

This module contains fixed population-level parameters, covariate effects, link coefficients, random effects, noise precision, and log base hazard functions. Parameters can be shared by assigning the same nn.Parameter object to multiple fields. Reusing tensors directly is not supported and requires wrapping in nn.Parameter for correct computations, as nn.ParameterDict would break the object ids otherwise.

Fixed population-level parameters:
  • fixed_effects are fixed population-level parameters.

Random effects and noise precision matrices:
  • random_prec and noise_prec are PrecisionParameters objects representing the random effects and residual noise precision matrices respectively.

Base hazard functions:
  • base_hazards is a dictionary of LogBaseHazardFn modules keyed by (from_state, to_state) tuples for each transition; optimization can be disabled per hazard via its frozen attribute.

Link and covariate coefficients:
  • link_coefs and x_coefs are dictionaries keyed by (from_state, to_state) tuples, representing linear coefficients for links and covariates, respectively.

Variables:
  • fixed_effects (torch.Tensor) – Fixed population-level parameters.

  • random_prec (PrecisionParameters) – Precision parameters for random effects.

  • noise_prec (PrecisionParameters) – Precision parameters for residual noise.

  • base_hazards (nn.ModuleDict) – Log base hazard functions per transition.

  • link_coefs (nn.ParameterDict) – Linear parameters for the link functions.

  • x_coefs (nn.ParameterDict) – Covariate parameters for each transition.

Parameters:

Examples

>>> fixed_effects = torch.zeros(3)
>>> random_prec = PrecisionParameters.from_covariance(torch.eye(3), "diag")
>>> noise_prec = PrecisionParameters.from_covariance(torch.eye(2), "spherical")
>>> link_coefs = {(0, 1): torch.zeros(3), (1, 0): torch.zeros(3)}
>>> x_coefs = {(0, 1): torch.zeros(2), (1, 0): torch.zeros(2)}
>>> params = ModelParameters(
...     fixed_effects,
...     random_prec,
...     noise_prec,
...     link_coefs,
...     x_coefs,
... )
>>> # Shared parameters
>>> shared_coef = nn.Parameter(torch.zeros(3))  # Mandatory nn.Parameter
>>> shared_link_coefs = {(0, 1): shared_coef, (1, 0): shared_coef}
>>> shared_params = ModelParameters(
...     fixed_effects,
...     random_prec,
...     noise_prec,
...     shared_link_coefs,
...     x_coefs,
... )
__init__(fixed_effects, random_prec, noise_prec, base_hazards, link_coefs, x_coefs)[source]

Initializes the ModelParams object.

Parameters:
  • fixed_effects (torch.Tensor) – Fixed population-level parameters.

  • random_prec (PrecisionParameters) – Precision parameters for random effects.

  • noise_prec (PrecisionParameters) – Precision parameters for residual noise.

  • base_hazards (dict[tuple[Any, Any], LogBaseHazardFn]) – Log base hazard functions.

  • link_coefs (dict[tuple[Any, Any], torch.Tensor]) – Linear parameters for the link functions.

  • x_coefs (dict[tuple[Any, Any], torch.Tensor]) – Covariate parameters for each transition.

Raises:

ValueError – If any of the tensors contains NaN or infinite values.

numel()[source]

Return the number of unique parameters.

Returns:

The number of the (unique) parameters.

Return type:

int