.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/pretrained_models.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_pretrained_models.py: Using pretrained models ======================= This tutorial explains how to use pretrained models in TorchRL. .. GENERATED FROM PYTHON SOURCE LINES 7-15 At the end of this tutorial, you will be capable of using pretrained models for efficient image representation, and fine-tune them. TorchRL provides pretrained models that are to be used either as transforms or as components of the policy. As the sematic is the same, they can be used interchangeably in one or the other context. In this tutorial, we will be using R3M (https://arxiv.org/abs/2203.12601), but other models (e.g. VIP) will work equally well. .. GENERATED FROM PYTHON SOURCE LINES 15-31 .. code-block:: Python import torch.cuda from tensordict.nn import TensorDictSequential from torch import nn from torchrl.envs import R3MTransform, TransformedEnv from torchrl.envs.libs.gym import GymEnv from torchrl.modules import Actor is_fork = multiprocessing.get_start_method() == "fork" device = ( torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu") ) .. GENERATED FROM PYTHON SOURCE LINES 52-56 Let us first create an environment. For the sake of simplicity, we will be using a common gym environment. In practice, this will work in more challenging, embodied AI contexts (e.g. have a look at our Habitat wrappers). .. GENERATED FROM PYTHON SOURCE LINES 56-58 .. code-block:: Python base_env = GymEnv("Ant-v4", from_pixels=True, device=device) .. GENERATED FROM PYTHON SOURCE LINES 59-66 Let us fetch our pretrained model. We ask for the pretrained version of the model through the download=True flag. By default this is turned off. Next, we will append our transform to the environment. In practice, what will happen is that each batch of data collected will go through the transform and be mapped on a "r3m_vec" entry in the output tensordict. Our policy, consisting of a single layer MLP, will then read this vector and compute the corresponding action. .. GENERATED FROM PYTHON SOURCE LINES 66-79 .. code-block:: Python r3m = R3MTransform( "resnet50", in_keys=["pixels"], download=True, ) env_transformed = TransformedEnv(base_env, r3m) net = nn.Sequential( nn.LazyLinear(128, device=device), nn.Tanh(), nn.Linear(128, base_env.action_spec.shape[-1], device=device), ) policy = Actor(net, in_keys=["r3m_vec"]) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://pytorch.s3.amazonaws.com/models/rl/r3m/r3m_50.pt" to /root/.cache/torch/hub/checkpoints/r3m_50.pt 0%| | 0.00/374M [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: pretrained_models.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: pretrained_models.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_