jmstate.types.RegressionFn

class RegressionFn(*args, **kwargs)[source]

Protocol defining a regression function for multistate models.

This function maps evaluation times and individual-specific parameters to predicted response values. It must support both 1D and 2D time inputs and individual parameters of order 2 or 3, returning either 3D or 4D tensors depending on the model design.

Tensor input conventions:

  • t represents the measurement times. It can be either:
    • a 2D tensor of shape \((n, m)\) when each individual has individual-specific time points,

    • a 1D tensor of shape \((m,)\) when all individuals share the same measurement times.

  • indiv_params represents the individual-specific parameters. It is expected to have shape \((n, l)\) or \((n_chains, n, l)\) where \(n\) is the number of individuals and \(l\) is the number of parameters.

Tensor output conventions:

  • Last dimension corresponds to the response variable dimension \(d\).

  • Second-last dimension corresponds to repeated measurements \(m\).

  • Third-last dimension corresponds to individual index \(n\).

  • Optional fourth-last dimension may be used for parallelization across MCMC chains (= batched processing).

This protocol is conceptually identical to LinkFn.

Parameters:
  • t (torch.Tensor) – Evaluation times of shape \((n, m)\) or \((m,)\).

  • indiv_params (torch.Tensor) – Individual parameters of shape 2D \((n, l)\) or 3D \((n_chains, n, l)\).

Returns:

Predicted response values of shape consistent with (n, m, d) or

(n_chains, n, m, d) for parallelized computations.

Return type:

torch.Tensor

Examples

>>> def sigmoid(t: torch.Tensor, indiv_params: torch.Tensor):
...     scale, offset, slope = indiv_params.chunk(3, dim=-1)
...     # Fully broadcasted computation
...     return (scale * torch.sigmoid((t - offset) / slope)).unsqueeze(-1)
__init__(*args, **kwargs)