{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Engine Caching (BERT)\n\nSmall caching example on BERT.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\nimport torch\nimport torch_tensorrt\nfrom engine_caching_example import remove_timing_cache\nfrom transformers import BertModel\n\nnp.random.seed(0)\ntorch.manual_seed(0)\n\nmodel = BertModel.from_pretrained(\"bert-base-uncased\", return_dict=False).cuda().eval()\ninputs = [\n torch.randint(0, 2, (1, 14), dtype=torch.int32).to(\"cuda\"),\n torch.randint(0, 2, (1, 14), dtype=torch.int32).to(\"cuda\"),\n]\n\n\ndef compile_bert(iterations=3):\n times = []\n start = torch.cuda.Event(enable_timing=True)\n end = torch.cuda.Event(enable_timing=True)\n\n # The 1st iteration is to measure the compilation time without engine caching\n # The 2nd and 3rd iterations are to measure the compilation time with engine caching.\n # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.\n # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.\n for i in range(iterations):\n # remove timing cache and reset dynamo for engine caching messurement\n remove_timing_cache()\n torch._dynamo.reset()\n\n if i == 0:\n cache_built_engines = False\n reuse_cached_engines = False\n else:\n cache_built_engines = True\n reuse_cached_engines = True\n\n start.record()\n compilation_kwargs = {\n \"use_python_runtime\": False,\n \"enabled_precisions\": {torch.float},\n \"truncate_double\": True,\n \"min_block_size\": 1,\n \"immutable_weights\": False,\n \"cache_built_engines\": cache_built_engines,\n \"reuse_cached_engines\": reuse_cached_engines,\n \"engine_cache_dir\": \"/tmp/torch_trt_bert_engine_cache\",\n \"engine_cache_size\": 1 << 30, # 1GB\n }\n optimized_model = torch.compile(\n model,\n backend=\"torch_tensorrt\",\n options=compilation_kwargs,\n )\n with torch.no_grad():\n optimized_model(*inputs)\n end.record()\n torch.cuda.synchronize()\n times.append(start.elapsed_time(end))\n\n print(\"-----compile bert-----> compilation time:\\n\", times, \"milliseconds\")\n\n\nif __name__ == \"__main__\":\n compile_bert()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 0 }