jmstate.types.LinkFn¶
- class LinkFn(*args, **kwargs)[source]¶
Protocol defining a link function for multistate models.
A link function maps evaluation times and individual-specific parameters to transformed outputs, such as transition-specific parameters. Requirements are identical to those of RegressionFn.
Tensor input conventions:
- t represents the measurement times. It can be either:
a 2D tensor of shape \((n, m)\) when each individual has individual-specific time points,
a 1D tensor of shape \((m,)\) when all individuals share the same measurement times.
indiv_params represents the individual-specific parameters. It is expected to have shape \((n, l)\) or \((n_chains, n, l)\) where \(n\) is the number of individuals and \(l\) is the number of parameters.
Tensor output conventions:
Last dimension corresponds to the response variable dimension \(d\).
Second-last dimension corresponds to repeated measurements \(m\).
Third-last dimension corresponds to individual index \(n\).
Optional fourth-last dimension may be used for parallelization across MCMC chains (= batched processing).
This protocol is conceptually identical to RegressionFn.
- Parameters:
t (torch.Tensor) – Evaluation times of shape \((n, m)\) or \((m,)\).
indiv_params (torch.Tensor) – Individual parameters of shape 2D \((n, l)\) or 3D \((n_chains, n, l)\).
- Returns:
- Transformed outputs consistent with shapes (n, m, d) or
(n_chains, n, m, d) for parallelized computations.
- Return type:
torch.Tensor
Examples
>>> def sigmoid(t: torch.Tensor, indiv_params: torch.Tensor): ... scale, offset, slope = indiv_params.chunk(3, dim=-1) ... # Fully broadcasted computation ... return (scale * torch.sigmoid((t - offset) / slope)).unsqueeze(-1)
- __init__(*args, **kwargs)¶