jmstate.types.LinkFn

class LinkFn(*args, **kwargs)[source]

Protocol defining a link function for multistate models.

A link function maps evaluation times and individual-specific parameters to transformed outputs, such as transition-specific parameters. Requirements are identical to those of RegressionFn.

Tensor input conventions:

  • t represents the measurement times. It can be either:
    • a 2D tensor of shape \((n, m)\) when each individual has individual-specific time points,

    • a 1D tensor of shape \((m,)\) when all individuals share the same measurement times.

  • indiv_params represents the individual-specific parameters. It is expected to have shape \((n, l)\) or \((n_chains, n, l)\) where \(n\) is the number of individuals and \(l\) is the number of parameters.

Tensor output conventions:

  • Last dimension corresponds to the response variable dimension \(d\).

  • Second-last dimension corresponds to repeated measurements \(m\).

  • Third-last dimension corresponds to individual index \(n\).

  • Optional fourth-last dimension may be used for parallelization across MCMC chains (= batched processing).

This protocol is conceptually identical to RegressionFn.

Parameters:
  • t (torch.Tensor) – Evaluation times of shape \((n, m)\) or \((m,)\).

  • indiv_params (torch.Tensor) – Individual parameters of shape 2D \((n, l)\) or 3D \((n_chains, n, l)\).

Returns:

Transformed outputs consistent with shapes (n, m, d) or

(n_chains, n, m, d) for parallelized computations.

Return type:

torch.Tensor

Examples

>>> def sigmoid(t: torch.Tensor, indiv_params: torch.Tensor):
...     scale, offset, slope = indiv_params.chunk(3, dim=-1)
...     # Fully broadcasted computation
...     return (scale * torch.sigmoid((t - offset) / slope)).unsqueeze(-1)
__init__(*args, **kwargs)