modlee.utils module

Utility functions.

class modlee.utils.DummyDataset(num_samples=100, input_channels=10, sequence_length=20)[source]

Bases: Dataset

class modlee.utils.TimeseriesDataset(data, target, input_seq: int, output_seq: int, time_column: str, encoder_column: list)[source]

Bases: Dataset

Class to handle data loading of the time series dataset.

get_dataset()[source]
to_dataloader(batch_size: int = 32, shuffle: bool = False)[source]
modlee.utils.apply_discretize_to_summary(text, info)[source]

Discretize a summary.

Parameters:
  • text – The text to discretize.

  • info – An object that contains different separators.

Returns:

The discretized summary.

modlee.utils.class_from_modality_task(modality, task, _class, *args, **kwargs)[source]

Return a Recommender object based on the modality and task. Currently supports:

  • image

— classification — segmentation - text — classification

Parameters:
  • _class – The class to return, either “Model” or “Recommender”

  • modality – The modality as a string, e.g. “image”, “text”.

  • task – The task as a string, e.g. “classification”, “segmentation”.

Returns:

The RecommenderObject, if it exists.

modlee.utils.closest_power_of_2(x)[source]

Round a number to its closest power of 2, i.e. y = 2**floor(log_2(x)).

Parameters:

x – The number.

Returns:

The closest power of 2 of the number.

modlee.utils.convert_to_scientific(x)[source]

Convert a number to scientific notation.

Parameters:

x – The number to convert.

Returns:

The number in scientific notation as a string.

modlee.utils.discretize(n: list[float, int]) list[float, int][source]

Discretize a list of inputs

Parameters:

n – The list of inputs to discretize.

Returns:

The list of discretized inputs.

modlee.utils.get_dataloader(dataset, batch_size=16, shuffle=True, *args, **kwargs)[source]
modlee.utils.get_fashion_mnist(batch_size=64, num_output_channels=1)[source]

Get the Fashion MNIST dataset from torchvision.

Parameters:
  • batch_size – The batch size, defaults to 64.

  • num_output_channels – Passed to torchvision.transforms.Grayscale. 1 = grayscale, 3 = RGB. Defaults to 1.

Returns:

A tuple of train and test dataloaders.

modlee.utils.get_imagenette_dataloader()[source]

Get a small validation dataloader for imagenette (https://pytorch.org/vision/stable/generated/torchvision.datasets.Imagenette.html#torchvision.datasets.Imagenette)

modlee.utils.get_modality_task(obj)[source]

Get the modality and task of a given object, e.g. “image” and “classification” from an ImageClassificationModleeModel

Parameters:

obj – The item to parse.

modlee.utils.get_model_size(model, as_MB=True)[source]

Get the size of a model, as estimated from the number and size of its parameters.

Parameters:
  • model – The model for which to get the size.

  • as_MB – Whether to return the size in MB, defaults to True.

Returns:

The model size.

class modlee.utils.image_loaders[source]

Bases: object

ctr = 4
get_cifar10_dataloader(*args, **kwargs)
get_dtd_dataloader(*args, **kwargs)
get_eurosat_dataloader(*args, **kwargs)
get_fashionmnist_dataloader(*args, **kwargs)
image_module = 'FashionMNIST'
image_modules = {'CIFAR10': {'download': True, 'train': False}, 'DTD': {'download': True, 'split': 'test'}, 'EuroSAT': {'download': True}, 'FGVCAircraft': {'download': True, 'split': 'test'}, 'FashionMNIST': {'download': True, 'train': False}, 'Flowers102': {'download': True, 'split': 'test'}, 'GTSRB': {'download': True, 'split': 'test'}, 'Imagenette': {'size': 'full', 'split': 'val'}, 'KMNIST': {'download': True, 'train': False}, 'MNIST': {'download': True, 'train': False}, 'Omniglot': {'download': True}, 'OxfordIIITPet': {'download': True, 'split': 'test'}, 'Places365': {'small': True, 'split': 'val'}, 'QMNIST': {'download': True, 'what': 'test10k'}, 'RenderedSST2': {'download': True, 'split': 'test'}, 'SEMEION': {'download': True}, 'STL10': {'download': True, 'split': 'test'}, 'SUN397': {'download': True}, 'SVHN': {'download': True, 'split': 'test'}, 'USPS': {'download': True, 'train': False}}
kwargs = {'download': True, 'train': False}
modlee.utils.is_cacheable(x)[source]

Check if an object is cacheable / serializable.

Parameters:

x – The object to check cacheability, probably a dictionary.

Returns:

A boolean of whether the object is cacheable or not.

modlee.utils.last_run_path(*args, **kwargs)[source]

Return the path to the last / most recent run path

Returns:

The path to the last run.

modlee.utils.quantize(x)[source]

Quantize an object.

Parameters:

x – The object to quantize.

Returns:

The object, quantized.

modlee.utils.quantize_dict(base_dict, quantize_fn=<function quantize>)[source]

Quantize a dictionary.

Parameters:
  • base_dict – The dictionary to quantize.

  • quantize_fn – The function to use for quantization, defaults to quantize.

Returns:

The quantized dictionary.

modlee.utils.safe_mkdir(target_path)[source]

Safely make a directory.

Parameters:

target_path – The path to the target directory.

modlee.utils.save_run(*args, **kwargs)[source]

Save the current run.

Parameters:

modlee_client – The client object that is tracking the current run.

modlee.utils.save_run_as_json(*args, **kwargs)[source]

Save the current run as a JSON.

Parameters:

modlee_client – The client object that is tracking the current run.

class modlee.utils.tabular_loaders[source]

Bases: object

class TabularDataset(data, target)[source]

Bases: Dataset

static get_adult_dataloader(batch_size=4, shuffle=True, root=None)[source]
static get_diabetes_dataloader(batch_size=4, shuffle=True, root=None)[source]
static get_housing_dataloader(batch_size=4, shuffle=True, root=None)[source]
class modlee.utils.text_loaders[source]

Bases: object

dataset_len = 71
dataset_module = 'WNLI'
get_cola_dataloader(split='dev')
static get_mnli_dataloader(*args, **kwargs)[source]
get_qnli_dataloader(split='dev')
get_rte_dataloader(split='dev')
get_sst2_dataloader(split='dev')
get_stsb_dataloader(split='dev')
get_wnli_dataloader(split='dev')
text_modules_lengths = {'CoLA': 527, 'QNLI': 5463, 'RTE': 277, 'SST2': 872, 'STSB': 1500, 'WNLI': 71}
class modlee.utils.timeseries_loader[source]

Bases: object

static get_timeseries_dataloader(data, target, input_seq: int, output_seq: int, time_column: str, encoder_column: list)[source]
class modlee.utils.timeseries_loaders[source]

Bases: object

static get_finance_dataloader(root=None)[source]
modlee.utils.typewriter_print(text, sleep_time=0.001, max_line_length=150, max_lines=20)[source]

Print a string letter-by-letter, like a typewriter.

Parameters:
  • text – The text to print.

  • sleep_time – The time to sleep between letters, defaults to 0.001.

  • max_line_length – The maximum line length to truncate to, defaults to 150.

  • max_lines – The maximum number of lines to print, defaults to 20.

modlee.utils.uri_to_path(uri)[source]

Convert a URI to a path.

Parameters:

uri – The URI to convert.

Returns:

The converted path.