modlee.utils module
Utility functions.
- modlee.utils.apply_discretize_to_summary(text, info)[source]
Discretize a summary.
- Parameters:
text – The text to discretize.
info – An object that contains different separators.
- Returns:
The discretized summary.
- modlee.utils.closest_power_of_2(x)[source]
Round a number to its closest power of 2, i.e. y = 2**floor(log_2(x)).
- Parameters:
x – The number.
- Returns:
The closest power of 2 of the number.
- modlee.utils.convert_to_scientific(x)[source]
Convert a number to scientific notation.
- Parameters:
x – The number to convert.
- Returns:
The number in scientific notation as a string.
- modlee.utils.discretize(n: list[float, int]) list[float, int] [source]
Discretize a list of inputs
- Parameters:
n – The list of inputs to discretize.
- Returns:
The list of discretized inputs.
- modlee.utils.get_fashion_mnist(batch_size=64, num_output_channels=1)[source]
Get the Fashion MNIST dataset from torchvision.
- Parameters:
batch_size – The batch size, defaults to 64.
num_output_channels – Passed to torchvision.transforms.Grayscale. 1 = grayscale, 3 = RGB. Defaults to 1.
- Returns:
A tuple of train and test dataloaders.
- modlee.utils.get_imagenette_dataloader()[source]
Get a small validation dataloader for imagenette (https://pytorch.org/vision/stable/generated/torchvision.datasets.Imagenette.html#torchvision.datasets.Imagenette)
- modlee.utils.get_model_size(model, as_MB=True)[source]
Get the size of a model, as estimated from the number and size of its parameters.
- Parameters:
model – The model for which to get the size.
as_MB – Whether to return the size in MB, defaults to True.
- Returns:
The model size.
- class modlee.utils.image_loaders[source]
Bases:
object
- get_cifar10_dataloader(*args, **kwargs)
- get_dtd_dataloader(*args, **kwargs)
- get_eurosat_dataloader(*args, **kwargs)
- get_fashionmnist_dataloader(*args, **kwargs)
- get_fgvcaircraft_dataloader(*args, **kwargs)
- get_flowers102_dataloader(*args, **kwargs)
- get_gtsrb_dataloader(*args, **kwargs)
- get_imagenette_dataloader(*args, **kwargs)
- get_kmnist_dataloader(*args, **kwargs)
- get_mnist_dataloader(*args, **kwargs)
- get_omniglot_dataloader(*args, **kwargs)
- get_oxfordiiitpet_dataloader(*args, **kwargs)
- get_places365_dataloader(*args, **kwargs)
- get_qmnist_dataloader(*args, **kwargs)
- get_renderedsst2_dataloader(*args, **kwargs)
- get_semeion_dataloader(*args, **kwargs)
- get_stl10_dataloader(*args, **kwargs)
- get_sun397_dataloader(*args, **kwargs)
- get_svhn_dataloader(*args, **kwargs)
- get_usps_dataloader(*args, **kwargs)
- image_module = 'USPS'
- image_modules = {'CIFAR10': {'download': True, 'train': False}, 'DTD': {'download': True, 'split': 'test'}, 'EuroSAT': {'download': True}, 'FGVCAircraft': {'download': True, 'split': 'test'}, 'FashionMNIST': {'download': True, 'train': False}, 'Flowers102': {'download': True, 'split': 'test'}, 'GTSRB': {'download': True, 'split': 'test'}, 'Imagenette': {'size': 'full', 'split': 'val'}, 'KMNIST': {'download': True, 'train': False}, 'MNIST': {'download': True, 'train': False}, 'Omniglot': {'download': True}, 'OxfordIIITPet': {'download': True, 'split': 'test'}, 'Places365': {'small': True, 'split': 'val'}, 'QMNIST': {'download': True, 'what': 'test10k'}, 'RenderedSST2': {'download': True, 'split': 'test'}, 'SEMEION': {'download': True}, 'STL10': {'download': True, 'split': 'test'}, 'SUN397': {'download': True}, 'SVHN': {'download': True, 'split': 'test'}, 'USPS': {'download': True, 'train': False}}
- kwargs = {'download': True, 'train': False}
- modlee.utils.is_cacheable(x)[source]
Check if an object is cacheable / serializable.
- Parameters:
x – The object to check cacheability, probably a dictionary.
- Returns:
A boolean of whether the object is cacheable or not.
- modlee.utils.last_run_path(*args, **kwargs)[source]
Return the path to the last / most recent run path
- Returns:
The path to the last run.
- modlee.utils.quantize(x)[source]
Quantize an object.
- Parameters:
x – The object to quantize.
- Returns:
The object, quantized.
- modlee.utils.quantize_dict(base_dict, quantize_fn=<function quantize>)[source]
Quantize a dictionary.
- Parameters:
base_dict – The dictionary to quantize.
quantize_fn – The function to use for quantization, defaults to quantize.
- Returns:
The quantized dictionary.
- modlee.utils.safe_mkdir(target_path)[source]
Safely make a directory.
- Parameters:
target_path – The path to the target directory.
- modlee.utils.save_run(*args, **kwargs)[source]
Save the current run.
- Parameters:
modlee_client – The client object that is tracking the current run.
- modlee.utils.save_run_as_json(*args, **kwargs)[source]
Save the current run as a JSON.
- Parameters:
modlee_client – The client object that is tracking the current run.
- class modlee.utils.text_loaders[source]
Bases:
object
- dataset_len = 71
- dataset_module = 'WNLI'
- get_cola_dataloader(split='dev')
- get_qnli_dataloader(split='dev')
- get_rte_dataloader(split='dev')
- get_sst2_dataloader(split='dev')
- get_stsb_dataloader(split='dev')
- get_wnli_dataloader(split='dev')
- text_modules_lengths = {'CoLA': 527, 'QNLI': 5463, 'RTE': 277, 'SST2': 872, 'STSB': 1500, 'WNLI': 71}
- modlee.utils.typewriter_print(text, sleep_time=0.001, max_line_length=150, max_lines=20)[source]
Print a string letter-by-letter, like a typewriter.
- Parameters:
text – The text to print.
sleep_time – The time to sleep between letters, defaults to 0.001.
max_line_length – The maximum line length to truncate to, defaults to 150.
max_lines – The maximum number of lines to print, defaults to 20.