Model training
scVI.TrainingArgs
— Typemutable struct TrainingArgs
Struct to store hyperparameters to control and customise the training process of an scVAE
model. Can be constructed using keywords.
Keyword arguments:
trainsize::Float32=0.9f0
: proportion of data to be used for training when using a train-test split for training. Has no effect whentrain_test_split==false
.train_test_split::Bool=false
: whether or not to randomly split the data into training and test set.batchsize::Int=128
: batchsize to be used when partitioning the data into minibatches for training based on stochastic gradient descentmax_epochs::Int=400
: number of epochs to train the modellr::Float64=1e-3
: learning rate (=stepsize) of the ADAM optimiser during the stochastic descent optimisation for model training (for details, see?ADAM
).weight_decay::Float32=0.0f0
: rate of weight decay to apply in the ADAM optimiser (for details, see?ADAM
).n_steps_kl_warmup::Union{Int, Nothing}=nothing
: number of steps (one gradient descent optimiser update for one batch) over which to perform gradual increase (warm-up, annealing) of the weight of the regularising KL-divergence term in the loss function (ensuring the consistency between variational posterior and standard normal prior). Empirically, this improves model inference.n_epochs_kl_warmup::Union{Int, Nothing}=400
: number of epochs (one update for all batches) over which to perform gradual increase (warm-up, annealing) of the weight of the regularising KL-divergence term in the loss function (ensuring the consistency between variational posterior and standard normal prior). Empirically, this improves model inference.progress::Bool=true
: whether or not to print a progress bar and the current value of the loss function to the REPL.register_losses::Bool=false
: whether or not to record the values of the different loss components after each training epoch in theloss_registry
of thescVAE
model. Iftrue
, for each loss component (reconstruction error, KL divergences, total loss), an array will be created in the dictionary with the name of the loss component as key, where after each epoch, the value of the component is saved.verbose::Bool=false
: only kicks in ifprogress==false
: whether or not to print the current epoch and value of the loss function everyverbose_freq
epoch.verbose_freq::Int=10
: frequency with which to display the current epoch and current value of the loss function (only ifprogress==false
andverbose==true
).
scVI.train_model!
— Functiontrain_model!(m::scVAE, adata::AnnData, training_args::TrainingArgs; batch_key::Symbol=:batch)
Trains an scVAE
model on an AnnData
object, where the behaviour is controlled by a TrainingArgs
object: Defines the ADAM SGD optimiser, collects the model parameters, optionally splits data in training and testdata and initialises a Flux.DataLoader
storing the data in the countmatrix of the AnnData
object in batches. Updates the model parameters via stochastic gradient for the specified number of epochs, optionally prints out progress and current loss values.
Returns the trained scVAE
model.
scVI.train_supervised_model!
— Functiontrain_supervised_model!(m::scVAE, adata::AnnData, labels::AbstractVecOrMat{S}, training_args::TrainingArgs) where S <: Real
Trains a scVAE
model on an AnnData
object, where the latent representation is additionally trained in a supervised way to match the provided labels
, where the behaviour is controlled by a TrainingArgs
object:
Defines the ADAM SGD optimiser, collects the model parameters, optionally splits data in training and testdata and initialises a Flux.DataLoader
storing the data in the countmatrix of the AnnData
object and the corresponding labels
for the supervision of the latent representation in batches.
The loss function used is the ELBO with an additional supervised term (can be checked in the function supervised_loss
in src/ModelFunctions.jl
: In addition to the scVAE
model and the count data, it has as additional input the provided labels
, that need to have the same dimension as the latent represenation. The mean squared error between the latent representation and the labels is calculated and added to the standard ELBO loss.
Updates the model parameters via stochastic gradient for the specified number of epochs, optionally prints out progress and current loss values.
Returns the trained scVAE
model.