jmstate.types.RegressionFn¶
- class RegressionFn(*args, **kwargs)[source]¶
Protocol defining a regression function for multistate models.
This function maps evaluation times and individual-specific parameters to predicted response values. It must support both 1D and 2D time inputs and individual parameters of order 2 or 3, returning either 3D or 4D tensors depending on the model design.
Tensor input conventions:
- t represents the measurement times. It can be either:
a 2D tensor of shape \((n, m)\) when each individual has individual-specific time points,
a 1D tensor of shape \((m,)\) when all individuals share the same measurement times.
indiv_params represents the individual-specific parameters. It is expected to have shape \((n, l)\) or \((n_chains, n, l)\) where \(n\) is the number of individuals and \(l\) is the number of parameters.
Tensor output conventions:
Last dimension corresponds to the response variable dimension \(d\).
Second-last dimension corresponds to repeated measurements \(m\).
Third-last dimension corresponds to individual index \(n\).
Optional fourth-last dimension may be used for parallelization across MCMC chains (= batched processing).
This protocol is conceptually identical to LinkFn.
- Parameters:
t (torch.Tensor) – Evaluation times of shape \((n, m)\) or \((m,)\).
indiv_params (torch.Tensor) – Individual parameters of shape 2D \((n, l)\) or 3D \((n_chains, n, l)\).
- Returns:
- Predicted response values of shape consistent with (n, m, d) or
(n_chains, n, m, d) for parallelized computations.
- Return type:
torch.Tensor
Examples
>>> def sigmoid(t: torch.Tensor, indiv_params: torch.Tensor): ... scale, offset, slope = indiv_params.chunk(3, dim=-1) ... # Fully broadcasted computation ... return (scale * torch.sigmoid((t - offset) / slope)).unsqueeze(-1)
- __init__(*args, **kwargs)¶