Source code for mlhub.utils

# Utilities
"""
    The ``utils`` module contains common utilities for MLHub.
    
    Configurations
    ---------------
    
    Download Directory
    ^^^^^^^^^^^^^^^^^^^
    
    This is the directory where all the files (datasets, checkpoints,
    etc.) are downloaded. It's used for dataloaders, saving 
    checkpoints during training, loading checkpoints for models, etc.
    
    It is ``/tmp`` by default. However, it can be set through the 
    following means
    
    -   By setting the environment variable ``MLHUB_DOWNLOAD_DIR`` to
        the desired directory.
    -   Using the :py:func:`mlhub.utils.set_download_dir` function.
    
    You can get the current download directory using the
    :py:func:`mlhub.utils.get_download_dir` function. The value of
    this variable is referred as ``DOWNLOAD_DIR`` in the docs.
    
    .. autofunction:: mlhub.utils.get_download_dir
    .. autofunction:: mlhub.utils.set_download_dir
    
    File Management
    ----------------
    
    .. autofunction:: mlhub.utils.ex
    .. autofunction:: mlhub.utils.download_and_extract_archive
    .. autofunction:: mlhub.utils.check_md5
    
    Images
    -------
    
    .. autofunction:: mlhub.utils.norm_img
    
    Miscellaneous
    --------------
    
    .. autofunction:: mlhub.utils.random_alnum_str
    
"""

# %%
import os
import sys
import torch
import string
import random
import hashlib
import numpy as np
from pathlib import Path
from datetime import datetime
from typing import Optional, Union
from torchvision.datasets.utils import \
        extract_archive as _extract_archive, \
        download_url as _download_url, \
        download_and_extract_archive as _download_and_extract_archive


# %%
# ----------- File Management -----------
# Expand the path fully
[docs] def ex(x: str) -> str: r""" Expand a path fully (to realpath). Also expands ``~`` (tilde) to home. :param x: A path :return: A fully resolved (absolute) path """ return os.path.realpath(os.path.expanduser(x))
# Download and extract an file from a URL
[docs] def download_and_extract_archive(url: str, download_root: Optional[str] = None, extract_root: Optional[str] = None, filename: Optional[str] = None, md5: Optional[str] = None, remove_finished: bool = False) -> None: r""" A wrapper to PyTorch's download and extract function with documentation. If the file is already downloaded, then the download is not done again (after an MD5 integrity check). However, the downloaded file is always extracted (files are overwritten if they already exist). :param url: The download URL (to obtain the file from) :param download_root: Root folder where downloaded items must be stored. If None, then it is inferred from the function :py:func:`mlhub.utils.get_download_dir`. :param extract_root: Root folder where the downloaded items are extracted. If None, then it is the same as the ``download_root`` :param filename: The filename to use for saving. It is the basename of the URL if None. :param md5: The checksum to check the downloaded file against (before extracting anything). No check is done if ``None``. :param remove_finished: If True, remove the downloaded file after extracting it. """ if download_root is None: download_root = get_download_dir() _download_and_extract_archive(url, download_root, extract_root, filename, md5, remove_finished)
# Cached download and extract (wrapper) def cached_download_and_extract_archive(url: str, download_root: Optional[str] = None, extract_root: Optional[str] = None, flag_root: Optional[str] = None, filename: Optional[str] = None, md5: Optional[str] = None, remove_finished: bool = False) -> None: r""" 1. Check if file to download already exists 2. If file exists, then check if already unzipped. If it doesn't exist then download it. 3. If file not unzipped, then unzip (extract) it. :param download_root: Root folder where downloaded items must be stored. If None, then it is inferred from the function :py:func:`mlhub.utils.get_download_dir`. """ # Defaults and preprocess if download_root is None: download_root = get_download_dir() if filename is None: filename = os.path.basename(url) if extract_root is None: extract_root = download_root if flag_root is None: flag_root = extract_root download_root = ex(download_root) download_file = True extract_file = True file_path = os.path.join(download_root, filename) flag_path = os.path.join(flag_root, f"{filename}.flag") # Check if the proper file already exists and download if doesn't if os.path.isfile(file_path): print(f"File already exists: {file_path}") if md5 is not None: if not check_md5(file_path, md5): print("MD5 doesn't match, file will be downloaded") run_command(f"mv {file_path} {file_path}.backup") else: print("MD5 matches, not downloading it") download_file = False if download_file: print(f"Downloading {url} to {file_path}") _download_url(url, download_root, filename, md5) else: print(f"Skipping download of {url}") # Check if we have already downloaded and extracted the file if os.path.isfile(flag_path): with open(flag_path, "r") as f: ts = f.read() print(f"File already extracted at timestamp {ts}") extract_file = False # Extract the file if extract_file: print(f"Extracting {file_path} to {extract_root}") _extract_archive(file_path, extract_root, remove_finished) # Create a flag file with timestamp ts = datetime.now().strftime(f"%Y-%m-%dT%H-%M-%S") with open(flag_path, "w") as f: f.write(ts) print(f"Data extraction completed at {ts}") else: print(f"Skipping extraction of {file_path}") # Run a system command def run_command(cmd: str): print(f">>> {cmd}") os.system(cmd) # Check the MD5 checksum of a file
[docs] def check_md5(file: str, true_md5: Optional[str] = None) \ -> Union[str, bool]: """ Returns the MD5 checksum of the given file :param file: The file to check (should exist) :param true_md5: The true MD5 checksum of the file. If None, then the checksum is not checked and the function returns the MD5 checksum of the file. If an expected (true) hash is passed then the function returns ``True`` if the MD5 matches ( ``False`` otherwise) :return: The MD5 checksum of the file if ``true_md5`` is None. Else a bool comparing ``true_md5`` with the MD5 of ``file``. """ if not os.path.isfile(file): raise FileNotFoundError(file) # Get hash of file with open(ex(file), "rb") as f: md5_hash = hashlib.md5(f.read()).hexdigest() if true_md5 is not None: return md5_hash == true_md5 return md5_hash
# %% # ----------- Download directory ----------- _download_dir = os.getenv("MLHUB_DOWNLOAD_DIR", "/tmp") # Get download directory
[docs] def get_download_dir() -> str: r""" Get the download directory (as absolute/resolved path). Only use :py:func:`mlhub.utils.set_download_dir` to set the download directory. :return: The fully resolved download directory """ return ex(_download_dir)
# Set download directory
[docs] def set_download_dir(path: str) -> str: r""" Set the download directory. If directory doesn't exist, it is created. Use :py:func:`mlhub.utils.get_download_dir` to get the current download directory. .. note:: By default, the download directory is set by the environment variable ``MLHUB_DOWNLOAD_DIR``. If it's not set, then the default is ``/tmp``. :param path: The download directory :return: The download directory """ global _download_dir if not os.path.isdir(path): os.makedirs(path, exist_ok=True) _download_dir = path return _download_dir
# %% # ----------- Images ----------- # Normalize an image T_IMG = Union[torch.Tensor, np.ndarray]
[docs] def norm_img(img: T_IMG, eps: float = 1e-12) -> T_IMG: r""" Normalize an image (uniformly map [min, max]) to range [0, 1]. :param img: The image to normalize. This is not modified. :param eps: A small value to avoid division by zero :return: The normalized image. """ return (img - img.min()) / (img.max() - img.min() + eps)
# %%
[docs] def random_alnum_str(n: int = 4) -> str: """ Generate a random alphanumeric string of length ``n``. Characters could be repeated. :param n: The length of alphanumeric string """ characters = string.ascii_letters + string.digits return "".join(random.choices(characters, k=n))