from modlee.model import ModleeModel
import torch
import torch.nn.functional as F
from torch.optim import lr_scheduler
[docs]
class RecommendedModel(ModleeModel):
"""
A ready-to-train ModleeModel that wraps around a recommended model from the recommender.
Contains default functions for training.
"""
def __init__(self, model, loss_fn=F.cross_entropy, *args, **kwargs):
"""
Constructor for a recommended model.
"""
super().__init__(*args, **kwargs)
self.model = model
self.loss_fn = loss_fn
[docs]
def forward(self, x):
return self.model(x)
[docs]
def training_step(self, batch, batch_idx, *args, **kwargs):
x, y = batch
y_out = self(x)
loss = self.loss_fn(y_out, y)
return {"loss": loss}
[docs]
def validation_step(self, val_batch, batch_idx, *args, **kwargs):
x, y = val_batch
y_out = self(x)
loss = self.loss_fn(y_out, y)
return {"val_loss": loss}
[docs]
def on_train_epoch_end(self) -> None:
"""
Update the learning rate scheduler.
"""
sch = self.scheduler
if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau):
sch.step(self.trainer.callback_metrics["loss"])
self.log("scheduler_last_lr", sch._last_lr[0])
return super().on_train_epoch_end()