Source code for modlee.model_text_converter

""" 
Convertor for model objects into text (deprecated?).
"""
import inspect
import torch.nn as nn
import torch
import lightning.pytorch as pl

module_available = True

modlee_required_packages = """
import torch
import torch.nn as nn
from torch.nn import functional as F
import lightning
import lightning.pytorch as pl

import torch
import torch.nn as nn
import inspect

import numpy
import numpy as np

try:
    from torchmetrics.functional import accuracy
except ImportError:
    from pytorch_lightning.metrics.functional import accuracy

#ptorch out-of-box libs needed
from torchvision.utils import _log_api_usage_once

from functools import partial
from typing import Any, Callable, List, Optional, Sequence, Literal

import torch
from torch import nn, Tensor
from torch.nn import functional as F

from torchvision import models
import torchvision
import torchmetrics

# from torchvision.ops.misc import Conv2dNormActivation, Permute
# from torchvision.ops.stochastic_depth import StochasticDepth
# from torchvision.transforms._presets import ImageClassification
# from torchvision.utils import _log_api_usage_once
# from torchvision.models._api import register_model, Weights, WeightsEnum
# from torchvision.models._meta import _IMAGENET_CATEGORIES
# from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface

# from torchvision.models.convnext import *
# from torchvision.models.convnext import CNBlockConfig

# Modlee imports
import modlee

"""

OPS_MERGED = {
    "utils",
    "HingeEmbeddingLoss",
    "ConstantPad2d",
    "efficientnet_v2_l",
    "convnext_large",
    "GRUCell",
    "CTCLoss",
    "maxvit",
    "FeatureAlphaDropout",
    "mnasnet0_75",
    "RegNet_Y_800MF_Weights",
    "AlphaDropout",
    "Conv1d",
    "LazyConvTranspose2d",
    "__spec__",
    "swin_b",
    "RNNCellBase",
    "mobilenet_v3_small",
    "Hardsigmoid",
    "__package__",
    "GoogLeNetOutputs",
    "Dropout1d",
    "vit_b_16",
    "ConvTranspose1d",
    "LPPool1d",
    "LazyLinear",
    "efficientnet_b6",
    "grad",
    "GELU",
    "SoftMarginLoss",
    "DenseNet169_Weights",
    "Conv2d",
    "parameter",
    "ELU",
    "ReLU",
    "wide_resnet101_2",
    "__path__",
    "regnet_x_400mf",
    "PixelShuffle",
    "vgg13",
    "efficientnet_b3",
    "efficientnet_b2",
    "MNASNet1_3_Weights",
    "Softmax",
    "LazyConv1d",
    "DenseNet161_Weights",
    "MNASNet0_75_Weights",
    "__doc__",
    "RReLU",
    "SqueezeNet",
    "mobilenet_v3_large",
    "DenseNet201_Weights",
    "TripletMarginLoss",
    "AdaptiveMaxPool1d",
    "mnasnet1_3",
    "Fold",
    "efficientnet_b7",
    "efficientnet_b5",
    "VGG13_BN_Weights",
    "Conv3d",
    "Parameter",
    "googlenet",
    "MultiMarginLoss",
    "convnext_small",
    "TripletMarginWithDistanceLoss",
    "get_model",
    "modules",
    "MobileNetV2",
    "ResNet101_Weights",
    "squeezenet1_1",
    "ConvNeXt_Tiny_Weights",
    "ResNeXt50_32X4D_Weights",
    "Swin_V2_T_Weights",
    "NLLLoss",
    "LazyConvTranspose1d",
    "get_weight",
    "vgg",
    "parallel",
    "SqueezeNet1_0_Weights",
    "inception",
    "resnet",
    "swin_v2_s",
    "factory_kwargs",
    "VGG16_BN_Weights",
    "LazyBatchNorm1d",
    "EfficientNet_B7_Weights",
    "DenseNet",
    "swin_t",
    "SmoothL1Loss",
    "Dropout2d",
    "Upsample",
    "TransformerDecoder",
    "Weights",
    "EfficientNet_B4_Weights",
    "ShuffleNet_V2_X1_5_Weights",
    "LazyBatchNorm3d",
    "resnet34",
    "ConstantPad1d",
    "BCEWithLogitsLoss",
    "RegNet",
    "CELU",
    "ModuleDict",
    "mobilenet_v2",
    "swin_s",
    "densenet161",
    "squeezenet",
    "BatchNorm3d",
    "AdaptiveAvgPool3d",
    "MultiLabelSoftMarginLoss",
    "regnet_x_3_2gf",
    "init",
    "__builtins__",
    "resnet18",
    "MobileNet_V2_Weights",
    "ShuffleNetV2",
    "Swin_V2_S_Weights",
    "ShuffleNet_V2_X0_5_Weights",
    "functional",
    "resnext50_32x4d",
    "optical_flow",
    "VisionTransformer",
    "AdaptiveLogSoftmaxWithLoss",
    "UninitializedBuffer",
    "MaxUnpool1d",
    "efficientnet_b1",
    "quantizable",
    "LazyInstanceNorm1d",
    "efficientnet",
    "VGG11_Weights",
    "LeakyReLU",
    "CosineEmbeddingLoss",
    "Hardswish",
    "RegNet_Y_16GF_Weights",
    "Linear",
    "ConvTranspose3d",
    "resnext101_64x4d",
    "RegNet_X_16GF_Weights",
    "HuberLoss",
    "inception_v3",
    "regnet_y_400mf",
    "vgg11_bn",
    "Tanhshrink",
    "Transformer",
    "GroupNorm",
    "efficientnet_v2_m",
    "ViT_L_16_Weights",
    "PoissonNLLLoss",
    "vgg16",
    "ParameterDict",
    "Softplus",
    "squeezenet1_0",
    "BatchNorm2d",
    "SELU",
    "PairwiseDistance",
    "efficientnet_b0",
    "LazyInstanceNorm3d",
    "LocalResponseNorm",
    "shufflenet_v2_x1_5",
    "vgg19",
    "__cached__",
    "swin_v2_t",
    "__name__",
    "mnasnet",
    "GLU",
    "MNASNet0_5_Weights",
    "VGG16_Weights",
    "_GoogLeNetOutputs",
    "RNNCell",
    "VGG",
    "SqueezeNet1_1_Weights",
    "ConvNeXt_Large_Weights",
    "common_types",
    "mobilenetv3",
    "GRU",
    "Mish",
    "mnasnet1_0",
    "MobileNet_V3_Large_Weights",
    "TransformerEncoder",
    "Threshold",
    "EfficientNet_V2_M_Weights",
    "Sigmoid",
    "regnet_y_800mf",
    "AdaptiveAvgPool2d",
    "Container",
    "EfficientNet_V2_L_Weights",
    "ViT_B_16_Weights",
    "SwinTransformer",
    "Softsign",
    "vit_l_16",
    "PixelUnshuffle",
    "MarginRankingLoss",
    "MaxPool2d",
    "regnet_y_16gf",
    "regnet_y_8gf",
    "densenet121",
    "ConvNeXt",
    "ReflectionPad3d",
    "mobilenetv2",
    "MSELoss",
    "RegNet_X_400MF_Weights",
    "MaxUnpool3d",
    "ShuffleNet_V2_X2_0_Weights",
    "list_models",
    "Embedding",
    "Softmin",
    "EfficientNet_B2_Weights",
    "PReLU",
    "shufflenet_v2_x0_5",
    "mobilenet",
    "ConvNeXt_Base_Weights",
    "Swin_S_Weights",
    "Identity",
    "InstanceNorm3d",
    "resnet152",
    "Unflatten",
    "_utils",
    "ViT_B_32_Weights",
    "efficientnet_v2_s",
    "ReflectionPad2d",
    "Hardshrink",
    "ResNeXt101_32X8D_Weights",
    "ResNeXt101_64X4D_Weights",
    "MultiLabelMarginLoss",
    "CrossMapLRN2d",
    "VGG13_Weights",
    "AvgPool3d",
    "RegNet_Y_32GF_Weights",
    "shufflenet_v2_x1_0",
    "LazyConv3d",
    "LSTM",
    "vgg11",
    "WeightsEnum",
    "MultiheadAttention",
    "swin_transformer",
    "get_model_builder",
    "regnet_x_800mf",
    "detection",
    "RegNet_X_32GF_Weights",
    "Wide_ResNet50_2_Weights",
    "MaxVit",
    "AlexNet",
    "MaxPool1d",
    "SiLU",
    "ShuffleNet_V2_X1_0_Weights",
    "GoogLeNet_Weights",
    "AdaptiveMaxPool2d",
    "maxvit_t",
    "BCELoss",
    "DataParallel",
    "Tanh",
    "EfficientNet_V2_S_Weights",
    "MobileNetV3",
    "EfficientNet_B3_Weights",
    "get_model_weights",
    "ReflectionPad1d",
    "intrinsic",
    "ParameterList",
    "ReplicationPad1d",
    "EfficientNet_B5_Weights",
    "ConvNeXt_Small_Weights",
    "qat",
    "densenet169",
    "vgg19_bn",
    "swin_v2_b",
    "RegNet_X_8GF_Weights",
    "EmbeddingBag",
    "ConvTranspose2d",
    "vit_b_32",
    "NLLLoss2d",
    "vision_transformer",
    "KLDivLoss",
    "__file__",
    "Module",
    "EfficientNet_B6_Weights",
    "GaussianNLLLoss",
    "TransformerEncoderLayer",
    "BatchNorm1d",
    "Flatten",
    "regnet_y_3_2gf",
    "ReplicationPad3d",
    "VGG19_BN_Weights",
    "resnet50",
    "L1Loss",
    "wide_resnet50_2",
    "convnext",
    "UpsamplingNearest2d",
    "vgg16_bn",
    "regnet_y_128gf",
    "_api",
    "LPPool2d",
    "convnext_tiny",
    "Swin_B_Weights",
    "ConstantPad3d",
    "RegNet_Y_8GF_Weights",
    "AvgPool2d",
    "resnext101_32x8d",
    "UninitializedParameter",
    "AvgPool1d",
    "video",
    "segmentation",
    "MNASNet",
    "mnasnet0_5",
    "regnet_x_32gf",
    "convnext_base",
    "regnet_x_8gf",
    "LogSoftmax",
    "VGG11_BN_Weights",
    "ViT_L_32_Weights",
    "ZeroPad2d",
    "regnet_x_1_6gf",
    "regnet_y_32gf",
    "RNNBase",
    "GoogLeNet",
    "Swin_T_Weights",
    "LayerNorm",
    "__loader__",
    "AlexNet_Weights",
    "RegNet_X_800MF_Weights",
    "Inception3",
    "RNN",
    "_InceptionOutputs",
    "FractionalMaxPool3d",
    "vgg13_bn",
    "Dropout",
    "CrossEntropyLoss",
    "vit_h_14",
    "InceptionOutputs",
    "RegNet_Y_128GF_Weights",
    "InstanceNorm2d",
    "MaxUnpool2d",
    "_reduction",
    "LogSigmoid",
    "efficientnet_b4",
    "quantization",
    "ResNet",
    "resnet101",
    "LazyBatchNorm2d",
    "ChannelShuffle",
    "Softmax2d",
    "TransformerDecoderLayer",
    "LazyConv2d",
    "MNASNet1_0_Weights",
    "FractionalMaxPool2d",
    "shufflenetv2",
    "EfficientNet",
    "MaxPool3d",
    "SyncBatchNorm",
    "Softshrink",
    "MaxVit_T_Weights",
    "RegNet_X_3_2GF_Weights",
    "ResNet18_Weights",
    "Unfold",
    "ViT_H_14_Weights",
    "AdaptiveAvgPool1d",
    "MobileNet_V3_Small_Weights",
    "VGG19_Weights",
    "RegNet_Y_400MF_Weights",
    "ModuleList",
    "_meta",
    "ResNet34_Weights",
    "Dropout3d",
    "LSTMCell",
    "AdaptiveMaxPool3d",
    "Inception_V3_Weights",
    "EfficientNet_B0_Weights",
    "Hardtanh",
    "RegNet_Y_3_2GF_Weights",
    "LazyInstanceNorm2d",
    "EfficientNet_B1_Weights",
    "DenseNet121_Weights",
    "regnet_x_16gf",
    "Wide_ResNet101_2_Weights",
    "RegNet_X_1_6GF_Weights",
    "quantized",
    "CosineSimilarity",
    "ReplicationPad2d",
    "Bilinear",
    "densenet",
    "RegNet_Y_1_6GF_Weights",
    "regnet_y_1_6gf",
    "ReLU6",
    "InstanceNorm1d",
    "densenet201",
    "LazyConvTranspose3d",
    "alexnet",
    "vit_l_32",
    "shufflenet_v2_x2_0",
    "Sequential",
    "Swin_V2_B_Weights",
    "regnet",
    "ResNet50_Weights",
    "ResNet152_Weights",
    "UpsamplingBilinear2d",
}
class_header = "class {}({}):\n"

modlee_model_name = "ModleeModel"

[docs] def exhaust_sequence_branch(root_sequence_module, custom_history): """ Exhaust a module as a tree to find custom modules, :param root_sequence_module: The root node/module of the model. :param custom_history: A list of custom module names. :return: A tuple of lists ([custom module objects], [custom module names]) """ sequences = [root_sequence_module] custom_modules_list = [] while len(sequences) != 0: sequence = sequences[-1] module_list = list(sequence.__dict__["_modules"].values()) for module in module_list: module_name = str(module).split("(")[0] if module_name not in OPS_MERGED and module_name not in custom_history: custom_modules_list.append(module) custom_history.add(module_name) elif module_name == "Sequential": sequences.insert(0, module) sequences.pop() return custom_modules_list, custom_history
[docs] def get_code_text(code_text, module, custom_history): """ Get a code text representation of a model as its __init__ and forward functions. :param code_text: The current code text if there are other dependencies, can be empty ''. :param module: The module for which to get the code text. :param custom_history: A list of the history of custom modules. :return: A tuple of code_text, custom_child_module_list, custom_history. """ # get current and parent class names module_class_name = module.__class__.__name__ module_parent_class_name = str(module.__class__.__bases__[0]).split("'")[1] # format the class header class_header_code = class_header.format(module_class_name, module_parent_class_name) forward_code = inspect.getsource(module.forward) init_code = inspect.getsource(module.__init__) functions_to_save = ["__init__", "forward"] if module_parent_class_name.split(".")[-1] in ["LightningModule", "ModleeModel"]: functions_to_save += [ "training_step", "validation_step", "configure_optimizers", ] function_code = {} for function_to_save in functions_to_save: _function = getattr(module, function_to_save, None) if _function is None: functions_to_save.remove(function_to_save) continue function_code.update({function_to_save: inspect.getsource(_function)}) if code_text == "": class_header_code = class_header_code.replace( module_class_name, modlee_model_name ) function_code["__init__"] = function_code["__init__"].replace( module_class_name, modlee_model_name ) code_text = "\n".join( [code_text, class_header_code] + list(function_code.values()) + ["\n"] ) child_module_list = list(module.__dict__["_modules"].values()) custom_child_module_list = [] for child_module in child_module_list: child_module_name = str(child_module).split("(")[0] if ( child_module_name not in OPS_MERGED and child_module_name not in custom_history ): custom_child_module_list.append(child_module) custom_history.add(child_module_name) elif child_module_name == "Sequential": seq_custom_modules_list, custom_history = exhaust_sequence_branch( child_module, custom_history ) custom_child_module_list = ( custom_child_module_list + seq_custom_modules_list ) return code_text, custom_child_module_list, custom_history
[docs] def get_code_text_for_model( model: nn.modules.module.Module | pl.core.module.LightningModule, include_header=False, ): """ Get the code text for a model. :param model: The model for which to get the code text. :return: The code text for the model. """ code_text = "" custom_history = set() module_queue = [model] while len(module_queue) != 0: code_text, custom_child_module_list, custom_history = get_code_text( code_text, module_queue[-1], custom_history ) module_queue.pop() module_queue = module_queue + custom_child_module_list if include_header: code_text = modlee_required_packages + code_text return code_text
[docs] def save_code_text_for_model(code_text: str, include_header: bool = False): """ Save the code text. :param code_text: The code text to save. :param include_header: Whether to include the header of modlee imports in the text, defaults to False. """ if include_header: code_text = modlee_required_packages + code_text file_path = "modlee_model.py" # Specify the file path # Open the file in write mode and write the text with open(file_path, "w") as file: file.write(code_text)