ODE-VAE model
We integrate the ODE-based dynamic model in the latent space of a VAE for flexible non-linear dimension reduction, reflecting the assumption of a lower-dimensional underlying dynamic process driving the observed measurements. The following functions are used to define, construct and train an ODE-VAE model.
To jointly optimize all components, i.e., the dynamic model and the VAE for dimension reduction, and the ODE-net for obtaining person-specific ODE parameters, we implement a joint loss function that incorporates all components and optimize it by stochastic gradient descent. This requires to differentiate through our ODE estimator and the calculation of time-dependent inverse-variance weights. Here, we exploit the flexible automatic differentiation system from Zygote.jl to simultaneously obtain gradients with respect to the VAE encoder and decoder parameters and the individual-specific dynamic model parameters in a straightforward way that requires minimal code adaptation. Zygote is specifically useful for that because of its very powerful source-to-source differentiation, that allows for differentiate through arbitrary Julia code, including the ODE solvers, user-defined structs, loops and recursion without any code refactoring or adaptation. For details, check out, e.g., Innes et al. (2019).
As a result of this joint optimization, the components can influence each other, such that a latent representation can be found that is automatically adapted to the underlying dynamics and the ODE system structures and regularizes the representation.
As our ODEs have analytical solutions, differentiation through the latent dynamics estimator does not require backpropagating gradients through a numerical ODE solving step. However, differentiable programming also allows for differentiating through ODE solvers in each loss function gradient update, which can be realized efficiently, e.g., using the adjoint sensitivity method.
Defining the model
LatentDynamics.odevae — TypeodevaeStruct for an ODE-VAE model, with the following fields:
p: number of VAE input dimensions, i.e., number of time-dependent variablesq: number of input dimensions for the baseline neural net, i.e., number of baseline variableszdim: number of latent dimensionsODEnet: neural net to map baseline variables to individual-specific ODE parameters (number of ODE parameters depends on the ODE system specified by thedynamicsfunction)encoder: neural net to map input data to latent spaceencodedμ: neural net layer parameterizing the mean of the latent spaceencodedlogσ: neural net layer parameterizing the log variance of the latent spacedecoder: neural net to map latent variable to reconstructed input datadecodedμ: neural net layer parameterizing the mean of the reconstructed input datadecodedlogσ: neural net layer parameterizing the log variance of the reconstructed input datadynamics: one ofparams_fullinhomogeneous,params_offdiagonalinhomogeneous,params_diagonalinhomogeneous,params_driftonly,params_fullhomogeneous,params_offdiagonalhomogeneous,params_diagonalhomogeneous: function to map a parameter vector (=the output of theODEnet) to the system matrix and constant vector of the ODE system
LatentDynamics.odevae — Methododevae(modelargs::ModelArgs)Function to initialize the ODE-VAE model according to the arguments passed in modelargs.
Returns an odevae model.
LatentDynamics.ModelArgs — TypeModelArgsStruct to store model arguments, can be constructed with keyword arguments to set the following fields:
p: number of VAE input dimensions, i.e., number of time-dependent variablesq: number of input dimensions for the baseline neural net, i.e., number of baseline variableszdim: number of latent dimensionsdynamics: one ofparams_fullinhomogeneous,params_offdiagonalinhomogeneous,params_diagonalinhomogeneous,params_driftonly,params_fullhomogeneous,params_offdiagonalhomogeneous,params_diagonalhomogeneous: function to map a parameter vector (=the output of theODEnet) to the system matrix and constant vector of the ODE systemseed: random seed for reproducibilitybottleneck: whether to use a bottleneck layer in theODEnetto reduce the number of effective parameters for higher-dimensional systemsinit_scaled: whether to initialize theODEnetwith scaled weightsscale_sigmoid: scaling factor for the sigmoid function used to shift the ODE parameters to a sensible range, acting as a prioradd_diagonal: whether to add a diagonal transformation to output of theODEnetto add flexibility after the sigmoid transformation
Training the model
LatentDynamics.LossArgs — TypeLossArgsStruct to store loss arguments, can be constructed with keyword arguments to set the following fields:
λ_μpenalty: weight for the penalty that encourages consistency of the mean before and after solving the ODEsλ_variancepenalty: weight for the penalty on the variance of the ODE estimatorvariancepenaltytype: one of:ratio_sum,:sum_ratio,:log_diff: type of penalty on the variance of the ODE estimatorvariancepenaltyoffset: offset used in the penalty on the variance of the latent spacefirstonly: whether to use only the first time point for solving the ODE (iffalse, an ODE is solved with each time point as initial condition and the individual solutions are averaged)weighting: whether to calculate inverse-variance weights for the contribution of other time points in the ODE trajectory estimator or use just equal weights for all ODE solutionsskipt0: whether to skip the first time point in the ODE estimator (to prevent the model from using just the initial condition and pushing the weights of all other solutions to zero)
LatentDynamics.loss — Functionloss(X, Y, t, m::odevae; args::LossArgs)Compute the loss of the ODE-VAE model m on a batch of data, consisting of time-dependent variables X, baseline variables Y and time point t.
Details of the loss function behaviour, including additional penalties, are controlled by the keyword arguments args of type LossArgs, see ?LossArgs for details.
Returns the mean ELBO, where the ODE estimator of the underlying trajectory is used to decode the latent value at the time points t and obtain a reconstruction according to these smooth latent dynamics as specified by the ODE system.
LatentDynamics.train_model! — Functiontrain_model!(m::odevae,
xs, xs_baseline, tvals,
lr, epochs, args::LossArgs;
selected_ids=nothing,
verbose::Bool=true,
plotting::Bool=true
)Train the ODE-VAE model m on a dataset of time-dependent variables xs, baseline variables xs_baseline and time points tvals. The structure of these is assumed to be as in the SMATestData and simdata structs.
Arguments
m: the ODE-VAE model to trainxs: a vector of matrices of time-dependent variables for each patientxs_baseline: a vector of vectors of baseline variables for each patienttvals: a vector of vectors of time points for each patientlr: the learning rate of the ADAM optimizerepochs: the number of epochs to train forargs: arguments controlling the loss function behaviour, see?LossArgsfor detailsselected_ids: the IDs of the patients to plot during training to monitor progress, ifnothing(default) then 12 random IDs are selectedverbose: whether to print the epoch and loss value during trainingplotting: whether to visualize the learnt latent trajectories of selected patients (those with theselected_ids)
Returns
m: the trained ODE-VAE model