{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Torch-TensorRT Distributed Inference\n\nThis interactive script is intended as a sample of distributed inference using data\nparallelism using Accelerate\nlibrary with the Torch-TensorRT workflow on GPT2 model.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports and Model Definition\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torch_tensorrt\nfrom accelerate import PartialState\nfrom transformers import AutoTokenizer, GPT2LMHeadModel\n\ntokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n\n# Set input prompts for different devices\nprompt1 = \"GPT2 is a model developed by.\"\nprompt2 = \"Llama is a model developed by \"\n\ninput_id1 = tokenizer(prompt1, return_tensors=\"pt\").input_ids\ninput_id2 = tokenizer(prompt2, return_tensors=\"pt\").input_ids\n\ndistributed_state = PartialState()\n\n# Import GPT2 model and load to distributed devices\nmodel = GPT2LMHeadModel.from_pretrained(\"gpt2\").eval().to(distributed_state.device)\n\n\n# Instantiate model with Torch-TensorRT backend\nmodel.forward = torch.compile(\n model.forward,\n backend=\"torch_tensorrt\",\n options={\n \"truncate_long_and_double\": True,\n \"enabled_precisions\": {torch.float16},\n },\n dynamic=False,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Assume there are 2 processes (2 devices)\nwith distributed_state.split_between_processes([input_id1, input_id2]) as prompt:\n cur_input = torch.clone(prompt[0]).to(distributed_state.device)\n\n gen_tokens = model.generate(\n cur_input,\n do_sample=True,\n temperature=0.9,\n max_length=100,\n )\n gen_text = tokenizer.batch_decode(gen_tokens)[0]" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.14" } }, "nbformat": 4, "nbformat_minor": 0 }