jmstate.types.ModelDesign¶
- class ModelDesign(indiv_params_fn, regression_fn, link_fns)[source]¶
Dataclass encapsulating the design of a multistate joint model.
This class defines the parametric and functional design of the model, including individual-specific parameters, regression structures, and link functions for state transitions. All functions should be implemented to allow maximum broadcasting to ensure efficient vectorized computation. If broadcasting is not possible, vmap may be used for safe parallelization.
Functions provided to the MCMC sampler will automatically be wrapped with torch.no_grad(). If gradient computation is required regardless of the sampling context, wrap the function explicitly with torch.enable_grad().
All functions must be well-defined on closed intervals of their input domain and differentiable almost everywhere to ensure compatibility with gradient-based procedures.
- Individual Parameters:
indiv_params_fn is a function that computes individual parameters. Given fixed_effects (fixed population-level parameters), x (covariates matrix of shape \((n, p)\)), and b (random effects, either 2D or 3D), it yields tensors of corresponding dimensions, either 2D or 3D depending on the model design. This function defines the mapping from population-level parameters and covariates to individual-specific parameters.
- Regression:
regression_fn is a function that maps time points and individual parameters to the expected observations. It must accept 1D or 2D time inputs and 2D or 3D individual parameters. The output tensor must have at least three dimensions: the last dimension corresponds to the response variable, the second-last to repeated measurements, the third-last to individuals, and an optional fourth-last dimension for parallelization across MCMC chains.
- Link Functions:
link_fns is a mapping from transition keys (tuples of (state_from, state_to)) to link functions. Each link function shares the same requirements as regression_fn and defines the transformation from regression outputs to transition-specific parameters.
- Variables:
indiv_params_fn (IndividualParametersFn) – Function that computes individual parameters. Given fixed_effects (fixed population-level parameters), x (covariates matrix of shape \((n, p)\)), and b (random effects, either 2D or 3D), it yields tensors of corresponding dimensions, either 2D or 3D depending on the model design. This function defines the mapping from population-level parameters and covariates to individual-specific parameters.
regression_fn (RegressionFn) – Regression function mapping time points and individual parameters to the expected observations. It must accept 1D or 2D time inputs and 2D or 3D individual parameters. The output tensor must have at least three dimensions: the last dimension corresponds to the response variable, the second-last to repeated measurements, the third-last to individuals, and an optional fourth-last dimension for parallelization across MCMC chains.
link_fns (Mapping[tuple[Any, Any], LinkFn]) – Mapping from transition keys (tuples of (state_from, state_to)) to link functions. Each link function shares the same requirements as regression_fn and defines the transformation from regression outputs to transition-specific parameters.
- Parameters:
indiv_params_fn (IndividualParametersFn)
regression_fn (RegressionFn)
link_fns (Mapping[tuple[Any, Any], LinkFn])
Examples
>>> def sigmoid(t: torch.Tensor, indiv_params: torch.Tensor): ... scale, offset, slope = indiv_params.chunk(3, dim=-1) ... # Fully broadcasted ... return (scale * torch.sigmoid((t - offset) / slope)).unsqueeze(-1) >>> fixed_plus_b = lambda fixed, x, b: fixed + b >>> link_fns = {("alive", "dead"): sigmoid} >>> design = ModelDesign(fixed_plus_b, sigmoid, link_fns)
- __init__(indiv_params_fn, regression_fn, link_fns)¶
- Parameters:
indiv_params_fn (IndividualParametersFn)
regression_fn (RegressionFn)
link_fns (Mapping[tuple[Any, Any], LinkFn])
- Return type:
None