.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/_rendered_examples/dynamo/autocast_example.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials__rendered_examples_dynamo_autocast_example.py: .. _autocast_example: An example of using Torch-TensorRT Autocast ================ This example demonstrates how to use Torch-TensorRT Autocast with PyTorch Autocast to compile a mixed precision model. .. GENERATED FROM PYTHON SOURCE LINES 9-14 .. code-block:: python import torch import torch.nn as nn import torch_tensorrt .. GENERATED FROM PYTHON SOURCE LINES 15-17 We define a mixed precision model that consists of a few layers, a ``log`` operation, and an ``abs`` operation. Among them, the ``fc1``, ``log``, and ``abs`` operations are within PyTorch Autocast context with ``dtype=torch.float16``. .. GENERATED FROM PYTHON SOURCE LINES 18-52 .. code-block:: python class MixedPytorchAutocastModel(nn.Module): def __init__(self): super(MixedPytorchAutocastModel, self).__init__() self.conv1 = nn.Conv2d( in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1 ) self.relu1 = nn.ReLU() self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d( in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1 ) self.relu2 = nn.ReLU() self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() self.fc1 = nn.Linear(16 * 8 * 8, 10) def forward(self, x): out1 = self.conv1(x) out2 = self.relu1(out1) out3 = self.pool1(out2) out4 = self.conv2(out3) out5 = self.relu2(out4) out6 = self.pool2(out5) out7 = self.flatten(out6) with torch.autocast(x.device.type, enabled=True, dtype=torch.float16): out8 = self.fc1(out7) out9 = torch.log( torch.abs(out8) + 1 ) # log is fp32 due to Pytorch Autocast requirements return x, out1, out2, out3, out4, out5, out6, out7, out8, out9 .. GENERATED FROM PYTHON SOURCE LINES 53-54 Define the model, inputs, and calibration dataloader for Autocast, and then we run the original PyTorch model to get the reference outputs. .. GENERATED FROM PYTHON SOURCE LINES 54-64 .. code-block:: python model = MixedPytorchAutocastModel().cuda().eval() inputs = (torch.randn((8, 3, 32, 32), dtype=torch.float32, device="cuda"),) ep = torch.export.export(model, inputs) calibration_dataloader = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(*inputs), batch_size=2, shuffle=False ) pytorch_outs = model(*inputs) .. GENERATED FROM PYTHON SOURCE LINES 65-70 We compile the model with Torch-TensorRT Autocast by setting ``enable_autocast=True``, ``use_explicit_typing=True``, and ``autocast_low_precision_type=torch.bfloat16``. To illustrate, we exclude the ``conv1`` node, all nodes with name containing ``relu``, and ``torch.ops.aten.flatten.using_ints`` ATen op from Autocast. In addtion, we also set ``autocast_max_output_threshold``, ``autocast_max_depth_of_reduction``, and ``autocast_calibration_dataloader``. Please refer to the documentation for more details. .. GENERATED FROM PYTHON SOURCE LINES 71-89 .. code-block:: python trt_autocast_mod = torch_tensorrt.compile( ep.module(), arg_inputs=inputs, min_block_size=1, use_python_runtime=True, use_explicit_typing=True, enable_autocast=True, autocast_low_precision_type=torch.bfloat16, autocast_excluded_nodes={"^conv1$", "relu"}, autocast_excluded_ops={"torch.ops.aten.flatten.using_ints"}, autocast_max_output_threshold=512, autocast_max_depth_of_reduction=None, autocast_calibration_dataloader=calibration_dataloader, ) autocast_outs = trt_autocast_mod(*inputs) .. GENERATED FROM PYTHON SOURCE LINES 90-94 We verify both the dtype and values of the outputs of the model are correct. As expected, ``fc1`` is in FP16 because of PyTorch Autocast; ``pool1``, ``conv2``, and ``pool2`` are in BFP16 because of Torch-TensorRT Autocast; the rest remain in FP32. Note that ``log`` is in FP32 because of PyTorch Autocast requirements. .. GENERATED FROM PYTHON SOURCE LINES 95-123 .. code-block:: python should_be_fp32 = [ autocast_outs[0], autocast_outs[1], autocast_outs[2], autocast_outs[5], autocast_outs[7], autocast_outs[9], ] should_be_fp16 = [ autocast_outs[8], ] should_be_bf16 = [autocast_outs[3], autocast_outs[4], autocast_outs[6]] assert all( a.dtype == torch.float32 for a in should_be_fp32 ), "Some Autocast outputs are not float32!" assert all( a.dtype == torch.float16 for a in should_be_fp16 ), "Some Autocast outputs are not float16!" assert all( a.dtype == torch.bfloat16 for a in should_be_bf16 ), "Some Autocast outputs are not bfloat16!" for i, (a, w) in enumerate(zip(autocast_outs, pytorch_outs)): assert torch.allclose( a.to(torch.float32), w.to(torch.float32), atol=1e-2, rtol=1e-2 ), f"Autocast and Pytorch outputs do not match! autocast_outs[{i}] = {a}, pytorch_outs[{i}] = {w}" print("All dtypes and values match!") .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_tutorials__rendered_examples_dynamo_autocast_example.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: autocast_example.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: autocast_example.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_