Source code for modlee.model_text_converter

Convertor for model objects into text (deprecated?).
import inspect
import torch.nn as nn
import torch
import lightning.pytorch as pl

module_available = True

modlee_required_packages = """
import torch
import torch.nn as nn
from torch.nn import functional as F
import lightning
import lightning.pytorch as pl

import torch
import torch.nn as nn
import inspect

import numpy
import numpy as np

    from torchmetrics.functional import accuracy
except ImportError:
    from pytorch_lightning.metrics.functional import accuracy

#ptorch out-of-box libs needed
from torchvision.utils import _log_api_usage_once

from functools import partial
from typing import Any, Callable, List, Optional, Sequence, Literal

import torch
from torch import nn, Tensor
from torch.nn import functional as F

from torchvision import models
import torchvision
import torchmetrics

# from torchvision.ops.misc import Conv2dNormActivation, Permute
# from torchvision.ops.stochastic_depth import StochasticDepth
# from torchvision.transforms._presets import ImageClassification
# from torchvision.utils import _log_api_usage_once
# from torchvision.models._api import register_model, Weights, WeightsEnum
# from torchvision.models._meta import _IMAGENET_CATEGORIES
# from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface

# from torchvision.models.convnext import *
# from torchvision.models.convnext import CNBlockConfig

# Modlee imports
import modlee


class_header = "class {}({}):\n"

modlee_model_name = "ModleeModel"

[docs] def exhaust_sequence_branch(root_sequence_module, custom_history): """ Exhaust a module as a tree to find custom modules, :param root_sequence_module: The root node/module of the model. :param custom_history: A list of custom module names. :return: A tuple of lists ([custom module objects], [custom module names]) """ sequences = [root_sequence_module] custom_modules_list = [] while len(sequences) != 0: sequence = sequences[-1] module_list = list(sequence.__dict__["_modules"].values()) for module in module_list: module_name = str(module).split("(")[0] if module_name not in OPS_MERGED and module_name not in custom_history: custom_modules_list.append(module) custom_history.add(module_name) elif module_name == "Sequential": sequences.insert(0, module) sequences.pop() return custom_modules_list, custom_history
[docs] def get_code_text(code_text, module, custom_history): """ Get a code text representation of a model as its __init__ and forward functions. :param code_text: The current code text if there are other dependencies, can be empty ''. :param module: The module for which to get the code text. :param custom_history: A list of the history of custom modules. :return: A tuple of code_text, custom_child_module_list, custom_history. """ # get current and parent class names module_class_name = module.__class__.__name__ module_parent_class_name = str(module.__class__.__bases__[0]).split("'")[1] # format the class header class_header_code = class_header.format(module_class_name, module_parent_class_name) forward_code = inspect.getsource(module.forward) init_code = inspect.getsource(module.__init__) functions_to_save = ["__init__", "forward"] if module_parent_class_name.split(".")[-1] in ["LightningModule", "ModleeModel"]: functions_to_save += [ "training_step", "validation_step", "configure_optimizers", ] function_code = {} for function_to_save in functions_to_save: _function = getattr(module, function_to_save, None) if _function is None: functions_to_save.remove(function_to_save) continue function_code.update({function_to_save: inspect.getsource(_function)}) if code_text == "": class_header_code = class_header_code.replace( module_class_name, modlee_model_name ) function_code["__init__"] = function_code["__init__"].replace( module_class_name, modlee_model_name ) code_text = "\n".join( [code_text, class_header_code] + list(function_code.values()) + ["\n"] ) child_module_list = list(module.__dict__["_modules"].values()) custom_child_module_list = [] for child_module in child_module_list: child_module_name = str(child_module).split("(")[0] if ( child_module_name not in OPS_MERGED and child_module_name not in custom_history ): custom_child_module_list.append(child_module) custom_history.add(child_module_name) elif child_module_name == "Sequential": seq_custom_modules_list, custom_history = exhaust_sequence_branch( child_module, custom_history ) custom_child_module_list = ( custom_child_module_list + seq_custom_modules_list ) return code_text, custom_child_module_list, custom_history
[docs] def get_code_text_for_model( model: nn.modules.module.Module | pl.core.module.LightningModule, include_header=False, ): """ Get the code text for a model. :param model: The model for which to get the code text. :return: The code text for the model. """ code_text = "" custom_history = set() module_queue = [model] while len(module_queue) != 0: code_text, custom_child_module_list, custom_history = get_code_text( code_text, module_queue[-1], custom_history ) module_queue.pop() module_queue = module_queue + custom_child_module_list if include_header: code_text = modlee_required_packages + code_text return code_text
[docs] def save_code_text_for_model(code_text: str, include_header: bool = False): """ Save the code text. :param code_text: The code text to save. :param include_header: Whether to include the header of modlee imports in the text, defaults to False. """ if include_header: code_text = modlee_required_packages + code_text file_path = "" # Specify the file path # Open the file in write mode and write the text with open(file_path, "w") as file: file.write(code_text)