{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# An example of using Torch-TensorRT Autocast\n\nThis example demonstrates how to use Torch-TensorRT Autocast with PyTorch Autocast to compile a mixed precision model.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torch.nn as nn\nimport torch_tensorrt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We define a mixed precision model that consists of a few layers, a ``log`` operation, and an ``abs`` operation.\nAmong them, the ``fc1``, ``log``, and ``abs`` operations are within PyTorch Autocast context with ``dtype=torch.float16``.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class MixedPytorchAutocastModel(nn.Module):\n def __init__(self):\n super(MixedPytorchAutocastModel, self).__init__()\n self.conv1 = nn.Conv2d(\n in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1\n )\n self.relu1 = nn.ReLU()\n self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)\n self.conv2 = nn.Conv2d(\n in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1\n )\n self.relu2 = nn.ReLU()\n self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)\n self.flatten = nn.Flatten()\n self.fc1 = nn.Linear(16 * 8 * 8, 10)\n\n def forward(self, x):\n out1 = self.conv1(x)\n out2 = self.relu1(out1)\n out3 = self.pool1(out2)\n out4 = self.conv2(out3)\n out5 = self.relu2(out4)\n out6 = self.pool2(out5)\n out7 = self.flatten(out6)\n with torch.autocast(x.device.type, enabled=True, dtype=torch.float16):\n out8 = self.fc1(out7)\n out9 = torch.log(\n torch.abs(out8) + 1\n ) # log is fp32 due to Pytorch Autocast requirements\n return x, out1, out2, out3, out4, out5, out6, out7, out8, out9" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the model, inputs, and calibration dataloader for Autocast, and then we run the original PyTorch model to get the reference outputs.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model = MixedPytorchAutocastModel().cuda().eval()\ninputs = (torch.randn((8, 3, 32, 32), dtype=torch.float32, device=\"cuda\"),)\nep = torch.export.export(model, inputs)\ncalibration_dataloader = torch.utils.data.DataLoader(\n torch.utils.data.TensorDataset(*inputs), batch_size=2, shuffle=False\n)\n\npytorch_outs = model(*inputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We compile the model with Torch-TensorRT Autocast by setting ``enable_autocast=True``, ``use_explicit_typing=True``, and\n``autocast_low_precision_type=torch.bfloat16``. To illustrate, we exclude the ``conv1`` node, all nodes with name\ncontaining ``relu``, and ``torch.ops.aten.flatten.using_ints`` ATen op from Autocast. In addtion, we also set\n``autocast_max_output_threshold``, ``autocast_max_depth_of_reduction``, and ``autocast_calibration_dataloader``. Please refer to\nthe documentation for more details.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "trt_autocast_mod = torch_tensorrt.compile(\n ep.module(),\n arg_inputs=inputs,\n min_block_size=1,\n use_python_runtime=True,\n use_explicit_typing=True,\n enable_autocast=True,\n autocast_low_precision_type=torch.bfloat16,\n autocast_excluded_nodes={\"^conv1$\", \"relu\"},\n autocast_excluded_ops={\"torch.ops.aten.flatten.using_ints\"},\n autocast_max_output_threshold=512,\n autocast_max_depth_of_reduction=None,\n autocast_calibration_dataloader=calibration_dataloader,\n)\n\nautocast_outs = trt_autocast_mod(*inputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We verify both the dtype and values of the outputs of the model are correct.\nAs expected, ``fc1`` is in FP16 because of PyTorch Autocast;\n``pool1``, ``conv2``, and ``pool2`` are in BFP16 because of Torch-TensorRT Autocast;\nthe rest remain in FP32. Note that ``log`` is in FP32 because of PyTorch Autocast requirements.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "should_be_fp32 = [\n autocast_outs[0],\n autocast_outs[1],\n autocast_outs[2],\n autocast_outs[5],\n autocast_outs[7],\n autocast_outs[9],\n]\nshould_be_fp16 = [\n autocast_outs[8],\n]\nshould_be_bf16 = [autocast_outs[3], autocast_outs[4], autocast_outs[6]]\n\nassert all(\n a.dtype == torch.float32 for a in should_be_fp32\n), \"Some Autocast outputs are not float32!\"\nassert all(\n a.dtype == torch.float16 for a in should_be_fp16\n), \"Some Autocast outputs are not float16!\"\nassert all(\n a.dtype == torch.bfloat16 for a in should_be_bf16\n), \"Some Autocast outputs are not bfloat16!\"\nfor i, (a, w) in enumerate(zip(autocast_outs, pytorch_outs)):\n assert torch.allclose(\n a.to(torch.float32), w.to(torch.float32), atol=1e-2, rtol=1e-2\n ), f\"Autocast and Pytorch outputs do not match! autocast_outs[{i}] = {a}, pytorch_outs[{i}] = {w}\"\nprint(\"All dtypes and values match!\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 0 }