jmstate.types.IndividualParametersFn

class IndividualParametersFn(*args, **kwargs)[source]

Protocol defining the individual parameters function.

This function maps population-level parameters, covariates, and random effects to individual-specific parameters used in a multistate model.

Tensor input conventions:

  • fixed_effects: fixed population-level parameters.

  • x: covariates matrix of shape \((n, p)\).

  • b: random effects of shape either \((n, q)\) (2D) or \((n, m, q)\) (3D).

Tensor output conventions:

  • Last dimension corresponds to the number of parameters \(l\).

  • Second-last dimension corresponds to the number of individuals \(n\).

  • Optional third-last dimension may be used for parallelization across MCMC chains (= batched processing).

Parameters:
  • fixed_effects (torch.Tensor) – Fixed population-level parameters.

  • x (torch.Tensor) – Fixed covariates matrix.

  • b (torch.Tensor) – Random effects tensor.

Returns:

Individual parameters tensor of shape consistent with

\((n, l)\) or \((n_chains, n, l)\) for parallelized computations.

Return type:

torch.Tensor

Examples

>>> indiv_params_fn = lambda fixed, x, b: fixed + b
__init__(*args, **kwargs)