{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Hierarchical Partitioner Example\n\nBasic example on how to use the hierarchical adjacency partitioner function and manually compile the partitioned model.\nNot yet available in the compile API.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from typing import Any, Callable\n\nimport torch\nimport torch.nn as nn\nimport torch_tensorrt\nfrom torch_tensorrt._enums import dtype\nfrom torch_tensorrt.dynamo import partitioning\nfrom torch_tensorrt.dynamo._compiler import convert_module\nfrom torch_tensorrt.dynamo.conversion._ConverterRegistry import (\n DYNAMO_CONVERTERS as CONVERTERS,\n)\nfrom torch_tensorrt.dynamo.lowering import (\n get_decompositions,\n pre_export_lowering,\n)\nfrom torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import (\n hierarchical_adjacency_partition,\n)\nfrom torch_tensorrt.dynamo.utils import (\n get_output_metadata,\n)\nfrom torchvision import models\n\n\nclass InductorModule(torch.nn.Module): # type: ignore[misc]\n \"\"\"Wrapper module for inductor compiled function.\"\"\"\n\n def __init__(self, func: Callable[..., Any]) -> None:\n super().__init__()\n self.func = func\n\n def forward(self, *args: Any, **kwargs: Any) -> Any:\n return self.func(*args, **kwargs)\n\n\nclass SimpleModel(nn.Module):\n def __init__(self):\n super().__init__()\n self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)\n self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)\n self.bn1 = nn.BatchNorm2d(64)\n self.bn2 = nn.BatchNorm2d(128)\n\n def forward(self, x):\n x = self.conv1(x)\n x = self.bn1(x)\n x = torch.relu(x)\n x = self.conv2(x)\n x = self.bn2(x)\n x = torch.relu(x)\n return x\n\n\ndef main():\n # Create model\n model = SimpleModel().cuda()\n # model = models.efficientnet_b0(pretrained=True).cuda()\n model = model.eval()\n\n # Create example input\n example_input = torch.randn(1, 3, 224, 224).cuda()\n\n exported_program = torch.export.export(model, (example_input,))\n exported_program = pre_export_lowering(exported_program)\n exported_program = exported_program.run_decompositions(get_decompositions())\n\n gm = exported_program.module()\n\n print(\"Original Model Structure:\\n\", gm)\n\n with torch.no_grad():\n original_output = model(example_input)\n\n # 1. Partition the model into blocks that can be executed by different backends\n partitioned_model, op_support = hierarchical_adjacency_partition(\n gm,\n min_block_size=1,\n backend_priority=[\"inductor\", \"tensorrt\"],\n backend_support_map={\n \"inductor\": {\n \"torch.ops.aten.convolution.default\",\n },\n \"tensorrt\": CONVERTERS.keys(),\n },\n torch_executed_ops={\n \"torch.ops.aten._native_batch_norm_legit_no_training.default\"\n },\n require_full_compilation=False,\n skip_fusion=True,\n )\n\n print(\"1. Partitioned Model Structure:\\n\", partitioned_model)\n\n # 2. Compile each submodule with the corresponding backend\n submodule_node_dict = {}\n for node in partitioned_model.graph.nodes:\n if \"_run_on_acc\" not in node.name:\n continue\n submodule_node_dict[node.name] = node\n\n # Store compiled replicas of Torch subgraphs\n compiled_modules = {}\n\n for name, _ in partitioned_model.named_children():\n submodule = getattr(partitioned_model, name)\n if not isinstance(submodule, torch.fx.graph_module.GraphModule):\n continue\n\n if \"_run_on_acc\" not in name:\n submodule.to(\"cuda\")\n continue\n\n if name not in submodule_node_dict:\n raise ValueError(\n f\"node_name: {name} does not exist in the submodule node dictionary\"\n )\n\n # set the submodule metadata back to the parent module_node\n metadata_list = get_output_metadata(submodule)\n assert len(metadata_list) > 0\n metadata_keys = [\"val\", \"tensor_meta\"]\n for key in metadata_keys:\n if key not in submodule_node_dict[name].meta:\n meta_val_list = [\n metadata[key] for metadata in metadata_list if key in metadata\n ]\n submodule_node_dict[name].meta[key] = meta_val_list\n break\n\n # Get the submodule inputs for min, opt, max shapes of the graph inputs\n submodule_inputs = partitioning.construct_submodule_inputs(submodule)\n assert submodule_inputs is not None\n\n # compile submodule with pytorch inductor backend\n if \"_run_on_acc_inductor\" in name:\n sub_inputs = []\n for input in submodule_inputs:\n sub_input = input.torch_tensor.to(\n dtype.to(input.dtype, t=torch.dtype)\n ).cuda()\n sub_inputs.append(sub_input)\n\n compiled_func = torch._inductor.compile(\n submodule,\n sub_inputs,\n )\n # Wrap the compiled function to be a torch.nn.Module\n compiled_submodule = InductorModule(compiled_func)\n\n # compile submodule with tensorrt backend\n elif \"_run_on_acc_tensorrt\" in name:\n compiled_submodule = convert_module(\n submodule,\n submodule_inputs,\n name=name,\n )\n else:\n raise ValueError(f\"Unknown backend for submodule: {name}\")\n\n compiled_modules[name] = compiled_submodule\n\n # Replace all FX Modules with compiled Modules\n for name, compiled_module in compiled_modules.items():\n setattr(partitioned_model, name, compiled_module)\n\n print(\"2. Compiled Model Structure:\\n\", partitioned_model)\n\n with torch.no_grad():\n partitioned_output = partitioned_model(example_input)\n print(\n \"3. Verify that Partitioned output == Original output:\",\n torch.allclose(partitioned_output, original_output, 1e-2, 1e-2),\n )\n\n\nif __name__ == \"__main__\":\n main()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 0 }