modlee.model.callbacks module

class modlee.model.callbacks.DataMetafeaturesCallback(data_snapshot_size=10000000.0, DataMetafeatures=<class 'modlee.data_metafeatures.DataMetafeatures'>, *args, **kwargs)[source]

Bases: ModleeCallback

Callback to calculate and log data meta-features.

on_train_start(trainer: Trainer, pl_module: LightningModule) None[source]

Called when the train begins.

class modlee.model.callbacks.LogCodeTextCallback(kwargs_to_cache={}, *args, **kwargs)[source]

Bases: ModleeCallback

Callback to log the model as code and text.

on_train_start(trainer: Trainer, pl_module: LightningModule) None[source]

Called when the train begins.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]

Called when fit, validate, test, predict, or tune begins.

class modlee.model.callbacks.LogModelCheckpointCallback(monitor='val_loss', filename='model_checkpoint', temp_dir_path='/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/modlee/tmp/checkpoints', save_top_k=1, mode='min', verbose=True, *args, **kwargs)[source]

Bases: ModelCheckpoint

Callback to log the best performing model in a training routine based on Loss value

log_and_clean_checkpoint(trainer, pl_module, step)[source]

Log the latest checkpoint to MLflow and remove the local file.

Parameters:
  • trainer – The PyTorch Lightning Trainer instance.

  • pl_module – The LightningModule being trained.

  • step – The step (‘train’ or ‘val’) that triggered the checkpoint.

on_fit_end(trainer, pl_module)[source]

Called when fit ends.

on_train_epoch_end(trainer, pl_module)[source]

Save a checkpoint at the end of the training epoch.

class modlee.model.callbacks.LogONNXCallback[source]

Bases: ModleeCallback

Callback for logging the model in its ONNX representations. Deprecated, will be combined with LogCodeTextCallback.

on_train_start(trainer: Trainer, pl_module: LightningModule) None[source]

Called when the train begins.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]

Called when fit, validate, test, predict, or tune begins.

class modlee.model.callbacks.LogOutputCallback(*args, **kwargs)[source]

Bases: Callback

Callback to log the output metrics for each batch.

class modlee.model.callbacks.LogParamsCallback[source]

Bases: Callback

Callback to log parameters at the start of training.

on_train_start(trainer: Trainer, pl_module: LightningModule) None[source]

Called when the train begins.

class modlee.model.callbacks.LogTransformsCallback[source]

Bases: ModleeCallback

Logs transforms applied to the dataset, if applied with torchvision.transforms

on_train_start(trainer, pl_module)[source]

Called when the train begins.

class modlee.model.callbacks.ModelMetafeaturesCallback(ModelMetafeatures=<class 'modlee.model_metafeatures.ModelMetafeatures'>)[source]

Bases: ModleeCallback

on_train_start(trainer: Trainer, pl_module: LightningModule) None[source]

Called when the train begins.

class modlee.model.callbacks.ModleeCallback[source]

Bases: Callback

Base class for Modlee-specific callbacks.

get_input(trainer, pl_module)[source]

Get an input (one element from a batch) from a trainer’s dataloader.

Parameters:
  • trainer – The trainer with the dataloader.

  • pl_module – The model module, used for loading the data input to the correct device.

Returns:

An input from the batch.

class modlee.model.callbacks.PushServerCallback[source]

Bases: Callback

Callback to push run assets to the server at the end of training.

on_fit_end(trainer: Trainer, pl_module: LightningModule) None[source]

Called when fit ends.