Model training

mutable struct TrainingArgs

Struct to store hyperparameters to control and customise the training process of an scVAE model. Can be constructed using keywords.

Keyword arguments:

  • trainsize::Float32=0.9f0: proportion of data to be used for training when using a train-test split for training. Has no effect when train_test_split==false.
  • train_test_split::Bool=false: whether or not to randomly split the data into training and test set.
  • batchsize::Int=128: batchsize to be used when partitioning the data into minibatches for training based on stochastic gradient descent
  • max_epochs::Int=400: number of epochs to train the model
  • lr::Float64=1e-3: learning rate (=stepsize) of the ADAM optimiser during the stochastic descent optimisation for model training (for details, see ?ADAM).
  • weight_decay::Float32=0.0f0: rate of weight decay to apply in the ADAM optimiser (for details, see ?ADAM).
  • n_steps_kl_warmup::Union{Int, Nothing}=nothing: number of steps (one gradient descent optimiser update for one batch) over which to perform gradual increase (warm-up, annealing) of the weight of the regularising KL-divergence term in the loss function (ensuring the consistency between variational posterior and standard normal prior). Empirically, this improves model inference.
  • n_epochs_kl_warmup::Union{Int, Nothing}=400: number of epochs (one update for all batches) over which to perform gradual increase (warm-up, annealing) of the weight of the regularising KL-divergence term in the loss function (ensuring the consistency between variational posterior and standard normal prior). Empirically, this improves model inference.
  • progress::Bool=true: whether or not to print a progress bar and the current value of the loss function to the REPL.
  • register_losses::Bool=false: whether or not to record the values of the different loss components after each training epoch in the loss_registry of the scVAE model. If true, for each loss component (reconstruction error, KL divergences, total loss), an array will be created in the dictionary with the name of the loss component as key, where after each epoch, the value of the component is saved.
  • verbose::Bool=false: only kicks in if progress==false: whether or not to print the current epoch and value of the loss function every verbose_freq epoch.
  • verbose_freq::Int=10: frequency with which to display the current epoch and current value of the loss function (only if progress==false and verbose==true).
train_model!(m::scVAE, adata::AnnData, training_args::TrainingArgs; batch_key::Symbol=:batch)

Trains an scVAE model on an AnnData object, where the behaviour is controlled by a TrainingArgs object: Defines the ADAM SGD optimiser, collects the model parameters, optionally splits data in training and testdata and initialises a Flux.DataLoader storing the data in the countmatrix of the AnnData object in batches. Updates the model parameters via stochastic gradient for the specified number of epochs, optionally prints out progress and current loss values.

Returns the trained scVAE model.

train_supervised_model!(m::scVAE, adata::AnnData, labels::AbstractVecOrMat{S}, training_args::TrainingArgs) where S <: Real

Trains a scVAE model on an AnnData object, where the latent representation is additionally trained in a supervised way to match the provided labels, where the behaviour is controlled by a TrainingArgs object:

Defines the ADAM SGD optimiser, collects the model parameters, optionally splits data in training and testdata and initialises a Flux.DataLoader storing the data in the countmatrix of the AnnData object and the corresponding labels for the supervision of the latent representation in batches.

The loss function used is the ELBO with an additional supervised term (can be checked in the function supervised_loss in src/ModelFunctions.jl: In addition to the scVAE model and the count data, it has as additional input the provided labels, that need to have the same dimension as the latent represenation. The mean squared error between the latent representation and the labels is calculated and added to the standard ELBO loss.

Updates the model parameters via stochastic gradient for the specified number of epochs, optionally prints out progress and current loss values.

Returns the trained scVAE model.
