Note
Go to the end to download the full example code
Low CPU Memory Compilation Example¶
This example demonstrates compiling a model with a bounded CPU (host) memory budget using Torch-TensorRT Dynamo. Limiting host RAM use is helpful on memory-constrained machines or when compiling very large models.
Key notes: - The toy model below has roughly 430 MB of parameters. We set the CPU
memory budget to 2 GiB. At compile time, only about 900 MB of host RAM may remain available. We expect at most 403 * 4 = 1612 MB of memory to be used by the model. So the model is partitioned into two subgraphs to fit the memory budget.
Performance impact varies by model. When the number of TensorRT engines created is small, the impact is typically minimal.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt as torchtrt
from torch_tensorrt.dynamo.conversion import CompilationSettings
class net(nn.Module):
def __init__(self):
super().__init__()
# Intentionally large layers to stress host memory during compilation.
self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1)
self.bn1 = nn.BatchNorm2d(4096)
self.conv2 = nn.Conv2d(4096, 1024, 3, padding=1)
self.bn2 = nn.BatchNorm2d(1024)
self.fc1 = nn.Linear(1024 * 56 * 56, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = F.max_pool2d(x, (2, 2))
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = F.max_pool2d(x, (2, 2))
x = torch.flatten(x, 1)
return self.fc1(x)
model = net().eval()
model.to("cuda")
inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
use_python_runtime = False
compilation_options = {
"use_python_runtime": use_python_runtime,
"enabled_precisions": enabled_precisions,
"min_block_size": 1,
"immutable_weights": True,
"reuse_cached_engines": False,
"enable_resource_partitioning": True,
"cpu_memory_budget": 2 * 1024 * 1024 * 1024, # 2 GiB in bytes
}
settings = CompilationSettings(**compilation_options)
with torchtrt.dynamo.Debugger(
log_level="debug",
logging_dir="/home/profile/logging/moe",
engine_builder_monitor=False,
):
exp_program = torch.export.export(model, tuple(inputs))
trt_gm = torchtrt.dynamo.compile(
exp_program,
inputs=inputs,
**compilation_options,
)
# Expect two back-to-back TensorRT engines due to partitioning under the memory budget.
print(trt_gm)
"""
You should be able to see two back-to-back TensorRT engines in the graph
Graph Structure:
Inputs: List[Tensor: (1, 1024, 224, 224)@float32]
...
TRT Engine #1 - Submodule name: _run_on_acc_0_resource_split_0
Engine Inputs: List[Tensor: (1, 1024, 224, 224)@float32]
Number of Operators in Engine: 9
Engine Outputs: List[Tensor: (1, 1024, 112, 112)@float32]
...
TRT Engine #2 - Submodule name: _run_on_acc_0_resource_split_1
Engine Inputs: List[Tensor: (1, 1024, 112, 112)@float32]
Number of Operators in Engine: 3
Engine Outputs: List[Tensor: (1, 10)@float32]
...
Outputs: List[Tensor: (1, 10)@float32]
------------------------- Aggregate Stats -------------------------
Average Number of Operators per TRT Engine: 6.0
Most Operators in a TRT Engine: 9
********** Recommendations **********
- For minimal graph segmentation, select min_block_size=9 which would generate 1 TRT engine(s)
- For moderate graph segmentation, select min_block_size=6 which would generate 1 TRT engine(s)
- The current level of graph segmentation is equivalent to selecting min_block_size=3 which generates 2 TRT engine(s)
GraphModule(
(_run_on_acc_0_resource_split_0): TorchTensorRTModule()
(_run_on_acc_0_resource_split_1): TorchTensorRTModule()
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_run_on_acc_0_resource_split_0 = self._run_on_acc_0_resource_split_0(x); x = None
_run_on_acc_0_resource_split_1 = self._run_on_acc_0_resource_split_1(_run_on_acc_0_resource_split_0); _run_on_acc_0_resource_split_0 = None
return pytree.tree_unflatten((_run_on_acc_0_resource_split_1,), self._out_spec)
)
"""
Total running time of the script: ( 0 minutes 0.000 seconds)