Train
Labelprop Index / Labelprop / Train
Auto-generated documentation for labelprop.train module.
inference
Perform inference using a trained LabelProp model.
Arguments
datamoduleDataModule - The data module used for training.model_PARAMSdict - The parameters used to initialize the LabelProp model.ckptstr - The path to the checkpoint file of the trained model.**kwargs- Additional keyword arguments.
Returns
Tuple- A tuple containing the predicted labels for the up, down, and fused directions.
Signature
def inference(datamodule, model_PARAMS, ckpt, **kwargs): ...
train
Train the model using the given data module and parameters.
Arguments
datamodule- The data module used for training.model_PARAMS- The parameters for the model.max_epochs- The maximum number of epochs for training.ckpt- The checkpoint path for resuming training (optional).pretraining- Whether to perform pretraining (default: False).**kwargs- Additional keyword arguments.
Returns
model- The trained model.best_ckpt- The path to the best model checkpoint.
Signature
def train(
datamodule, model_PARAMS, max_epochs, ckpt=None, pretraining=False, **kwargs
): ...
train_and_eval
Train and evaluate a LabelProp model.
Arguments
datamoduleDataModule - The data module containing the dataset.model_PARAMSdict - The parameters for the LabelProp model.max_epochsint - The maximum number of epochs to train the model.ckptstr, optional - The path to a checkpoint file to load the model from. Defaults to None.
Returns
tuple- A tuple containing the trained model, the propagated labels (Y_up), the inverse propagated labels (Y_down), and the evaluation results.
Signature
def train_and_eval(datamodule, model_PARAMS, max_epochs, ckpt=None): ...