Model

Multistate joint modeling package.

class MultiStateJointModel(design, params, optimizer=None, *, n_quad=32, n_bisect=32, n_chains=5, init_step_size=0.1, adapt_rate=0.1, target_accept_rate=0.44, n_warmup=100, n_subsample=10, max_iter=1000, tol=0.1, window_size=100, verbose=True)[source]

Nonlinear multistate joint model for longitudinal and survival data.

This class generalizes both linear and standard joint models to accommodate multiple states under a semi-Markov assumption. The model is fully specified by a ModelDesign object, which defines longitudinal and hazard functions, and a ModelParameters object, which contains the associated (modifiable) parameters. Parameters may be shared across functions as specified in the model design. The ModelDesign is fixed at initialization, while ModelParameters are updated in place during fitting.

Model fitting is performed using a stochastic gradient ascent algorithm, and parameter sampling is handled by a Metropolis-Within-Gibbs MCMC procedure. The default settings are based on commonly accepted values, and the step sizes are adapted component-wise dynamically based on the acceptance rate.

Dynamic prediction is supported via the prediction methods in PredictMixin, which allow both single and double Monte Carlo integration.

Design settings:
  • design: The model specification defining the individual parameters, regression, and link functions.

  • params: Initial values for the model parameters; modifiable during fitting.

Numerical integration settings:
  • n_quad: Number of nodes for Gauss-Legendre quadrature in hazard integration.

  • n_bisect: Number of bisection steps for transition time sampling.

MCMC settings:
  • sampler: MCMC sampler instance used during model fitting.

  • n_chains: Number of parallel MCMC chains.

  • init_step_size: Initial kernel standard deviation in

    Metropolis-Within-Gibbs.

  • adapt_rate: Adaptation rate for the step size.

  • target_accept_rate: Target mean acceptance probability.

  • n_warmup: Number of warmup iterations per chain.

  • n_subsample: Number of subsamples between predictions; higher values reduce autocorrelation but increase computation time. A value of one means no subsampling. This value may be very sensitive.

Fitting settings:
  • optimizer: Optimizer for stochastic gradient ascent. If None, fitting is disabled. Recommended: torch.optim.Adam with learning rate 0.01 to 0.5.

  • max_iter_fit: Maximum iterations for gradient ascent.

  • n_samples_summary: Number of samples used to compute the Fisher Information Matrix and model selection criteria; higher values improve accuracy.

  • tol: Tolerance for the \(R^2\) convergence criterion.

  • window_size: Window size for \(R^2\) convergence; default 100. This criterion is scale-agnostic and provides a local stationarity test.

Printing and visualization:
  • verbose: Whether to print progress during fitting and prediction.

  • After fitting, summary and plot_params_history (from jmstate.utils) can be used to display p-values, log-likelihood, AIC, BIC, and the evolution of parameters over iterations.

Variables:
  • design (ModelDesign) – The model specification defining the individual parameters, regression, and link functions.

  • params (ModelParameters) – The modifiable model parameters.

  • optimizer (torch.optim.Optimizer | None) – Optimizer used for fitting.

  • n_quad (int) – Number of Gauss-Legendre quadrature nodes for hazard integration.

  • n_bisect (int) – Number of bisection steps for transition time sampling.

  • sampler (MetropolisWithinGibbsSampler | None) – MCMC sampler instance used during model fitting.

  • n_chains (int) – Number of parallel MCMC chains.

  • init_step_size (float) – Initial kernel standard deviation in Metropolis-Within-Gibbs.

  • adapt_rate (float) – Adaptation rate for the MCMC step size.

  • target_accept_rate (float) – Target acceptance probability.

  • n_warmup (int) – Number of warmup iterations per MCMC chain.

  • n_subsample (int) – Number of subsamples for MCMC iterations.

  • max_iter_fit (int) – Maximum number of iterations for stochastic gradient ascent.

  • tol (float) – Tolerance for \(R^2\) convergence criterion.

  • window_size (int) – Window size for \(R^2\) convergence evaluation.

  • n_samples_summary (int) – Number of posterior samples for computing Fisher Information and selection criteria.

  • verbose (bool) – Flag to print fitting and prediction progress.

  • params_history (list[torch.Tensor]) – History of parameter values as flattened tensors.

  • fim (torch.Tensor | None) – Fisher Information Matrix.

  • loglik (float | None) – Log-likelihood of the fitted model.

  • aic (float | None) – Akaike Information Criterion.

  • bic (float | None) – Bayesian Information Criterion.

Parameters:
  • design (ModelDesign)

  • params (ModelParameters)

  • optimizer (Optimizer | None)

  • n_quad (int)

  • n_bisect (int)

  • n_chains (int)

  • init_step_size (float)

  • adapt_rate (float)

  • target_accept_rate (float)

  • n_warmup (int)

  • n_subsample (int)

  • max_iter (int)

  • tol (float)

  • window_size (int)

  • verbose (bool | int)

Examples

>>> from jmstate import MultiStateJointModel
>>> optimizer = torch.optim.Adam(params.parameters(), lr=0.1)
>>> model = MultiStateJointModel(design, params, optimizer)
>>> model.fit(data)
>>> model.summary()
__init__(design, params, optimizer=None, *, n_quad=32, n_bisect=32, n_chains=5, init_step_size=0.1, adapt_rate=0.1, target_accept_rate=0.44, n_warmup=100, n_subsample=10, max_iter=1000, tol=0.1, window_size=100, verbose=True)[source]

Initialize the multistate joint model with specified design and parameters.

Constructs a joint model based on a user-defined ModelDesign and initial ModelParameters. Provides default settings for numerical integration, MCMC sampling, stochastic gradient fitting, and printing options.

Parameters:
  • design (ModelDesign) – The model specification defining the individual parameters, regression, and link functions.

  • params (ModelParameters) – Initial values for the model parameters; modifiable during fitting.

  • optimizer (torch.optim.Optimizer | None, optional) – Optimizer used for fitting. If None, fitting is disabled. Defaults to None.

  • n_quad (int, optional) – Number of nodes for Gauss-Legendre quadrature in hazard integration. Defaults to 32.

  • n_bisect (int, optional) – Number of bisection steps for transition time sampling. Defaults to 32.

  • n_chains (int, optional) – Number of parallel MCMC chains. Defaults to 5.

  • init_step_size (float, optional) – Initial step size for the MCMC sampler. Defaults to 0.1.

  • adapt_rate (float, optional) – Adaptation rate for the MCMC step size. Defaults to 0.1.

  • target_accept_rate (float, optional) – Target mean acceptance probability for Metropolis-Within-Gibbs. Defaults to 0.44.

  • n_warmup (int, optional) – Number of warmup iterations per MCMC chain. Defaults to 100.

  • n_subsample (int, optional) – Number of subsamples between MCMC updates. Defaults to 10.

  • max_iter (int, optional) – Maximum number of iterations for stochastic gradient ascent. Defaults to 1000.

  • tol (float, optional) – Tolerance for \(R^2\) convergence criterion. Defaults to 0.1.

  • window_size (int, optional) – Window size for \(R^2\) convergence evaluation. Defaults to 100.

  • verbose (bool | int, optional) – Flag to print progress during fitting and prediction. Defaults to True.

compute_surv_logps(sample_data, u)

Compute log survival probabilities at specified times.

Evaluates the log-probability of remaining event-free up to the prediction times \(u\) conditional on individual-level parameters and censoring times. The computation uses the hazard function \(\lambda(t)\) to obtain:

\[\log \mathbb{P}(T^* \geq u \mid T^* > c) = -\int_c^u \lambda(t) \, dt,\]

where \(c\) denotes the censoring time for each individual. In cases with multiple possible transitions, \(\lambda(t)\) sums over all transition-specific hazards:

\[-\int_c^u \sum_{k'} \lambda^{k' \mid k}(t \mid t_0) \, dt,\]

exploiting the Chasles property to simplify computation and improve numerical precision.

The input u must be a matrix of shape \((n, m)\) where \(n\) is the number of individuals and \(m\) is the number of prediction time points.

Parameters:
  • sample_data (SampleData) – The dataset containing covariates, individual-level parameters, trajectories, and censoring information.

  • u (torch.Tensor) – Matrix of evaluation times of shape (n, m).

Raises:
  • ValueError – If u contains NaN or infinite values.

  • ValueError – If u has a shape inconsistent with the number of individuals.

Returns:

Computed survival log-probabilities of shape (n, m), with rows corresponding to individuals and columns to prediction times.

Return type:

torch.Tensor

fit(data)

Fit the model to observed data using maximum likelihood estimation.

Computes the Maximum Likelihood Estimate (MLE) \(\hat{\theta}\) of the model parameters. Optimization is performed using the configured optimizer for up to max_iter iterations. Convergence is assessed via a linearity-based stationarity test on the last window_size iterates: the \(R^2\) statistic measures whether the trajectory of each parameter component is better explained by a linear trend than by a constant. Convergence is declared when all \(R^2\) values are below tol, indicating negligible linear drift.

The fitting procedure leverages the Fisher identity coupled with a stochastic gradient algorithm and a Metropolis-Hastings MCMC sampler. The Fisher identity states:

\[\nabla_\theta \log \mathcal{L}(\theta ; x) = \mathbb{E}_{b \sim p(\cdot \mid x, \theta)} \left( \nabla_\theta \log \mathcal{L}(\theta ; x, b) \right).\]
Parameters:

data (ModelData) – Dataset containing covariates, longitudinal measurements, trajectories, and censoring times used for fitting.

Raises:

ValueError – If the optimizer has not been initialized prior to fitting.

Returns:

The fitted model instance with estimated parameters.

Return type:

Self

predict_surv_logps(data, u, *, n_samples=1000, double_monte_carlo=False)

Predict survival log-probabilities at specified times.

Computes the predicted log survival probabilities for each individual at the specified prediction times \(u\) by evaluating the survival function conditional on posterior draws of the random effects \(b\). The computation may optionally use a double Monte Carlo procedure to incorporate parameter uncertainty following Rizopoulos (2011).

The predicted quantity is given by:

\[\log \mathbb{P}(T^* \geq u \mid T^* > c) = -\int_c^u \lambda(t) \, dt,\]

where \(c\) denotes the individual censoring time. In the presence of multiple transitions, \(\lambda(t)\) is the sum over all possible transition-specific hazards:

\[-\int_c^u \sum_{k'} \lambda^{k' \mid k}(t \mid t_0) \, dt,\]

using the Chasles property to simplify computation and improve numerical precision.

The input u must be a matrix of shape \((n, m)\) where \(n\) is the number of individuals and \(m\) is the number of prediction time points.

Parameters:
  • data (ModelData) – The dataset containing covariates, observed outcomes, trajectories, and censoring information.

  • u (torch.Tensor) – Matrix of prediction times of shape (n, m).

  • n_samples (int, optional) – Number of posterior samples to draw. Defaults to 1000.

  • double_monte_carlo (bool, optional) – If True, predictions incorporate a double Monte Carlo procedure to sample parameters. Defaults to False.

Raises:
  • ValueError – If double_monte_carlo is True and the model has not been fitted.

  • ValueError – If u contains NaN or infinite values.

  • ValueError – If u has a shape inconsistent with the number of individuals.

Returns:

Predicted survival log-probabilities of shape

(n_samples, n, m), stacked along the first dimension.

Return type:

torch.Tensor

predict_trajectories(data, c, *, max_length=10, n_samples=1000, double_monte_carlo=False)

Predict individual-level trajectories up to specified censoring times.

Simulates the evolution of individual trajectories conditional on posterior draws of the random effects \(b\), up to the censoring times c. Trajectories are truncated to a maximum length of max_length to avoid infinite loops. The simulation algorithm is a variant of Gillespie’s method adapted for individual parameters. If double_monte_carlo is True, then the prediction is computed using the double Monte Carlo procedure described in Rizopoulos (2011).

The input c must be a column vector of shape \((n, 1)\) where \(n\) is the number of individuals.

Parameters:
  • data (ModelData) – The dataset containing covariates, observed outcomes, trajectories, and censoring information.

  • c (torch.Tensor) – Column vector of censoring times for each individual.

  • max_length (int, optional) – Maximum length of generated trajectories. Defaults to 10.

  • n_samples (int, optional) – Number of posterior samples to draw. Defaults to 1000.

  • double_monte_carlo (bool, optional) – If True, predictions incorporate a double Monte Carlo procedure to sample parameters. Defaults to False.

Raises:
  • ValueError – If double_monte_carlo is True and the model has not been fitted.

  • ValueError – If c contains NaN or infinite values.

  • ValueError – If c has a shape inconsistent with the number of individuals.

Returns:

Predicted trajectories for each individual, organized as a list of lists, with the outer list indexing posterior draws and the inner list indexing individuals.

Return type:

list[list[Trajectory]]

predict_y(data, u, *, n_samples=1000, double_monte_carlo=False)

Predict longitudinal measurements at specified times.

Computes the predicted longitudinal responses for each individual at the specified prediction times \(u\) by evaluating the regression function conditional on posterior draws of the random effects \(b\). The prediction may optionally use a double Monte Carlo procedure to account for parameter uncertainty following Rizopoulos (2011).

The input u must be a matrix of shape \((n, m)\) where \(n\) is the number of individuals and \(m\) is the number of prediction time points.

Parameters:
  • data (ModelData) – The dataset containing covariates, observed outcomes, trajectories, and censoring information.

  • u (torch.Tensor) – Matrix of prediction times of shape (n, m).

  • n_samples (int, optional) – Number of posterior samples to draw. Defaults to 1000.

  • double_monte_carlo (bool, optional) – If True, predictions incorporate a double Monte Carlo procedure to sample parameters. Defaults to False.

Raises:
  • ValueError – If double_monte_carlo is True and the model has not been fitted.

  • ValueError – If u contains NaN or infinite values.

  • ValueError – If u has a shape inconsistent with the number of individuals.

Returns:

Predicted longitudinal outcomes of shape (n_samples, n, m),

where predictions are stacked along the first dimension.

Return type:

torch.Tensor

sample_trajectories(sample_data, c, *, max_length=10)

Simulate individual trajectories from the multistate joint model.

Generates sample trajectories for each individual up to the censoring times c, truncating to a maximum of max_length transitions to prevent infinite loops. The simulation employs a variant of Gillespie’s algorithm adapted for individual parameter draws. These sampled trajectories form the basis for posterior predictive checks or downstream predictions in the joint model framework.

The input c must be a column vector of shape \((n, 1)\) where \(n\) is the number of individuals.

Parameters:
  • sample_data (SampleData) – The dataset containing covariates, trajectories, and individual-level parameter used for sampling.

  • c (torch.Tensor) – Column vector of censoring times for each individual.

  • max_length (int, optional) – Maximum number of iterations or transitions sampled per trajectory. Defaults to 10.

Raises:
  • ValueError – If c contains NaN or infinite values.

  • ValueError – If c has a shape inconsistent with the number of individuals.

Returns:

List of sampled trajectories, one per individual, with

each trajectory truncated at the censoring time.

Return type:

list[Trajectory]

property stderr: Tensor

Computes the estimated standard errors of the model parameters.

The standard errors are derived from the diagonal of the inverse of the estimated Fisher Information Matrix evaluated at the Maximum Likelihood Estimate (MLE). They provide a measure of uncertainty for each parameter and can be used to construct confidence intervals.

\[\mathrm{stderr} = \sqrt{\operatorname{diag}\left( \hat{\mathcal{I}}_n (\hat{\theta})^{-1} \right)}\]
Raises:

ValueError – If the model has not been fitted and the Fisher Information Matrix is unavailable.

Returns:

Vector of standard errors corresponding to each parameter.

Return type:

torch.Tensor