jmstate.types.ModelData¶
- class ModelData(x, t, y, trajectories, c)[source]¶
Dataclass containing learnable multistate joint model data.
It can be used with scikit-learn’s train-test splits or cross-validation utilities.
- Covariates:
x is a matrix of covariates of shape \((n, p)\), where \(n\) denotes the number of individuals and \(p\) the number of covariates.
- Measurement Times:
- t represents the measurement times. It can be either:
a 2D tensor of shape \((n, m)\) when each individual has individual-specific time points,
a 1D tensor of shape \((m,)\) when all individuals share the same measurement times.
Padding with NaNs is optional. However, t must not contain NaN values at positions where y is observed (i.e., where y is not NaN).
- Observations:
y is expected to be a 3D tensor of shape \((n, m, d)\), where \(n\) is the number of individuals, \(m\) is the maximum number of measurements per individual, and \(d\) is the dimension of the observation space \(\mathbb{R}^d\). Padding is performed with NaNs.
- Trajectories:
trajectories contains the individual-level multistate trajectories. They correspond to a list[list[tuple[float, Any]]] where each inner list is a trajectory and each tuple is a (time, state) pair.
- Censoring Times:
c represents the right censoring times, and are expected to be a column vector of shape \((n, 1)\). Each value must be greater than or equal to the corresponding maximum transition time for each individual.
The data can be completed using the prepare method, which formats it for manual likelihood evaluation and MCMC procedures. This usage is intended for advanced users familiar with the codebase. The method is called automatically by the fit and predict routines and does not require explicit user intervention in standard workflows.
- Raises:
ValueError – If some trajectory is empty.
ValueError – If some trajectory is not sorted.
ValueError – If some trajectory is not compatible with the censoring times.
ValueError – If any of the inputs contain inf or NaN values except y.
ValueError – If the size is not consistent between inputs.
- Variables:
x (torch.Tensor) – Fixed covariate matrix of shape (n, p), where n is the number of individuals and p the number of covariates.
t (torch.Tensor) – Measurement times. Either a 1D tensor of shape (m,) when times are shared across individuals, or a 2D tensor of shape (n, m) when individuals have distinct time grids. Padding with NaNs may be used when required.
y (torch.Tensor) – Longitudinal measurements of shape (n, m, d), where n is the number of individuals, m the maximum number of measurements per individual, and d the observation dimension. Padding is performed with NaNs when necessary.
trajectories (list[Trajectory]) – List of individual trajectories. Each Trajectory consists of a sequence of (time, state) tuples.
c (torch.Tensor) – Censoring times provided as a column vector. Each value must be greater than or equal to the corresponding maximum trajectory time.
valid_mask (torch.Tensor) – Boolean mask indicating valid (non-padded) measurements.
n_valid (torch.Tensor) – Number of valid measurements per individual.
valid_t (torch.Tensor) – Filtered tensor containing only valid measurement times.
valid_y (torch.Tensor) – Filtered tensor containing only valid measurements.
buckets (dict[tuple[Any, Any], tuple[torch.Tensor, ...]]) – Grouped trajectory data structures used for likelihood computation.
- Parameters:
x (Tensor)
t (Tensor)
y (Tensor)
trajectories (list[list[tuple[float, str]]])
c (Tensor)
- __init__(x, t, y, trajectories, c)¶
- Parameters:
x (Tensor)
t (Tensor)
y (Tensor)
trajectories (list[list[tuple[float, str]]])
c (Tensor)
- Return type:
None