image1

Image Classification Model Recommendation

This example notebook uses the modlee package to train a recommended model. We will perform image classification on CIFAR10 from torchvision.

Here is a video explanation of this exercise.

Open in Kaggle

First, import torch- and modlee-related packages.

import os
import lightning.pytorch as pl

# Set your API key

os.environ['MODLEE_API_KEY'] = "replace-with-your-api-key"

import torch, torchvision
import torchvision.transforms as transforms

Next, initialize the package.

import modlee

# Initialize the Modlee package
modlee.init(api_key=os.environ.get('MODLEE_API_KEY'))

Now, we can create a dataloader from CIFAR10.

transforms = transforms.Compose([
    transforms.ToTensor(), #converts images to PyTorch tensors
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #adjusts the color values
    ])

train_dataset = torchvision.datasets.CIFAR10( #this command gets the CIFAR-10 images
    root='./data',
    train=True, #loading the training split of the dataset
    download=True,
    transform=transforms) #applies transformations defined earlier

val_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False, #loading the validation split of the dataset
    download=True,
    transform=transforms)

train_dataloader = torch.utils.data.DataLoader( #this tool loads the data
    train_dataset,
    batch_size=16, #we will load the images in groups of 16
   )

val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=16
)

Create a modlee recommender object and fit to the dataset. This process will calculate the dataset metafeatures to send to the server. The server will return a recommended model for the dataset assigned to recommender.model.

# create a Modlee recommender object
recommender = modlee.recommender.ImageClassificationRecommender(
    num_classes=10
)

# recommender analyzes training data to suggest best model
recommender.fit(train_dataloader)

#retrieves the recommended model
modlee_model = recommender.model
print(f"\nRecommended model: \n{modlee_model}")
INFO:Analyzing dataset based on data metafeatures...
INFO:Finished analyzing dataset.
INFO:The model is available at the recommender object's `model` attribute.

Recommended model:
RecommendedModel(
  (model): GraphModule(
    (Conv): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
    (Conv_1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (Relu): ReLU()
    (MaxPool): MaxPool2d(kernel_size=[3, 3], stride=[2, 2], padding=[1, 1], dilation=[1, 1], ceil_mode=False)
    (Conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (Relu_1): ReLU()
    (Conv_3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (Add): OnnxBinaryMathOperation()
    (Relu_2): ReLU()
    (Conv_4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (Relu_3): ReLU()
    (Conv_5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    ...

We can train the model, with automatic documentation of metafeatures.

with modlee.start_run() as run:
    trainer = pl.Trainer(max_epochs=1)
    trainer.fit(
        model=modlee_model,
        train_dataloaders=train_dataloader,
        val_dataloaders=val_dataloader
    )
  | Name  | Type        | Params
--------------------------------------
0 | model | GraphModule | 11.7 M
--------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.779    Total estimated model params size (MB)
Epoch 0: 100%|██████████| 3125/3125 [01:14<00:00, 41.86it/s, v_num=0]

Finally, we can view the saved assets from training. With Modlee, your training assets are automatically saved, preserving valuable insights for future reference and collaboration.

last_run_path = modlee.last_run_path()
print(f"Run path: {last_run_path}")
artifacts_path = os.path.join(last_run_path, 'artifacts')
artifacts = sorted(os.listdir(artifacts_path))
print(f"Saved artifacts: {artifacts}")
Run path: /home/ubuntu/projects/modlee_pypi/examples/mlruns/0/ff1e754d6401438fba506a0d98ca1f91
Saved artifacts: ['cached_vars', 'checkpoints', 'model', 'model.py', 'model_graph.py', 'model_graph.txt', 'model_size', 'model_summary.txt', 'stats_rep', 'transforms.txt']