image1

Image Segmentation

In this tutorial, we will build an image segmentation model using the Pascal VOC 2012 dataset, leveraging the Modlee package for experimentation.

Steps Overview:

  1. Setup and Initialization

  2. Dataset Preparation

  3. Model Definition

  4. Model Training

  5. Results and Artifacts Retrieval

Open in Kaggle

First, we will import the the necessary libraries and set up the environment.

import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Subset
import tensorflow_datasets as tfds
from pytorch_lightning.loggers import TensorBoardLogger
import lightning.pytorch as pl
import modlee
import random
import os

Now, we will set up the modlee API key and initialize the modlee package. You can access your modlee API key from the dashboard.

Replace replace-with-your-api-key with your API key.

modlee.init(api_key="replace-with-your-api-key")

Now, we will define transformations for the input images and segmentation masks. Both will be resized to 256x256 pixels for standardization.

# Define the transformations applied to the images and masks
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((256, 256)),  # Resize all images to 256x256
    torchvision.transforms.ToTensor()           # Convert images to PyTorch tensors
])

target_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((256, 256)),  # Resize the masks to match the image size
    torchvision.transforms.ToTensor()           # Convert masks to PyTorch tensors
])

Next, we will load the Pascal VOC 2012 dataset using the VOCSegmentation class.

# Prepare the VOC 2012 dataset for segmentation tasks
train_dataset = torchvision.datasets.VOCSegmentation(
    root='./data', year='2012', image_set='train', download=True,
    transform=transform, #`transform` applies to input images
    target_transform=target_transform #`target_transform` applies to segmentation masks
)

val_dataset = torchvision.datasets.VOCSegmentation(
    root='./data', year='2012', image_set='val', download=True,
    transform=transform,
    target_transform=target_transform
)

To accelerate the training process, we will create smaller subsets of the training and validation datasets. We will define a subset of 500 samples for training and 100 samples for validation.

# Use only a subset of the training and validation data to speed up training

subset_size = int(len(train_dataset) * 0.1)  # Use 10% of the training data
train_indices = random.sample(range(len(train_dataset)), subset_size)
train_subset = Subset(train_dataset, train_indices)

# Create subsets of the datasets based on the indices we defined above
val_subset_size = int(len(val_dataset) * 0.1)  # Use 10% of the validation data
val_indices = random.sample(range(len(val_dataset)), val_subset_size)
val_subset = Subset(val_dataset, val_indices)

We will now create DataLoader instances for both the training and validation subsets.

# Create DataLoader for both training and validation data
train_dataloader = DataLoader(train_subset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_subset, batch_size=8, shuffle=False)

Next, we will create the image segmentation model within Modlee’s framework. We offer two different approaches for selecting a model:

Option 1: Use a Recommended Modlee Model

If you’d like to start with a benchmark solution, Modlee provides pre-trained and optimized models for specific tasks. You can retrieve a recommended model as follows:

recommender = modlee.recommender.from_modality_task(
    modality='image',
    task='segmentation',
    in_channels=3
    )

recommender.fit(train_dataloader)
recommended_modlee_model = recommender.model

Option 2: Define Your Own Modlee Model

If you want to experiment with a custom architecture, you can define your own model. Below, we provide a model featuring an encoder for extracting relevant features and a decoder for generating the segmentation mask.

# Define the image segmentation model
class ImageSegmentation(modlee.model.ImageSegmentationModleeModel):
    def __init__(self, in_channels=3):
        super().__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 128, kernel_size=3, padding=1),
            torch.nn.ReLU(),
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 1, kernel_size=1),
        )
        self.loss_fn = torch.nn.BCEWithLogitsLoss()

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        output_size = x.shape[2:]
        decoded = torch.nn.functional.interpolate(decoded, size=output_size, mode='bilinear', align_corners=False)
        return decoded

    def training_step(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = self.loss_fn(logits, y.float())
        return loss

    def validation_step(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = self.loss_fn(logits, y.float())
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

model = ImageSegmentation(in_channels=3)  # Initialize the model

Now, we can train and evaluate our model using PyTorch Lightning for one epoch. For this example, we’ll continue as if we created our own model.

# Train the model using Modlee and PyTorch Lightning
with modlee.start_run() as run:
    # Set `max_epochs=1` to train for 1 epoch
    trainer = pl.Trainer(max_epochs=1)

    # Fit the model on the training and validation data
    trainer.fit(model=model,
                train_dataloaders=train_dataloader,
                val_dataloaders=val_dataloader)

After training, we will examine the artifacts saved by Modlee, such as the model graph and various statistics. Modlee automatically preserves your training assets, ensuring that valuable insights are available for future reference and collaboration.

# Retrieve the path where Modlee saved the results of this run
last_run_path = modlee.last_run_path()
print(f"Run path: {last_run_path}")
artifacts_path = os.path.join(last_run_path, 'artifacts')
artifacts = sorted(os.listdir(artifacts_path))
print(f"Saved artifacts: {artifacts}")