Source code for modlee

""" 
Modlee package.
"""
import traceback
import importlib
import glob
from contextlib import contextmanager, redirect_stderr, redirect_stdout

import os
from os import devnull
from os.path import dirname, basename, isfile, join

import pathlib
from pathlib import Path
from urllib.parse import urlparse

from functools import partial
import logging, warnings

logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)

from .api_config import ModleeAPIConfig

import mlflow
from mlflow import start_run, last_active_run

from .retriever import *
from .utils import save_run, last_run_path, save_run_as_json
from .model_text_converter import get_code_text, get_code_text_for_model
from . import (
    model_text_converter,
    data_metafeatures,
    model,
    recommender,
    config,
)

api_modules = [
    "model_text_converter",
]
modules = glob.glob(join(dirname(__file__), "*.py"))
__all__ = [
    basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py")
]

for _logger in [
    "pytorch_lightning",
    "lightning.pytorch.core",
    "mlflow",
    "torchvision",
    "torch.nn",
]:
    pl_logger = logging.getLogger(_logger)
    pl_logger.propagate = False
    pl_logger.setLevel(logging.ERROR)
warnings.filterwarnings("ignore", ".*does not have many workers.*")
warnings.filterwarnings("ignore", ".*turn shuffling off.*")
warnings.filterwarnings("ignore", ".*Arguments other than a weight enum or.*")
warnings.filterwarnings("ignore", ".*The parameter 'pretrained' is deprecated since.*")
warnings.filterwarnings("ignore", ".*Using a target size.*")
warnings.filterwarnings("ignore", ".*Implicit dimension choice.*")
warnings.filterwarnings("ignore", ".*divides the total loss by both.*")
warnings.filterwarnings(
    "ignore", ".*To copy construct from a tensor, it is recommended.*"
)
warnings.filterwarnings("ignore", ".*NLLLoss2d has been deprecated.*")
warnings.filterwarnings("ignore", ".*The default value of the antialias parameter.*")
warnings.filterwarnings("ignore", ".*No names were found for specified dynamic axes.*")
warnings.filterwarnings("ignore", ".*Starting from v1.9.0.*")
warnings.filterwarnings(
    "ignore", "Input data has range zero. The results may not be accurate."
)
warnings.filterwarnings("ignore", "scipy.stats.shapiro: Input data has range zero.")


[docs] @contextmanager def suppress_stdout_stderr(): """A context manager that redirects stdout and stderr to devnull""" with open(devnull, "w") as fnull: with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out: yield (err, out)
[docs] def init(run_path=None, api_key=None): """ Initialize package. Typically called at the beginning of a machine learning pipeline. Sets the run path where experiment assets will be stored. :param run_path: The path to the current run. """ # if run_dir not provided, set to the same directory as the calling file if run_path is None or os.path.exists(run_path) == False: run_path = os.getcwd() set_run_path(run_path) auth(api_key)
[docs] def auth(api_key=None): """ Fetches API functionality if the API key is valid. :param api_key: The user's API key, if it is not available as an environment variable. """ config = ModleeAPIConfig() if api_key: config.set_api_key(api_key) else: config.ensure_api_key()
# logging.warning("API key not provided. Functionality will be limited.")
[docs] def set_run_path(run_path): """ Set the path to the current run. This is where the experiment assets will be saved. :param run_path: The path to the current run. :raises FileNotFoundError: If the path does not exist, will not create the parent directories. :return: The tracking URI for the experiment. """ # Checking if path is absolute if not os.path.isabs(run_path): run_path = os.path.abspath(run_path) logging.debug(f"Setting run logs to abspath {run_path}") # Checking if path contains mlruns if "mlruns" not in run_path.split("/")[-1]: run_path = os.path.join(run_path, "mlruns") # Setting base directory and checking for existence run_dir_base = os.path.dirname(run_path) if not os.path.exists(run_dir_base): raise FileNotFoundError( f"No base directory {run_dir_base}, cannot set tracking URI" ) # Setting tracking URI for mlflow tracking_uri = pathlib.Path(run_path).as_uri() mlflow.set_tracking_uri(tracking_uri) return tracking_uri
[docs] def get_run_path(): """ Get the path to the current run. :return: The path to the current run. """ artifact_path = urlparse(mlflow.get_tracking_uri()).path return artifact_path
[docs] def get_api_key(): """ Get the current API key. """ global API_KEY return API_KEY