{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Weight Streaming\n\nWeight streaming in TensorRT is a powerful feature designed to overcome GPU memory limitations\nwhen working with large models. It enables running models larger than available GPU memory\nby streaming weight data from host (CPU) memory to GPU memory during inference.\n\nStreaming larger amounts of memory will likely result in lower performance. But if\nstreaming weights allows the user to run larger batch sizes and it can lead to higher throughput.\nThis increased throughput can sometimes outweigh the slowdown caused by streaming weights.\nThe optimal amount of memory to stream varies depending on the specific model and hardware.\nExperimenting with different memory limits can help find the best balance between streaming\noverhead and batch size benefits.\n\nThis example uses a pre-trained Llama-2 model and show how to use weight streaming feature with\nTorch-TensorRT.\n 1. compile option - build trt engine with weight streaming feature\n 2. runtime api - weight streaming budget control by context manager\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports and Model Definition\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import copy\nimport timeit\n\nimport numpy as np\nimport torch\nimport torch_tensorrt\nfrom transformers import AutoModelForCausalLM\n\n\ndef export_llm(model, inputs, min_seq_len=1, max_seq_len=16):\n \"\"\"\n Exports the LLM model into an ExportedProgram with dynamic shapes.\n In the case of guard failures due to some PyTorch kernel implements, we also\n try to re-export the graph by expressing them as runtime assert nodes\n \"\"\"\n with torch.no_grad():\n # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604\n seq_len = torch.export.Dim(\"seq_len\", min=min_seq_len, max=max_seq_len)\n position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device)\n try:\n print(\"Trying to export the model using torch.export.export()..\")\n # strict=False only enables aotautograd tracing and excludes dynamo.\n ep = torch.export.export(\n model,\n args=(inputs,),\n kwargs={\"position_ids\": position_ids},\n dynamic_shapes=({1: seq_len}, {1: seq_len}),\n strict=False,\n )\n except:\n print(\n \"Trying torch.export._trace._export to trace the graph since torch.export.export() failed\"\n )\n # This API is used to express the constraint violation guards as asserts in the graph.\n ep = torch.export._trace._export(\n model,\n args=(inputs,),\n kwargs={\"position_ids\": position_ids},\n dynamic_shapes=({1: seq_len}, {1: seq_len}),\n strict=False,\n prefer_deferred_runtime_asserts_over_guards=True,\n )\n\n return ep\n\n\ndef time_generate(model, inputs, output_seq_length, iterations=10):\n \"\"\"\n Measure the time for generating a sentence over certain number of iterations\n \"\"\"\n # We only support single input (B x seq_len) for LLMs now\n input_seq = inputs[0]\n with torch.no_grad():\n timings = []\n for _ in range(iterations):\n start_time = timeit.default_timer()\n inputs_copy = copy.copy(input_seq)\n # Greedy decoding of the model. This generates up to max_tokens.\n while inputs_copy.shape[1] <= output_seq_length:\n outputs = model(inputs_copy)\n logits = outputs.logits\n next_token_logits = logits[:, -1, :]\n next_tokens = torch.argmax(next_token_logits, dim=-1)\n inputs_copy = torch.cat([inputs_copy, next_tokens[:, None]], dim=-1)\n torch.cuda.synchronize()\n end_time = timeit.default_timer()\n timings.append(end_time - start_time)\n\n times = np.array(timings)\n time_mean_ms = np.mean(times) * 1000\n\n return time_mean_ms\n\n\n# Load the LLaMA-2 model\nDEVICE = torch.device(\"cuda:0\")\nllama_path = \"meta-llama/Llama-2-7b-chat-hf\"\nwith torch.no_grad():\n model = AutoModelForCausalLM.from_pretrained(\n llama_path, use_cache=False, attn_implementation=\"eager\"\n ).eval()\n\n# Set input and output sequence lengths\nisl = 128\nosl = 256\n\n# Create random input tensors\ninput_tensors = [torch.randint(0, 5, (1, isl), dtype=torch.int64).cuda()]\n# Convert the model to half precision (FP16)\nmodel = model.half()\n# Exports the LLM model into an ExportedProgram with dynamic shapes.\nllama2_ep = export_llm(model, input_tensors[0], max_seq_len=osl)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compiler option\n\nenable_weight_streaming=True option and use_explicit_typing=True are required to build\nthe engine with weight streaming feature. use_explicit_typing=True option creates a\n[strongly typed network](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#strongly-typed-networks) and only float32 precision is allowed in enabled_precisions option\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Create a TensorRT-compiled model\ntrt_model = torch_tensorrt.dynamo.compile(\n llama2_ep,\n inputs=input_tensors,\n enabled_precisions={torch.float32},\n truncate_double=True,\n device=DEVICE,\n use_explicit_typing=True,\n enable_weight_streaming=True,\n)\n\n# Warm up for 3 iterations\n_ = time_generate(trt_model, input_tensors, osl, 3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Running with automatic budget size\n\nOnce you specify the enable_weight_streaming compile option, automatic budget size is configured.\nThis automatic size may not always provide the optimal solution because the automatically determined\nbudget lacks insight into the user's specific memory constraints and usage patterns\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Weight streaming context to get current weight budget information\nweight_streaming_ctx = torch_tensorrt.runtime.weight_streaming(trt_model)\n# Measure the mean latency of the model with weight streaming\nmean_latency = time_generate(trt_model, input_tensors, osl, 1)\n# Calculate the percentage of current weight budget used\nweight_budget_pct = (\n weight_streaming_ctx.device_budget / weight_streaming_ctx.total_device_budget * 100\n)\nprint(\n f\"Set weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms\"\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Running with weight streaming context manager\n\nWeight streaming budget can be limited by using weight streaming context manager.\nThe permissible range for the budget size is from 0 to ctx.total_device_budget.\n0 means maximum memory savings occur by using minimum amounts of memory. Value\nequal to ctx.total_device_budget will disable weight streaming.\nIf multiple trt engines are created, budgets are distributed proportionally\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Use a context manager for weight streaming\nwith torch_tensorrt.runtime.weight_streaming(trt_model) as weight_streaming_ctx:\n # Get the total size of streamable weights in the engine\n streamable_budget = weight_streaming_ctx.total_device_budget\n\n # Scenario 1: Automatic weight streaming budget\n # Get the automatically determined weight streaming budget\n requested_budget = weight_streaming_ctx.get_automatic_weight_streaming_budget()\n # Set the device budget to the automatically determined value\n weight_streaming_ctx.device_budget = requested_budget\n # Measure the mean latency with automatic budget\n mean_latency = time_generate(trt_model, input_tensors, osl, 1)\n # Calculate the percentage of the weight budget used\n weight_budget_pct = (\n weight_streaming_ctx.device_budget\n / weight_streaming_ctx.total_device_budget\n * 100\n )\n print(\n f\"Set auto weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms\"\n )\n\n # Scenario 2: Manual 10% weight streaming budget\n # Set the budget to 10% of the total streamable weights\n requested_budget = int(streamable_budget * 0.1)\n weight_streaming_ctx.device_budget = requested_budget\n # Measure the mean latency with 10% budget\n mean_latency = time_generate(trt_model, input_tensors, osl, 1)\n # Calculate the percentage of the weight budget used\n weight_budget_pct = (\n weight_streaming_ctx.device_budget\n / weight_streaming_ctx.total_device_budget\n * 100\n )\n print(\n f\"Set weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms\"\n )" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 0 }