jmstate.types.IndividualParametersFn¶
- class IndividualParametersFn(*args, **kwargs)[source]¶
Protocol defining the individual parameters function.
This function maps population-level parameters, covariates, and random effects to individual-specific parameters used in a multistate model.
Tensor input conventions:
fixed_effects: fixed population-level parameters.
x: covariates matrix of shape \((n, p)\).
b: random effects of shape either \((n, q)\) (2D) or \((n, m, q)\) (3D).
Tensor output conventions:
Last dimension corresponds to the number of parameters \(l\).
Second-last dimension corresponds to the number of individuals \(n\).
Optional third-last dimension may be used for parallelization across MCMC chains (= batched processing).
- Parameters:
fixed_effects (torch.Tensor) – Fixed population-level parameters.
x (torch.Tensor) – Fixed covariates matrix.
b (torch.Tensor) – Random effects tensor.
- Returns:
- Individual parameters tensor of shape consistent with
\((n, l)\) or \((n_chains, n, l)\) for parallelized computations.
- Return type:
torch.Tensor
Examples
>>> indiv_params_fn = lambda fixed, x, b: fixed + b
- __init__(*args, **kwargs)¶