.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/export.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_export.py: Exporting TorchRL modules ========================= **Author**: `Vincent Moens `_ .. _export_tuto: .. note:: To run this tutorial in a notebook, add an installation cell at the beginning containing: .. code-block:: !pip install tensordict !pip install torchrl !pip install "gymnasium[atari,accept-rom-license]"<1.0.0 Introduction ------------ Learning a policy has little value if that policy cannot be deployed in real-world settings. As shown in other tutorials, TorchRL has a strong focus on modularity and composability: thanks to ``tensordict``, the components of the library can be written in the most generic way there is by abstracting their signature to a mere set of operations on an input ``TensorDict``. This may give the impression that the library is bound to be used only for training, as typical low-level execution hardwares (edge devices, robots, arduino, Raspberry Pi) do not execute python code, let alone with pytorch, tensordict or torchrl installed. Fortunately, PyTorch provides a full ecosystem of solutions to export code and trained models to devices and hardwares, and TorchRL is fully equipped to interact with it. It is possible to choose from a varied set of backends, including ONNX or AOTInductor examplified in this tutorial. This tutorial gives a quick overview of how a trained model can be isolated and shipped as a standalone executable to be exported on hardware. Key learnings: - Export any TorchRL module after training; - Using various backends; - Testing your exported model. Fast recap: a simple TorchRL training loop ------------------------------------------ In this section, we reproduce the training loop from the last Getting Started tutorial, slightly adapted to be used with Atari games as they are rendered by the gymnasium library. We will stick to the DQN example, and show how a policy that outputs a distribution over values can be used instead later. .. GENERATED FROM PYTHON SOURCE LINES 50-143 .. code-block:: Python import time from pathlib import Path import numpy as np import torch from tensordict.nn import ( TensorDictModule as Mod, TensorDictSequential, TensorDictSequential as Seq, ) from torch.optim import Adam from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, ReplayBuffer from torchrl.envs import ( Compose, GrayScale, GymEnv, Resize, set_exploration_type, StepCounter, ToTensorImage, TransformedEnv, ) from torchrl.modules import ConvNet, EGreedyModule, QValueModule from torchrl.objectives import DQNLoss, SoftUpdate torch.manual_seed(0) env = TransformedEnv( GymEnv("ALE/Pong-v5", categorical_action_encoding=True), Compose( ToTensorImage(), Resize(84, interpolation="nearest"), GrayScale(), StepCounter() ), ) env.set_seed(0) value_mlp = ConvNet.default_atari_dqn(num_actions=env.action_spec.space.n) value_net = Mod(value_mlp, in_keys=["pixels"], out_keys=["action_value"]) policy = Seq(value_net, QValueModule(spec=env.action_spec)) exploration_module = EGreedyModule( env.action_spec, annealing_num_steps=100_000, eps_init=0.5 ) policy_explore = Seq(policy, exploration_module) init_rand_steps = 5000 frames_per_batch = 100 optim_steps = 10 collector = SyncDataCollector( env, policy_explore, frames_per_batch=frames_per_batch, total_frames=-1, init_random_frames=init_rand_steps, ) rb = ReplayBuffer(storage=LazyTensorStorage(100_000)) loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True) optim = Adam(loss.parameters()) updater = SoftUpdate(loss, eps=0.99) total_count = 0 total_episodes = 0 t0 = time.time() for data in collector: # Write data in replay buffer rb.extend(data) max_length = rb[:]["next", "step_count"].max() if len(rb) > init_rand_steps: # Optim loop (we do several optim steps # per batch collected for efficiency) for _ in range(optim_steps): sample = rb.sample(128) loss_vals = loss(sample) loss_vals["loss"].backward() optim.step() optim.zero_grad() # Update exploration factor exploration_module.step(data.numel()) # Update target params updater.step() total_count += data.numel() total_episodes += data["next", "done"].sum() if max_length > 200: break .. GENERATED FROM PYTHON SOURCE LINES 144-160 Exporting a TensorDictModule-based policy ----------------------------------------- ``TensorDict`` allowed us to build a policy with a great flexibility: from a regular :class:`~torch.nn.Module` that outputs action values from an observation, we added a :class:`~torchrl.modules.QValueModule` module that read these values and computed an action using some heuristic (e.g., an argmax call). However, there's a small technical catch in our case: the environment (the actual Atari game) doesn't return grayscale, 84x84 images but raw screen-size color ones. The transforms we appended to the environment make sure that the images can be read by the model. We can see that, from the training perspective, the boundary between environment and model is blurry, but at execution time things are much clearer: the model should take care of transforming the input data (images) to the format that can be processed by our CNN. Here again, the magic of tensordict will unblock us: it happens that most of local (non-recursive) TorchRL's transforms can be used both as environment transforms or preprocessing blocks within a :class:`~torch.nn.Module` instance. Let's see how we can prepend them to our policy: .. GENERATED FROM PYTHON SOURCE LINES 160-169 .. code-block:: Python policy_transform = TensorDictSequential( env.transform[ :-1 ], # the last transform is a step counter which we don't need for preproc policy_explore.requires_grad_( False ), # Using the explorative version of the policy for didactic purposes, see below. ) .. GENERATED FROM PYTHON SOURCE LINES 170-176 We create a fake input, and pass it to :func:`~torch.export.export` with the policy. This will give a "raw" python function that will read our input tensor and output an action without any reference to TorchRL or tensordict modules. A good practice is to call :meth:`~tensordict.nn.TensorDictSequential.select_out_keys` to let the model know that we only want a certain set of outputs (in case the policy returns more than one tensor). .. GENERATED FROM PYTHON SOURCE LINES 176-188 .. code-block:: Python fake_td = env.base_env.fake_tensordict() pixels = fake_td["pixels"] with set_exploration_type("DETERMINISTIC"): exported_policy = torch.export.export( # Select only the "action" output key policy_transform.select_out_keys("action"), args=(), kwargs={"pixels": pixels}, strict=False, ) .. GENERATED FROM PYTHON SOURCE LINES 189-192 Representing the policy can be quite insightful: we can see that the first operations are a permute, a div, unsqueeze, resize followed by the convolutional and MLP layers. .. GENERATED FROM PYTHON SOURCE LINES 192-195 .. code-block:: Python print("Deterministic policy") exported_policy.graph_module.print_readable() .. GENERATED FROM PYTHON SOURCE LINES 196-198 As a final check, we can execute the policy with a dummy input. The output (for a single image) should be an integer from 0 to 6 representing the action to be executed in the game. .. GENERATED FROM PYTHON SOURCE LINES 198-202 .. code-block:: Python output = exported_policy.module()(pixels=pixels) print("Exported module output", output) .. GENERATED FROM PYTHON SOURCE LINES 203-229 Further details on exporting :class:`~tensordict.nn.TensorDictModule` instances can be found in the tensordict `documentation `_. .. note:: Exporting modules that take and output nested keys is perfectly fine. The corresponding kwargs will be the `"_".join(key)` version of the key, i.e., the `("group0", "agent0", "obs")` key will correspond to the `"group0_agent0_obs"` keyword argument. Colliding keys (e.g., `("group0_agent0", "obs")` and `("group0", "agent0_obs")` may lead to undefined behaviours and should be avoided at all cost. Obviously, key names should also always produce valid keyword arguments, i.e., they should not contain special characters such as spaces or commas. ``torch.export`` has many other features that we will explore further below. Before this, let us just do a small digression on exploration and stochastic policies in the context of test-time inference, as well as recurrent policies. Working with stochastic policies -------------------------------- As you probably noted, above we used the :class:`~torchrl.envs.set_exploration_type` context manager to control the behaviour of the policy. If the policy is stochastic (e.g., the policy outputs a distribution over the action space like it is the case in PPO or other similar on-policy algorithms) or explorative (with an exploration module appended like E-Greedy, additive gaussian or Ornstein-Uhlenbeck) we may want or not want to use that exploration strategy in its exported version. Fortunately, export utils can understand that context manager and as long as the exportation occurs within the right context manager, the behaviour of the policy should match what is indicated. To demonstrate this, let us try with another exploration type: .. GENERATED FROM PYTHON SOURCE LINES 229-238 .. code-block:: Python with set_exploration_type("RANDOM"): exported_stochastic_policy = torch.export.export( policy_transform.select_out_keys("action"), args=(), kwargs={"pixels": pixels}, strict=False, ) .. GENERATED FROM PYTHON SOURCE LINES 239-243 Our exported policy should now have a random module at the end of the call stack, unlike the previous version. Indeed, the last three operations are: generate a random integer between 0 and 6, use a random mask and select the network output or the random action based on the value in the mask. .. GENERATED FROM PYTHON SOURCE LINES 243-246 .. code-block:: Python print("Stochastic policy") exported_stochastic_policy.graph_module.print_readable() .. GENERATED FROM PYTHON SOURCE LINES 247-257 Working with recurrent policies ------------------------------- Another typical use case is a recurrent policy that will output an action as well as a one or more recurrent state. LSTM and GRU are CuDNN-based modules, which means that they will behave differently than regular :class:`~torch.nn.Module` instances (export utils may not trace them well). Fortunately, TorchRL provides a python implementation of these modules that can be swapped with the CuDNN version when desired. To show this, let us write a prototypical policy that relies on an RNN: .. GENERATED FROM PYTHON SOURCE LINES 257-269 .. code-block:: Python from tensordict.nn import TensorDictModule from torchrl.envs import BatchSizeTransform from torchrl.modules import LSTMModule, MLP lstm = LSTMModule( input_size=32, num_layers=2, hidden_size=256, in_keys=["observation", "hidden0", "hidden1"], out_keys=["intermediate", "hidden0", "hidden1"], ) .. GENERATED FROM PYTHON SOURCE LINES 270-273 If the LSTM module is not python based but CuDNN (:class:`~torch.nn.LSTM`), the :meth:`~torchrl.modules.LSTMModule.make_python_based` method can be used to use the python version. .. GENERATED FROM PYTHON SOURCE LINES 273-275 .. code-block:: Python lstm = lstm.make_python_based() .. GENERATED FROM PYTHON SOURCE LINES 276-279 Let's now create the policy. We combine two layers that modify the shape of the input (unsqueeze/squeeze operations) with the LSTM and an MLP. .. GENERATED FROM PYTHON SOURCE LINES 279-293 .. code-block:: Python recurrent_policy = TensorDictSequential( # Unsqueeze the first dim of all tensors to make LSTMCell happy BatchSizeTransform(reshape_fn=lambda x: x.unsqueeze(0)), lstm, TensorDictModule( MLP(in_features=256, out_features=5, num_cells=[64, 64]), in_keys=["intermediate"], out_keys=["action"], ), # Squeeze the first dim of all tensors to get the original shape back BatchSizeTransform(reshape_fn=lambda x: x.squeeze(0)), ) .. GENERATED FROM PYTHON SOURCE LINES 294-296 As before, we select the relevant keys: .. GENERATED FROM PYTHON SOURCE LINES 296-301 .. code-block:: Python recurrent_policy.select_out_keys("action", "hidden0", "hidden1") print("recurrent policy input keys:", recurrent_policy.in_keys) print("recurrent policy output keys:", recurrent_policy.out_keys) .. GENERATED FROM PYTHON SOURCE LINES 302-304 We are now ready to export. To do this, we build fake inputs and pass them to :func:`~torch.export.export`: .. GENERATED FROM PYTHON SOURCE LINES 304-326 .. code-block:: Python fake_obs = torch.randn(32) fake_hidden0 = torch.randn(2, 256) fake_hidden1 = torch.randn(2, 256) # Tensor indicating whether the state is the first of a sequence fake_is_init = torch.zeros((), dtype=torch.bool) exported_recurrent_policy = torch.export.export( recurrent_policy, args=(), kwargs={ "observation": fake_obs, "hidden0": fake_hidden0, "hidden1": fake_hidden1, "is_init": fake_is_init, }, strict=False, ) print("Recurrent policy graph:") exported_recurrent_policy.graph_module.print_readable() .. GENERATED FROM PYTHON SOURCE LINES 327-336 AOTInductor: Export your policy to pytorch-free C++ binaries ------------------------------------------------------------ AOTInductor is a PyTorch module that allows you to export your model (policy or other) to pytorch-free C++ binaries. This is particularly useful when you need to deploy your model on devices or platforms where PyTorch is not available. Here's an example of how you can use AOTInductor to export your policy, inspired by the `AOTI documentation `_: .. GENERATED FROM PYTHON SOURCE LINES 336-355 .. code-block:: Python from tempfile import TemporaryDirectory from torch._inductor import aoti_compile_and_package, aoti_load_package with TemporaryDirectory() as tmpdir: path = str(Path(tmpdir) / "model.pt2") with torch.no_grad(): pkg_path = aoti_compile_and_package( exported_policy, # Specify the generated shared library path package_path=path, ) print("pkg_path", pkg_path) compiled_module = aoti_load_package(pkg_path) print(compiled_module(pixels=pixels)) .. GENERATED FROM PYTHON SOURCE LINES 356-385 Exporting TorchRL models with ONNX ---------------------------------- .. note:: To execute this part of the script, make sure pytorch onnx is installed: .. code-block:: !pip install onnx-pytorch !pip install onnxruntime You can also find more information about using ONNX in the PyTorch ecosystem `here `_. The following example is based on this documentation. In this section, we are going to showcase how we can export our model in such a way that it can be executed on a pytorch-free setting. There are plenty of resources on the web explaining how ONNX can be used to deploy PyTorch models on various hardwares and devices, including `Raspberry Pi `_, `NVIDIA TensorRT `_, `iOS `_ and `Android `_. The Atari game we trained on can be isolated without TorchRL or gymnasium with the `ALE library `_ and therefore provides us with a good example of what we can achieve with ONNX. Let us see what this API looks like: .. GENERATED FROM PYTHON SOURCE LINES 386-407 .. code-block:: Python from ale_py import ALEInterface, roms # Create the interface ale = ALEInterface() # Load the pong environment ale.loadROM(roms.Pong) ale.reset_game() # Make a step in the simulator action = 0 reward = ale.act(action) screen_obs = ale.getScreenRGB() print("Observation from ALE simulator:", type(screen_obs), screen_obs.shape) from matplotlib import pyplot as plt plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False) plt.imshow(screen_obs) plt.title("Screen rendering of Pong game.") .. GENERATED FROM PYTHON SOURCE LINES 408-410 Exporting to ONNX is quite similar the Export/AOTI above: .. GENERATED FROM PYTHON SOURCE LINES 410-418 .. code-block:: Python import onnxruntime with set_exploration_type("DETERMINISTIC"): # We use torch.onnx.dynamo_export to capture the computation graph from our policy_explore model pixels = torch.as_tensor(screen_obs) onnx_policy_export = torch.onnx.dynamo_export(policy_transform, pixels=pixels) .. GENERATED FROM PYTHON SOURCE LINES 419-420 We can now save the program on disk and load it: .. GENERATED FROM PYTHON SOURCE LINES 420-431 .. code-block:: Python with TemporaryDirectory() as tmpdir: onnx_file_path = str(Path(tmpdir) / "policy.onnx") onnx_policy_export.save(onnx_file_path) ort_session = onnxruntime.InferenceSession( onnx_file_path, providers=["CPUExecutionProvider"] ) onnxruntime_input = {ort_session.get_inputs()[0].name: screen_obs} onnx_policy = ort_session.run(None, onnxruntime_input) .. GENERATED FROM PYTHON SOURCE LINES 432-437 Running a rollout with ONNX ~~~~~~~~~~~~~~~~~~~~~~~~~~~ We now have an ONNX model that runs our policy. Let's compare it to the original TorchRL instance: because it is more lightweight, the ONNX version should be faster than the TorchRL one. .. GENERATED FROM PYTHON SOURCE LINES 437-459 .. code-block:: Python def onnx_policy(screen_obs: np.ndarray) -> int: onnxruntime_input = {ort_session.get_inputs()[0].name: screen_obs} onnxruntime_outputs = ort_session.run(None, onnxruntime_input) action = int(onnxruntime_outputs[0]) return action with timeit("ONNX rollout"): num_steps = 1000 ale.reset_game() for _ in range(num_steps): screen_obs = ale.getScreenRGB() action = onnx_policy(screen_obs) reward = ale.act(action) with timeit("TorchRL version"), torch.no_grad(), set_exploration_type("DETERMINISTIC"): env.rollout(num_steps, policy_explore) print(timeit.print()) .. GENERATED FROM PYTHON SOURCE LINES 460-486 Note that ONNX also offers the possibility of optimizing models directly, but this is beyond the scope of this tutorial. Conclusion ---------- In this tutorial, we learned how to export TorchRL modules using various backends such as PyTorch's built-in export functionality, ``AOTInductor``, and ``ONNX``. We demonstrated how to export a policy trained on an Atari game and run it on a pytorch-free setting using the ``ALE`` library. We also compared the performance of the original TorchRL instance with the exported ONNX model. Key takeaways: - Exporting TorchRL modules allows for deployment on devices without PyTorch installed. - AOTInductor and ONNX provide alternative backends for exporting models. - Optimizing ONNX models can improve performance. Further reading and learning steps: - Check out the official documentation for PyTorch's `export functionality `_, `AOTInductor `_, and `ONNX `_ for more information. - Experiment with deploying exported models on different devices. - Explore optimization techniques for ONNX models to improve performance. .. _sphx_glr_download_tutorials_export.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: export.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: export.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: export.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_