ODE-VAE model

We integrate the ODE-based dynamic model in the latent space of a VAE for flexible non-linear dimension reduction, reflecting the assumption of a lower-dimensional underlying dynamic process driving the observed measurements. The following functions are used to define, construct and train an ODE-VAE model.

To jointly optimize all components, i.e., the dynamic model and the VAE for dimension reduction, and the ODE-net for obtaining person-specific ODE parameters, we implement a joint loss function that incorporates all components and optimize it by stochastic gradient descent. This requires to differentiate through our ODE estimator and the calculation of time-dependent inverse-variance weights. Here, we exploit the flexible automatic differentiation system from Zygote.jl to simultaneously obtain gradients with respect to the VAE encoder and decoder parameters and the individual-specific dynamic model parameters in a straightforward way that requires minimal code adaptation. Zygote is specifically useful for that because of its very powerful source-to-source differentiation, that allows for differentiate through arbitrary Julia code, including the ODE solvers, user-defined structs, loops and recursion without any code refactoring or adaptation. For details, check out, e.g., Innes et al. (2019).

As a result of this joint optimization, the components can influence each other, such that a latent representation can be found that is automatically adapted to the underlying dynamics and the ODE system structures and regularizes the representation.

As our ODEs have analytical solutions, differentiation through the latent dynamics estimator does not require backpropagating gradients through a numerical ODE solving step. However, differentiable programming also allows for differentiating through ODE solvers in each loss function gradient update, which can be realized efficiently, e.g., using the adjoint sensitivity method.

Defining the model

LatentDynamics.odevaeType
odevae

Struct for an ODE-VAE model, with the following fields:

  • p: number of VAE input dimensions, i.e., number of time-dependent variables
  • q: number of input dimensions for the baseline neural net, i.e., number of baseline variables
  • zdim: number of latent dimensions
  • ODEnet: neural net to map baseline variables to individual-specific ODE parameters (number of ODE parameters depends on the ODE system specified by the dynamics function)
  • encoder: neural net to map input data to latent space
  • encodedμ: neural net layer parameterizing the mean of the latent space
  • encodedlogσ: neural net layer parameterizing the log variance of the latent space
  • decoder: neural net to map latent variable to reconstructed input data
  • decodedμ: neural net layer parameterizing the mean of the reconstructed input data
  • decodedlogσ: neural net layer parameterizing the log variance of the reconstructed input data
  • dynamics: one of params_fullinhomogeneous, params_offdiagonalinhomogeneous, params_diagonalinhomogeneous, params_driftonly, params_fullhomogeneous, params_offdiagonalhomogeneous, params_diagonalhomogeneous: function to map a parameter vector (=the output of the ODEnet) to the system matrix and constant vector of the ODE system
source
LatentDynamics.odevaeMethod
odevae(modelargs::ModelArgs)

Function to initialize the ODE-VAE model according to the arguments passed in modelargs.

Returns an odevae model.

source
LatentDynamics.ModelArgsType
ModelArgs

Struct to store model arguments, can be constructed with keyword arguments to set the following fields:

  • p: number of VAE input dimensions, i.e., number of time-dependent variables
  • q: number of input dimensions for the baseline neural net, i.e., number of baseline variables
  • zdim: number of latent dimensions
  • dynamics: one of params_fullinhomogeneous, params_offdiagonalinhomogeneous, params_diagonalinhomogeneous, params_driftonly, params_fullhomogeneous, params_offdiagonalhomogeneous, params_diagonalhomogeneous: function to map a parameter vector (=the output of the ODEnet) to the system matrix and constant vector of the ODE system
  • seed: random seed for reproducibility
  • bottleneck: whether to use a bottleneck layer in the ODEnet to reduce the number of effective parameters for higher-dimensional systems
  • init_scaled: whether to initialize the ODEnet with scaled weights
  • scale_sigmoid: scaling factor for the sigmoid function used to shift the ODE parameters to a sensible range, acting as a prior
  • add_diagonal: whether to add a diagonal transformation to output of the ODEnet to add flexibility after the sigmoid transformation
source

Training the model

LatentDynamics.LossArgsType
LossArgs

Struct to store loss arguments, can be constructed with keyword arguments to set the following fields:

  • λ_μpenalty: weight for the penalty that encourages consistency of the mean before and after solving the ODEs
  • λ_variancepenalty: weight for the penalty on the variance of the ODE estimator
  • variancepenaltytype: one of :ratio_sum, :sum_ratio, :log_diff: type of penalty on the variance of the ODE estimator
  • variancepenaltyoffset: offset used in the penalty on the variance of the latent space
  • firstonly: whether to use only the first time point for solving the ODE (if false, an ODE is solved with each time point as initial condition and the individual solutions are averaged)
  • weighting: whether to calculate inverse-variance weights for the contribution of other time points in the ODE trajectory estimator or use just equal weights for all ODE solutions
  • skipt0: whether to skip the first time point in the ODE estimator (to prevent the model from using just the initial condition and pushing the weights of all other solutions to zero)
source
LatentDynamics.lossFunction
loss(X, Y, t, m::odevae; args::LossArgs)

Compute the loss of the ODE-VAE model m on a batch of data, consisting of time-dependent variables X, baseline variables Y and time point t.

Details of the loss function behaviour, including additional penalties, are controlled by the keyword arguments args of type LossArgs, see ?LossArgs for details.

Returns the mean ELBO, where the ODE estimator of the underlying trajectory is used to decode the latent value at the time points t and obtain a reconstruction according to these smooth latent dynamics as specified by the ODE system.

source
LatentDynamics.train_model!Function
train_model!(m::odevae, 
    xs, xs_baseline, tvals, 
    lr, epochs, args::LossArgs; 
    selected_ids=nothing, 
    verbose::Bool=true, 
    plotting::Bool=true
    )

Train the ODE-VAE model m on a dataset of time-dependent variables xs, baseline variables xs_baseline and time points tvals. The structure of these is assumed to be as in the SMATestData and simdata structs.

Arguments

  • m: the ODE-VAE model to train
  • xs: a vector of matrices of time-dependent variables for each patient
  • xs_baseline: a vector of vectors of baseline variables for each patient
  • tvals: a vector of vectors of time points for each patient
  • lr: the learning rate of the ADAM optimizer
  • epochs: the number of epochs to train for
  • args: arguments controlling the loss function behaviour, see ?LossArgs for details
  • selected_ids: the IDs of the patients to plot during training to monitor progress, if nothing (default) then 12 random IDs are selected
  • verbose: whether to print the epoch and loss value during training
  • plotting: whether to visualize the learnt latent trajectories of selected patients (those with the selected_ids)

Returns

  • m: the trained ODE-VAE model
source