.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/pretrained_models.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_pretrained_models.py: Using pretrained models ======================= This tutorial explains how to use pretrained models in TorchRL. .. GENERATED FROM PYTHON SOURCE LINES 6-8 .. code-block:: Python import tempfile .. GENERATED FROM PYTHON SOURCE LINES 9-17 At the end of this tutorial, you will be capable of using pretrained models for efficient image representation, and fine-tune them. TorchRL provides pretrained models that are to be used either as transforms or as components of the policy. As the semantic is the same, they can be used interchangeably in one or the other context. In this tutorial, we will be using R3M (https://arxiv.org/abs/2203.12601), but other models (e.g. VIP) will work equally well. .. GENERATED FROM PYTHON SOURCE LINES 17-33 .. code-block:: Python import torch.cuda from tensordict.nn import TensorDictSequential from torch import nn from torchrl.envs import Compose, R3MTransform, TransformedEnv from torchrl.envs.libs.gym import GymEnv from torchrl.modules import Actor is_fork = multiprocessing.get_start_method() == "fork" device = ( torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu") ) .. GENERATED FROM PYTHON SOURCE LINES 54-58 Let us first create an environment. For the sake of simplicity, we will be using a common gym environment. In practice, this will work in more challenging, embodied AI contexts (e.g. have a look at our Habitat wrappers). .. GENERATED FROM PYTHON SOURCE LINES 58-60 .. code-block:: Python base_env = GymEnv("Ant-v4", from_pixels=True, device=device) .. GENERATED FROM PYTHON SOURCE LINES 61-68 Let us fetch our pretrained model. We ask for the pretrained version of the model through the download=True flag. By default this is turned off. Next, we will append our transform to the environment. In practice, what will happen is that each batch of data collected will go through the transform and be mapped on a "r3m_vec" entry in the output tensordict. Our policy, consisting of a single layer MLP, will then read this vector and compute the corresponding action. .. GENERATED FROM PYTHON SOURCE LINES 68-81 .. code-block:: Python r3m = R3MTransform( "resnet50", in_keys=["pixels"], download=False, # Turn to true for real-life testing ) env_transformed = TransformedEnv(base_env, r3m) net = nn.Sequential( nn.LazyLinear(128, device=device), nn.Tanh(), nn.Linear(128, base_env.action_spec.shape[-1], device=device), ) policy = Actor(net, in_keys=["r3m_vec"]) .. GENERATED FROM PYTHON SOURCE LINES 82-84 Let's check the number of parameters of the policy: .. GENERATED FROM PYTHON SOURCE LINES 84-86 .. code-block:: Python print("number of params:", len(list(policy.parameters()))) .. GENERATED FROM PYTHON SOURCE LINES 87-89 We collect a rollout of 32 steps and print its output: .. GENERATED FROM PYTHON SOURCE LINES 89-92 .. code-block:: Python rollout = env_transformed.rollout(32, policy) print("rollout with transform:", rollout) .. GENERATED FROM PYTHON SOURCE LINES 93-97 For fine tuning, we integrate the transform in the policy after making the parameters trainable. In practice, it may be wiser to restrict this to a subset of the parameters (say the last layer of the MLP). .. GENERATED FROM PYTHON SOURCE LINES 97-101 .. code-block:: Python r3m.train() policy = TensorDictSequential(r3m, policy) print("number of params after r3m is integrated:", len(list(policy.parameters()))) .. GENERATED FROM PYTHON SOURCE LINES 102-106 Again, we collect a rollout with R3M. The structure of the output has changed slightly, as now the environment returns pixels (and not an embedding). The embedding "r3m_vec" is an intermediate result of our policy. .. GENERATED FROM PYTHON SOURCE LINES 106-109 .. code-block:: Python rollout = base_env.rollout(32, policy) print("rollout, fine tuning:", rollout) .. GENERATED FROM PYTHON SOURCE LINES 110-117 The easiness with which we have swapped the transform from the env to the policy is due to the fact that both behave like TensorDictModule: they have a set of `"in_keys"` and `"out_keys"` that make it easy to read and write output in different context. To conclude this tutorial, let's have a look at how we could use R3M to read images stored in a replay buffer (e.g. in an offline RL context). First, let's build our dataset: .. GENERATED FROM PYTHON SOURCE LINES 117-123 .. code-block:: Python from torchrl.data import LazyMemmapStorage, ReplayBuffer buffer_scratch_dir = tempfile.TemporaryDirectory().name storage = LazyMemmapStorage(1000, scratch_dir=buffer_scratch_dir) rb = ReplayBuffer(storage=storage, transform=Compose(lambda td: td.to(device), r3m)) .. GENERATED FROM PYTHON SOURCE LINES 124-127 We can now collect the data (random rollouts for our purpose) and fill the replay buffer with it: .. GENERATED FROM PYTHON SOURCE LINES 127-133 .. code-block:: Python total = 0 while total < 1000: tensordict = base_env.rollout(1000) rb.extend(tensordict) total += tensordict.numel() .. GENERATED FROM PYTHON SOURCE LINES 134-136 Let's check what our replay buffer storage looks like. It should not contain the "r3m_vec" entry since we haven't used it yet: .. GENERATED FROM PYTHON SOURCE LINES 136-138 .. code-block:: Python print("stored data:", storage._storage) .. GENERATED FROM PYTHON SOURCE LINES 139-142 When sampling, the data will go through the R3M transform, giving us the processed data that we wanted. In this way, we can train an algorithm offline on a dataset made of images: .. GENERATED FROM PYTHON SOURCE LINES 142-145 .. code-block:: Python batch = rb.sample(32) print("data after sampling:", batch) .. _sphx_glr_download_tutorials_pretrained_models.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: pretrained_models.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: pretrained_models.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: pretrained_models.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_