Shortcuts

torch.hub

Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility.

Publishing models

Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights) to a github repository by adding a simple hubconf.py file;

hubconf.py can have multiple entrypoints. Each entrypoint is defined as a python function with the following signature.

def entrypoint_name(pretrained=False, *args, **kwargs):
    ...

How to implement an entrypoint?

Here is a code snipet from pytorch/vision repository, which specifies an entrypoint for resnet18 model. You can see a full script in pytorch/vision repo

dependencies = ['torch', 'math']

def resnet18(pretrained=False, *args, **kwargs):
    """
    Resnet18 model
    pretrained (bool): a recommended kwargs for all entrypoints
    args & kwargs are arguments for the function
    """
    ######## Call the model in the repo ###############
    from torchvision.models.resnet import resnet18 as _resnet18
    model = _resnet18(*args, **kwargs)
    ######## End of call ##############################
    # The following logic is REQUIRED
    if pretrained:
        # For weights saved in local repo
                    # model.load_state_dict(<path_to_saved_file>)

                    # For weights saved elsewhere
                    checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
        model.load_state_dict(model_zoo.load_url(checkpoint, progress=False))
    return model
  • dependencies variable is a list of package names required to to run the model.
  • Pretrained weights can either be stored local in the github repo, or loadable by model_zoo.load().
  • pretrained controls whether to load the pre-trained weights provided by repo owners.
  • args and kwargs are passed along to the real callable function.
  • Docstring of the function works as a help message, explaining what does the model do and what are the allowed arguments.
  • Entrypoint function should ALWAYS return a model(nn.module).

Important Notice

  • The published models should be at least in a branch/tag. It can’t be a random commit.

Loading models from Hub

Users can load the pre-trained models using torch.hub.load() API.

torch.hub.load(github, model, force_reload=False, *args, **kwargs)[source]

Load a model from a github repo, with pretrained weights.

Parameters:
  • github – Required, a string with format “repo_owner/repo_name[:tag_name]” with an optional tag/branch. The default branch is master if not specified. Example: ‘pytorch/vision[:hub]’
  • model – Required, a string of entrypoint name defined in repo’s hubconf.py
  • force_reload – Optional, whether to discard the existing cache and force a fresh download. Default is False.
  • *args – Optional, the corresponding args for callable model.
  • **kwargs – Optional, the corresponding kwargs for callable model.
Returns:

a single model with corresponding pretrained weights.

Here’s an example loading resnet18 entrypoint from pytorch/vision repo.

hub_model = hub.load(
    'pytorch/vision:master', # repo_owner/repo_name:branch
    'resnet18', # entrypoint
    1234, # args for callable [not applicable to resnet]
    pretrained=True) # kwargs for callable

Where are my downloaded model & weights saved?

The locations are used in the order of

  • hub_dir: user specified path. It can be set in the following ways: - Setting the environment variable TORCH_HUB_DIR - Calling hub.set_dir(<PATH_TO_HUB_DIR>)
  • ~/.torch/hub
torch.hub.set_dir(d)[source]

Optionally set hub_dir to a local dir to save downloaded models & weights.

If this argument is not set, env variable TORCH_HUB_DIR will be searched first, ~/.torch/hub will be created and used as fallback.

Parameters:d – path to a local folder to save downloaded models & weights.

Caching logic

By default, we don’t clean up files after loading it. Hub uses the cache by default if it already exists in hub_dir.

Users can force a reload by calling hub.load(..., force_reload=True). This will delete the existing github folder and downloaded weights, reinitialize a fresh download. This is useful when updates are published to the same branch, users can keep up with the latest release.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources