{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Using Custom Kernels within TensorRT Engines with Torch-TensorRT\n\nWe are going to demonstrate how a developer could include a custom kernel in a TensorRT engine using Torch-TensorRT\n\nTorch-TensorRT supports falling back to PyTorch implementations of operations in the case that Torch-TensorRT\ndoes not know how to compile them in TensorRT. However, this comes at the cost of a graph break and will reduce the performance of the model.\nThe easiest way to fix lack of support for ops is by adding a decomposition (see:\n[Writing lowering passes for the Dynamo frontend](https://pytorch.org/TensorRT/contributors/writing_dynamo_aten_lowering_passes.html)) - which defines the operator\nin terms of PyTorch ops that are supported in Torch-TensorRT or a converter (see:\n[Writing converters for the Dynamo frontend](https://pytorch.org/TensorRT/contributors/dynamo_converters.html)) - which defines the operator in terms of TensorRT operators.\n\nIn some cases there isn't a great way to do either of these, perhaps because the operator is a custom kernel that is not part of standard PyTorch or\nTensorRT cannot support it natively.\n\nFor these cases, it is possible to use a TensorRT plugin to replace the operator **inside** the TensorRT engine, thereby avoiding\nthe performance and resource overhead from a graph break.\nFor the sake of demonstration, consider the operation circular padding. Circular padding is useful for ops like circular convolution in deep learning.\nThe following image denotes how the original image (red) is circular padded once (green) and twice (blue):\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Writing Custom Operators in PyTorch\n\nAssume for whatever reason we would like to use a custom implementation of circular padding. In this case as implemented using a kernel written in [OpenAI Triton](https://openai.com/index/triton)\n\nWhen using custom kernels with PyTorch, it is recommended to take the additional step of registering them as formal operators in PyTorch. This will both make it easier to handle\nthe operation in Torch-TensorRT and simplify its use in PyTorch. This could either be done as part of a C++ library or in Python. (see: [Custom ops in C++](https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html) and [Python custom ops](https://pytorch.org/docs/stable/library.html) for more details )\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from typing import Any, Sequence\n\nimport numpy as np\nimport torch\nimport triton\nimport triton.language as tl\nfrom torch.library import custom_op\n\n\n# Defining the kernel to be run on the GPU\n@triton.jit # type: ignore\ndef circ_pad_kernel(\n X: torch.Tensor,\n all_pads_0: tl.int32,\n all_pads_2: tl.int32,\n all_pads_4: tl.int32,\n all_pads_6: tl.int32,\n orig_dims_0: tl.int32,\n orig_dims_1: tl.int32,\n orig_dims_2: tl.int32,\n orig_dims_3: tl.int32,\n Y: torch.Tensor,\n Y_shape_1: tl.int32,\n Y_shape_2: tl.int32,\n Y_shape_3: tl.int32,\n X_len: tl.int32,\n Y_len: tl.int32,\n BLOCK_SIZE: tl.constexpr,\n) -> None:\n pid = tl.program_id(0)\n i = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n\n mask_y = i < Y_len\n\n i3 = i % Y_shape_3\n i2 = (i // Y_shape_3) % Y_shape_2\n i1 = (i // Y_shape_3 // Y_shape_2) % Y_shape_1\n i0 = i // Y_shape_3 // Y_shape_2 // Y_shape_1\n\n j0 = (i0 - all_pads_0 + orig_dims_0) % orig_dims_0\n j1 = (i1 - all_pads_2 + orig_dims_1) % orig_dims_1\n j2 = (i2 - all_pads_4 + orig_dims_2) % orig_dims_2\n j3 = (i3 - all_pads_6 + orig_dims_3) % orig_dims_3\n\n load_idx = (\n orig_dims_3 * orig_dims_2 * orig_dims_1 * j0\n + orig_dims_3 * orig_dims_2 * j1\n + orig_dims_3 * j2\n + j3\n )\n mask_x = load_idx < X_len\n\n x = tl.load(X + load_idx, mask=mask_x)\n\n tl.store(Y + i, x, mask=mask_y)\n\n\n# The launch code wrapped to expose it as a custom operator in our namespace\n@custom_op(\"torchtrt_ex::triton_circular_pad\", mutates_args=()) # type: ignore[misc]\ndef triton_circular_pad(x: torch.Tensor, padding: Sequence[int]) -> torch.Tensor:\n out_dims = np.ones(len(x.shape), dtype=np.int32)\n for i in range(np.size(padding) // 2):\n out_dims[len(out_dims) - i - 1] = (\n x.shape[len(out_dims) - i - 1] + padding[i * 2] + padding[i * 2 + 1]\n )\n\n y = torch.empty(tuple(out_dims.tolist()), device=x.device)\n\n N = len(x.shape)\n all_pads = np.zeros((N * 2,), dtype=np.int32)\n orig_dims = np.array(x.shape, dtype=np.int32)\n out_dims = np.array(x.shape, dtype=np.int32)\n\n for i in range(len(padding) // 2):\n out_dims[N - i - 1] += padding[i * 2] + padding[i * 2 + 1]\n all_pads[N * 2 - 2 * i - 2] = padding[i * 2]\n all_pads[N * 2 - 2 * i - 1] = padding[i * 2 + 1]\n\n blockSize = 256\n numBlocks = (int((np.prod(out_dims) + blockSize - 1) // blockSize),)\n\n circ_pad_kernel[numBlocks](\n x,\n all_pads[0],\n all_pads[2],\n all_pads[4],\n all_pads[6],\n orig_dims[0],\n orig_dims[1],\n orig_dims[2],\n orig_dims[3],\n y,\n out_dims[1],\n out_dims[2],\n out_dims[3],\n int(np.prod(orig_dims)),\n int(np.prod(out_dims)),\n BLOCK_SIZE=256,\n )\n\n return y" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Above is all that is required to create a custom operator for PyTorch. We can now call it directly as ``torch.ops.torchtrt_ex.triton_circular_pad``\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Testing our custom op\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The native PyTorch implementation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ex_input = torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3).to(\"cuda\")\npadding = (1, 1, 2, 0)\ntorch.nn.functional.pad(ex_input, padding, \"circular\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```none\ntensor([[[[5., 3., 4., 5., 3.],\n [8., 6., 7., 8., 6.],\n [2., 0., 1., 2., 0.],\n [5., 3., 4., 5., 3.],\n [8., 6., 7., 8., 6.]]]], device='cuda:0')\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our custom implementation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.ops.torchtrt_ex.triton_circular_pad(ex_input, padding)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```none\ntensor([[[[5., 3., 4., 5., 3.],\n [8., 6., 7., 8., 6.],\n [2., 0., 1., 2., 0.],\n [5., 3., 4., 5., 3.],\n [8., 6., 7., 8., 6.]]]], device='cuda:0')\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have defined the minimum to start using our custom op in PyTorch, but to take the extra step of making this operator tracable by Dynamo (a prerequisite for being supported in Torch-TensorRT),\nwe need to define a \"Fake Tensor\" implementation of the op. This function defines the effect that our kernel would have on input tensors in terms of native PyTorch ops.\nIt allows Dynamo to calculate tensor properties like sizes, stride, device etc. without needing to use real data (More information [here](https://pytorch.org/docs/main/library.html#torch.library.register_fake)).\nIn our case we can just use the native circular pad operation as our FakeTensor implementation.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@torch.library.register_fake(\"torchtrt_ex::triton_circular_pad\") # type: ignore[misc]\ndef _(x: torch.Tensor, padding: Sequence[int]) -> torch.Tensor:\n return torch.nn.functional.pad(x, padding, \"circular\")\n\n\n# Additionally one may want to define an autograd implementation for the backwards pass to round out the custom op implementation but that is beyond the scope of this tutorial (see https://pytorch.org/docs/main/library.html#torch.library.register_autograd for more)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using the Custom Operator in a Model\nWe can now create models using our custom op. Here is a small example one that uses both natively supported operators (Convolution) and our custom op.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from typing import Sequence\n\nfrom torch import nn\n\n\nclass MyModel(nn.Module): # type: ignore[misc]\n def __init__(self, padding: Sequence[int]):\n super().__init__()\n\n self.padding = padding\n self.conv = nn.Conv2d(1, 5, kernel_size=3)\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n padded_x = torch.ops.torchtrt_ex.triton_circular_pad(x, self.padding)\n y = self.conv(padded_x)\n\n return y\n\n\nmy_model = MyModel((1, 1, 2, 0)).to(\"cuda\").eval()\nwith torch.no_grad():\n my_model(ex_input)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```none\ntensor([[[[-0.2604, -0.4232, -0.3041],\n [-3.0833, -3.2461, -3.1270],\n [-0.2450, -0.4079, -0.2887]],\n\n [[ 0.2828, -0.0373, 1.0332],\n [-2.3143, -2.6344, -1.5638],\n [-1.1867, -1.5068, -0.4363]],\n\n [[ 1.7937, 1.3488, 2.1350],\n [ 0.7966, 0.3517, 1.1379],\n [ 3.5537, 3.1088, 3.8950]],\n\n [[-1.0550, -0.6163, -1.0109],\n [ 0.5245, 0.9632, 0.5686],\n [ 0.3775, 0.8162, 0.4216]],\n\n [[-0.4311, -0.1649, -1.2091],\n [-4.3668, -4.1006, -5.1447],\n [-5.0352, -4.7689, -5.8131]]]], device='cuda:0')\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we try to compile this model with Torch-TensorRT, we can see that (as of Torch-TensorRT 2.4.0) a number of subgraphs are created to run the custom op in PyTorch and the convolution in TensorRT\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch_tensorrt as torchtrt\n\ntorchtrt.compile(\n my_model,\n inputs=[ex_input],\n dryrun=True, # Check the support of the model without having to build the engines\n min_block_size=1,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```none\nGraphModule(\n (_run_on_gpu_0): GraphModule()\n (_run_on_acc_1): GraphModule(\n (conv): Module()\n )\n)\n\n++++++++++++++ Dry-Run Results for Graph +++++++++++++++++\n\nThe graph consists of 2 Total Operators, of which 1 operators are supported, 50.0% coverage\n\nThe following ops are currently unsupported or excluded from conversion, and are listed with their op-count in the graph:\n torch.ops.torchtrt_ex.triton_circular_pad.default: 1\n\nThe following nodes are currently set to run in Torch:\nNode: torch.ops.torchtrt_ex.triton_circular_pad.default, with layer location: __/triton_circular_pad\nNote: Some of the above nodes may be supported, but were not included in a TRT graph by the partitioner\n\nCompiled with: CompilationSettings(enabled_precisions={}, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=True, hardware_compatible=False)\n\n Graph Structure:\n\n Inputs: List[Tensor: (1, 1, 3, 3)@float32]\n ...\n TRT Engine #1 - Submodule name: _run_on_acc_1\n Engine Inputs: List[Tensor: (1, 1, 5, 5)@float32]\n Number of Operators in Engine: 1\n Engine Outputs: Tensor: (1, 5, 3, 3)@float32\n ...\n Outputs: List[Tensor: (1, 5, 3, 3)@float32]\n\n --------- Aggregate Stats ---------\n\n Average Number of Operators per TRT Engine: 1.0\n Most Operators in a TRT Engine: 1\n\n ********** Recommendations **********\n\n - For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s)\n - The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s)\n```\nWe see that there is going to be 2 subgraphs, one that will run through PyTorch for our custom op and one through TensorRT for the convolution. This graph break is going to be a significant portion of the latency of this model.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Wrapping Custom Kernels to use in TensorRT\n\nTo address this graph break, the first step is to make our kernel implementation available in TensorRT. Again this can be done in either C++ or Python. For the actual details on how to implement\nTensorRT plugins refer [here](https://github.com/NVIDIA/TensorRT/tree/release/10.0/samples/python/python_plugin). From a high level, similar to PyTorch you will need to\ndefine systems to handle setting up the operator, calculating the effect of the operation abstractly, serializing the op and the actual mechanics of calling the implementation of the op in the engine.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import pickle as pkl\nfrom typing import Any, List, Optional, Self\n\nimport cupy as cp # Needed to work around API gaps in PyTorch to build torch.Tensors around preallocated CUDA memory\nimport numpy as np\nimport tensorrt as trt\n\n\nclass CircularPaddingPlugin(trt.IPluginV2DynamicExt): # type: ignore[misc]\n def __init__(\n self, field_collection: Optional[List[trt.PluginFieldCollection]] = None\n ):\n super().__init__()\n self.pads = []\n self.X_shape: List[int] = []\n\n self.num_outputs = 1\n self.plugin_namespace = \"\"\n self.plugin_type = \"CircularPaddingPlugin\"\n self.plugin_version = \"1\"\n\n if field_collection is not None:\n assert field_collection[0].name == \"pads\"\n self.pads = field_collection[0].data\n\n def get_output_datatype(\n self, index: int, input_types: List[trt.DataType]\n ) -> trt.DataType:\n return input_types[0]\n\n def get_output_dimensions(\n self,\n output_index: int,\n inputs: List[trt.DimsExprs],\n exprBuilder: trt.IExprBuilder,\n ) -> trt.DimsExprs:\n output_dims = trt.DimsExprs(inputs[0])\n\n for i in range(np.size(self.pads) // 2):\n output_dims[len(output_dims) - i - 1] = exprBuilder.operation(\n trt.DimensionOperation.SUM,\n inputs[0][len(output_dims) - i - 1],\n exprBuilder.constant(self.pads[i * 2] + self.pads[i * 2 + 1]),\n )\n\n return output_dims\n\n def configure_plugin(\n self,\n inp: List[trt.DynamicPluginTensorDesc],\n out: List[trt.DynamicPluginTensorDesc],\n ) -> None:\n X_dims = inp[0].desc.dims\n self.X_shape = np.zeros((len(X_dims),))\n for i in range(len(X_dims)):\n self.X_shape[i] = X_dims[i]\n\n def serialize(self) -> bytes:\n return pkl.dumps({\"pads\": self.pads})\n\n def supports_format_combination(\n self, pos: int, in_out: List[trt.PluginTensorDesc], num_inputs: int\n ) -> bool:\n assert num_inputs == 1\n assert pos < len(in_out)\n\n desc = in_out[pos]\n if desc.format != trt.TensorFormat.LINEAR:\n return False\n\n # first input should be float16 or float32\n if pos == 0:\n return bool(\n (desc.type == trt.DataType.FLOAT) or desc.type == (trt.DataType.HALF)\n )\n\n # output should have the same type as the input\n if pos == 1:\n return bool((in_out[0].type == desc.type))\n\n return False\n\n def enqueue(\n self,\n input_desc: List[trt.PluginTensorDesc],\n output_desc: List[trt.PluginTensorDesc],\n inputs: List[int],\n outputs: List[int],\n workspace: int,\n stream: int,\n ) -> None:\n # Host code is slightly different as this will be run as part of the TRT execution\n in_dtype = torchtrt.dtype.try_from(input_desc[0].type).to(np.dtype)\n\n a_mem = cp.cuda.UnownedMemory(\n inputs[0], np.prod(input_desc[0].dims) * cp.dtype(in_dtype).itemsize, self\n )\n c_mem = cp.cuda.UnownedMemory(\n outputs[0],\n np.prod(output_desc[0].dims) * cp.dtype(in_dtype).itemsize,\n self,\n )\n\n a_ptr = cp.cuda.MemoryPointer(a_mem, 0)\n c_ptr = cp.cuda.MemoryPointer(c_mem, 0)\n\n a_d = cp.ndarray((np.prod(input_desc[0].dims)), dtype=in_dtype, memptr=a_ptr)\n c_d = cp.ndarray((np.prod(output_desc[0].dims)), dtype=in_dtype, memptr=c_ptr)\n\n a_t = torch.as_tensor(a_d, device=\"cuda\")\n c_t = torch.as_tensor(c_d, device=\"cuda\")\n\n N = len(self.X_shape)\n all_pads = np.zeros((N * 2,), dtype=np.int32)\n orig_dims = np.array(self.X_shape, dtype=np.int32)\n out_dims = np.array(self.X_shape, dtype=np.int32)\n\n for i in range(np.size(self.pads) // 2):\n out_dims[N - i - 1] += self.pads[i * 2] + self.pads[i * 2 + 1]\n all_pads[N * 2 - 2 * i - 2] = self.pads[i * 2]\n all_pads[N * 2 - 2 * i - 1] = self.pads[i * 2 + 1]\n\n all_pads = all_pads.tolist()\n orig_dims = orig_dims.tolist()\n out_dims = out_dims.tolist()\n\n blockSize = 256\n numBlocks = (int((np.prod(out_dims) + blockSize - 1) // blockSize),)\n\n # Call the same kernel implementation we use in PyTorch\n circ_pad_kernel[numBlocks](\n a_t,\n all_pads[0],\n all_pads[2],\n all_pads[4],\n all_pads[6],\n orig_dims[0],\n orig_dims[1],\n orig_dims[2],\n orig_dims[3],\n c_t,\n out_dims[1],\n out_dims[2],\n out_dims[3],\n int(np.prod(orig_dims)),\n int(np.prod(out_dims)),\n BLOCK_SIZE=256,\n )\n\n def clone(self) -> Self:\n cloned_plugin = CircularPaddingPlugin()\n cloned_plugin.__dict__.update(self.__dict__)\n return cloned_plugin\n\n\nclass CircularPaddingPluginCreator(trt.IPluginCreator): # type: ignore[misc]\n def __init__(self):\n super().__init__()\n\n self.name = \"CircularPaddingPlugin\"\n self.plugin_namespace = \"\"\n self.plugin_version = \"1\"\n self.field_names = trt.PluginFieldCollection(\n [trt.PluginField(\"pads\", np.array([]), trt.PluginFieldType.INT32)]\n )\n\n def create_plugin(\n self, name: str, field_collection: trt.PluginFieldCollection_\n ) -> CircularPaddingPlugin:\n return CircularPaddingPlugin(field_collection)\n\n def deserialize_plugin(self, name: str, data: bytes) -> CircularPaddingPlugin:\n pads_dict = pkl.loads(data)\n print(pads_dict)\n deserialized = CircularPaddingPlugin()\n deserialized.__dict__.update(pads_dict)\n print(deserialized.pads)\n return deserialized\n\n\n# Register the plugin creator in the TensorRT Plugin Registry\nTRT_PLUGIN_REGISTRY = trt.get_plugin_registry()\nTRT_PLUGIN_REGISTRY.register_creator(CircularPaddingPluginCreator(), \"\") # type: ignore[no-untyped-call]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using Torch-TensorRT to Insert the Kernel\nNow with our TensorRT plugin, we can create a converter so that Torch-TensorRT knows to insert our plugin in place of our custom circular padding operator.\nMore information on writing converters can be found [here](https://pytorch.org/TensorRT/contributors/dynamo_converters.html)\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from typing import Dict, Tuple\n\nfrom torch.fx.node import Argument, Target\nfrom torch_tensorrt.dynamo.conversion import (\n ConversionContext,\n dynamo_tensorrt_converter,\n)\nfrom torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor\nfrom torch_tensorrt.fx.converters.converter_utils import set_layer_name\n\n\n@dynamo_tensorrt_converter(torch.ops.torchtrt_ex.triton_circular_pad.default) # type: ignore\n# Recall the schema defined above:\n# torch.ops.torchtrt_ex.triton_circular_pad.default(Tensor x, IntList padding) -> Tensor\ndef circular_padding_converter(\n ctx: ConversionContext,\n target: Target,\n args: Tuple[Argument, ...],\n kwargs: Dict[str, Argument],\n name: str,\n):\n # How to retrieve a plugin if it is defined elsewhere (e.g. linked library)\n plugin_registry = trt.get_plugin_registry()\n plugin_creator = plugin_registry.get_plugin_creator(\n type=\"CircularPaddingPlugin\", version=\"1\", plugin_namespace=\"\"\n )\n assert plugin_creator, f\"Unable to find CircularPaddingPlugin creator\"\n\n # Pass configurations to the plugin implementation\n field_configs = trt.PluginFieldCollection(\n [\n trt.PluginField(\n \"pads\",\n np.array(\n args[1], dtype=np.int32\n ), # Arg 1 of `torch.ops.torchtrt_ex.triton_circular_pad` is the int list containing the padding settings. Note: the dtype matters as you are eventually passing this as a c-like buffer\n trt.PluginFieldType.INT32,\n ),\n ]\n )\n\n plugin = plugin_creator.create_plugin(name=name, field_collection=field_configs)\n assert plugin, \"Unable to create CircularPaddingPlugin\"\n\n input_tensor = args[\n 0\n ] # Arg 0 `torch.ops.torchtrt_ex.triton_circular_pad` is the input tensor\n if not isinstance(input_tensor, trt.ITensor):\n # Freeze input tensor if not TensorRT Tensor already\n input_tensor = get_trt_tensor(ctx, input_tensor, f\"{name}_input\")\n\n layer = ctx.net.add_plugin_v2(\n [input_tensor], plugin\n ) # Add the plugin to the network being constructed\n layer.name = f\"circular_padding_plugin-{name}\"\n return layer.get_output(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we are now able to fully compile our model\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "trt_model = torchtrt.compile(\n my_model,\n inputs=[ex_input],\n min_block_size=1,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```none\nGraphModule(\n (_run_on_acc_0): TorchTensorRTModule()\n)\n\n+++++++++++++++ Dry-Run Results for Graph ++++++++++++++++\n\nThe graph consists of 2 Total Operators, of which 2 operators are supported, 100.0% coverage\n\nCompiled with: CompilationSettings(enabled_precisions={}, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False)\n\n Graph Structure:\n\n Inputs: List[Tensor: (1, 1, 3, 3)@float32]\n ...\n TRT Engine #1 - Submodule name: _run_on_acc_0\n Engine Inputs: List[Tensor: (1, 1, 3, 3)@float32]\n Number of Operators in Engine: 2\n Engine Outputs: Tensor: (1, 5, 3, 3)@float32\n ...\n Outputs: List[Tensor: (1, 5, 3, 3)@float32]\n\n ---------- Aggregate Stats -------------\n\n Average Number of Operators per TRT Engine: 2.0\n Most Operators in a TRT Engine: 2\n\n ********** Recommendations **********\n\n - For minimal graph segmentation, select min_block_size=2 which would generate 1 TRT engine(s)\n - The current level of graph segmentation is equivalent to selecting min_block_size=2 which generates 1 TRT engine(s)\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, now there is only one subgraph created for the TensorRT engine that contains both our custom kernel and the native convolution operator.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "with torch.no_grad():\n print(trt_model(ex_input))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```none\ntensor([[[[-0.2604, -0.4232, -0.3041],\n [-3.0833, -3.2461, -3.1270],\n [-0.2450, -0.4079, -0.2887]],\n\n [[ 0.2828, -0.0373, 1.0332],\n [-2.3143, -2.6344, -1.5638],\n [-1.1867, -1.5068, -0.4363]],\n\n [[ 1.7937, 1.3488, 2.1350],\n [ 0.7966, 0.3517, 1.1379],\n [ 3.5537, 3.1088, 3.8950]],\n\n [[-1.0550, -0.6163, -1.0109],\n [ 0.5245, 0.9632, 0.5686],\n [ 0.3775, 0.8162, 0.4216]],\n\n [[-0.4311, -0.1649, -1.2091],\n [-4.3668, -4.1006, -5.1447],\n [-5.0352, -4.7689, -5.8131]]]], device='cuda:0')\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can verify our implementation is run correctly by both TensorRT and PyTorch\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "with torch.no_grad():\n print(my_model(ex_input) - trt_model(ex_input))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```none\ntensor([[[[0., 0., 0.],\n [0., 0., 0.],\n [0., 0., 0.]],\n\n [[0., 0., 0.],\n [0., 0., 0.],\n [0., 0., 0.]],\n\n [[0., 0., 0.],\n [0., 0., 0.],\n [0., 0., 0.]],\n\n [[0., 0., 0.],\n [0., 0., 0.],\n [0., 0., 0.]],\n\n [[0., 0., 0.],\n [0., 0., 0.],\n [0., 0., 0.]]]], device='cuda:0', grad_fn=)\n```\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 0 }