.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "advanced/dynamic_quantization_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_advanced_dynamic_quantization_tutorial.py: (beta) Dynamic Quantization on an LSTM Word Language Model ================================================================== **Author**: `James Reed `_ **Edited by**: `Seth Weidman `_ Introduction ------------ Quantization involves converting the weights and activations of your model from float to int, which can result in smaller model size and faster inference with only a small hit to accuracy. In this tutorial, we will apply the easiest form of quantization - `dynamic quantization `_ - to an LSTM-based next word-prediction model, closely following the `word language model `_ from the PyTorch examples. .. GENERATED FROM PYTHON SOURCE LINES 22-32 .. code-block:: default # imports import os from io import open import time import torch import torch.nn as nn import torch.nn.functional as F .. GENERATED FROM PYTHON SOURCE LINES 33-39 1. Define the model ------------------- Here we define the LSTM model architecture, following the `model `_ from the word language model example. .. GENERATED FROM PYTHON SOURCE LINES 39-73 .. code-block:: default class LSTMModel(nn.Module): """Container module with an encoder, a recurrent module, and a decoder.""" def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5): super(LSTMModel, self).__init__() self.drop = nn.Dropout(dropout) self.encoder = nn.Embedding(ntoken, ninp) self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout) self.decoder = nn.Linear(nhid, ntoken) self.init_weights() self.nhid = nhid self.nlayers = nlayers def init_weights(self): initrange = 0.1 self.encoder.weight.data.uniform_(-initrange, initrange) self.decoder.bias.data.zero_() self.decoder.weight.data.uniform_(-initrange, initrange) def forward(self, input, hidden): emb = self.drop(self.encoder(input)) output, hidden = self.rnn(emb, hidden) output = self.drop(output) decoded = self.decoder(output) return decoded, hidden def init_hidden(self, bsz): weight = next(self.parameters()) return (weight.new_zeros(self.nlayers, bsz, self.nhid), weight.new_zeros(self.nlayers, bsz, self.nhid)) .. GENERATED FROM PYTHON SOURCE LINES 74-82 2. Load in the text data ------------------------ Next, we load the `Wikitext-2 dataset `_ into a `Corpus`, again following the `preprocessing `_ from the word language model example. .. GENERATED FROM PYTHON SOURCE LINES 82-132 .. code-block:: default class Dictionary(object): def __init__(self): self.word2idx = {} self.idx2word = [] def add_word(self, word): if word not in self.word2idx: self.idx2word.append(word) self.word2idx[word] = len(self.idx2word) - 1 return self.word2idx[word] def __len__(self): return len(self.idx2word) class Corpus(object): def __init__(self, path): self.dictionary = Dictionary() self.train = self.tokenize(os.path.join(path, 'train.txt')) self.valid = self.tokenize(os.path.join(path, 'valid.txt')) self.test = self.tokenize(os.path.join(path, 'test.txt')) def tokenize(self, path): """Tokenizes a text file.""" assert os.path.exists(path) # Add words to the dictionary with open(path, 'r', encoding="utf8") as f: for line in f: words = line.split() + [''] for word in words: self.dictionary.add_word(word) # Tokenize file content with open(path, 'r', encoding="utf8") as f: idss = [] for line in f: words = line.split() + [''] ids = [] for word in words: ids.append(self.dictionary.word2idx[word]) idss.append(torch.tensor(ids).type(torch.int64)) ids = torch.cat(idss) return ids model_data_filepath = 'data/' corpus = Corpus(model_data_filepath + 'wikitext-2') .. GENERATED FROM PYTHON SOURCE LINES 133-149 3. Load the pretrained model ----------------------------- This is a tutorial on dynamic quantization, a quantization technique that is applied after a model has been trained. Therefore, we'll simply load some pretrained weights into this model architecture; these weights were obtained by training for five epochs using the default settings in the word language model example. Before running this tutorial, download the required pre-trained model: .. code-block:: bash wget https://s3.amazonaws.com/pytorch-tutorial-assets/word_language_model_quantize.pth Place the downloaded file in the data directory or update the model_data_filepath accordingly. .. GENERATED FROM PYTHON SOURCE LINES 149-170 .. code-block:: default ntokens = len(corpus.dictionary) model = LSTMModel( ntoken = ntokens, ninp = 512, nhid = 256, nlayers = 5, ) model.load_state_dict( torch.load( model_data_filepath + 'word_language_model_quantize.pth', map_location=torch.device('cpu'), weights_only=True ) ) model.eval() print(model) .. rst-class:: sphx-glr-script-out .. code-block:: none LSTMModel( (drop): Dropout(p=0.5, inplace=False) (encoder): Embedding(33278, 512) (rnn): LSTM(512, 256, num_layers=5, dropout=0.5) (decoder): Linear(in_features=256, out_features=33278, bias=True) ) .. GENERATED FROM PYTHON SOURCE LINES 171-174 Now let's generate some text to ensure that the pretrained model is working properly - similarly to before, we follow `here `_ .. GENERATED FROM PYTHON SOURCE LINES 174-199 .. code-block:: default input_ = torch.randint(ntokens, (1, 1), dtype=torch.long) hidden = model.init_hidden(1) temperature = 1.0 num_words = 1000 with open(model_data_filepath + 'out.txt', 'w') as outf: with torch.no_grad(): # no tracking history for i in range(num_words): output, hidden = model(input_, hidden) word_weights = output.squeeze().div(temperature).exp().cpu() word_idx = torch.multinomial(word_weights, 1)[0] input_.fill_(word_idx) word = corpus.dictionary.idx2word[word_idx] outf.write(str(word.encode('utf-8')) + ('\n' if i % 20 == 19 else ' ')) if i % 100 == 0: print('| Generated {}/{} words'.format(i, 1000)) with open(model_data_filepath + 'out.txt', 'r') as outf: all_output = outf.read() print(all_output) .. rst-class:: sphx-glr-script-out .. code-block:: none | Generated 0/1000 words | Generated 100/1000 words | Generated 200/1000 words | Generated 300/1000 words | Generated 400/1000 words | Generated 500/1000 words | Generated 600/1000 words | Generated 700/1000 words | Generated 800/1000 words | Generated 900/1000 words b'and' b'providing' b'the' b'cement' b',' b'parodied' b'.' b'In' b'cliffs' b',' b'Song' b'health' b'Commissioned' b'Alabama' b'@-@' b'cake' b'and' b'fraud' b"'" b'active' b'amputated' b'pew' b',' b'regardless' b'of' b'excimer' b',' b'over' b'a' b'small' b'down' b'with' b'the' b'curriculum' b'"' b'small' b'ambassadors' b'"' b'.' b'Manor' b',' b'but' b'Major' b'streaks' b'attempted' b'to' b'thoroughly' b'respect' b'for' b'the' b'search' b'as' b'Night' b'and' b'cultural' b'control' b'of' b'establishment' b'credit' b'to' b'most' b'of' b'his' b'race' b'or' b'the' b'mounts' b'to' b'demo' b'sentence' b'stories' b',' b'that' b'participate' b'power' b'on' b'hydride' b'Shapur' b'.' b'In' b'his' b'conviction' b'of' b'' b'to' b'a' b'photo' b',' b'on' b'Dutch' b'full' b'in' b'readability' b',' b'the' b'village' b'has' b'a' b'foundation' b'@-@' b'winning' b'attraction' b'and' b'is' b'free' b'his' b'staircase' b'.' b'A' b'average' b'@-@' b'' b'review' b'of' b'Sil' b'began' b',' b'and' b'his' b'wife' b'advances' b'as' b'follows' b';' b'they' b'contains' b'of' b'kickoff' b'from' b'Chou' b'.' b'Doyle' b"'s" b'invention' b'reading' b'Snelling' b"'s" b'discovery' b'restrictions' b'with' b'the' b'other' b'tales' b'and' b'Josipovi\xc4\x87' b'that' b'he' b'"' b'starting' b'mention' b'on' b'to' b'this' b'name' b'in' b'a' b'' b'' b'that' b'is' b'their' b'sons' b'.' b"'" b'[' b'lame' b']' b'normally' b'"' b'.' b'When' b'Kevin' b'Eugene' b'Walpole' b'describes' b'Raffles' b'subdivisions' b'' b'with' b'it' b'fell' b'as' b'his' b'opponent' b'and' b'went' b'into' b'Israeli' b'Nigel' b'' b',' b'Donald' b'culling' b'added' b'this' b'seems' b'as' b'"' b'ambassador' b'"' b',' b'According' b'to' b'its' b'lyrics' b'on' b'all' b'other' b'children' b',' b'' b'found' b'a' b'girl' b'...' b'Jax' b"'" b'name' b',' b'"' b'The' b'"' b'finger' b'"' b'has' b'him' b'the' b'final' b'@-@' b'time' b'admiration' b'of' b'literary' b'legal' b'singers' b':' b'Austria' b',' b'Boogie' b'Fruit' b',' b'Sichuan' b',' b'McDougal' b'/' b'' b',' b'Australia' b',' b'and' b'the' b'' b'Centre' b'.' b'In' b'1832' b'D' b'Robertson' b'has' b'Cardinals' b'a' b'planet' b'over' b'considered' b',' b'correspondence' b'in' b'both' b'a' b'play' b'with' b'an' b'judge' b'tailed' b'font' b'.' b'The' b'context' b'of' b'' b'and' b'Francesco' b'DGA' b'describes' b'Stock' b"'s" b'lover' b'(' b'a' b'' b'bar' b')' b'.' b'The' b'organ' b'are' b'largely' b'lacking' b'.' b'In' b'an' b'survey' b'of' b'' b'Entertainment' b',' b'Marshal' b'Chopra' b'writes' b'that' b'strong' b'' b'...' b'a' b'gag' b'and' b'widely' b'have' b'attempted' b'moving' b'in' b'his' b'grandfather' b"'s" b'retreat' b'.' b'Sand' b'by' b'a' b'critic' b"'s" b'necklace' b'on' b'the' b'devastating' b'difficulty' b'with' b'third' b'toys' b'approved' b'at' b'a' b'home' b'source' b'within' b'its' b'wish' b'.' b'In' b'the' b'Colorado' b',' b'the' b'' b'Boom' b'Magazine' b'and' b'the' b'Catholic' b'Yamaha' b'is' b'monsters' b'.' b'The' b'final' b'background' b'of' b'I' b'enjoyed' b'the' b'matter' b'follows' b'another' b'lighthouse' b',' b'especially' b'Julian' b'and' b'B.' b'Robinson' b'appears' b'a' b'different' b'hero' b'to' b'two' b'teachers' b',' b'rather' b'than' b'possibly' b'persisted' b'.' b'When' b'the' b'words' b'of' b'supposedly' b'classmate' b'realizes' b'a' b'exterior' b'adventures' b'and' b'stress' b'with' b'@-@' b'circular' b'Boogie' b'Kodipetta' b',' b'his' b'property' b'finale' b'consider' b'Pen' b'as' b'"' b'why' b'I' b'are' b'probably' b'gala' b',' b'a' b'living' b'and' b'used' b'"' b'.' b'The' b'same' b'day' b'called' b'"' b'The' b'head' b',' b'there' b'is' b'solar' b'' b'on' b'the' b'other' b'passages' b',' b'all' b'of' b'them' b'"' b'.' b'ssp.' b'Bernstein' b'observed' b'when' b'they' b'then' b'do' b'not' b'enjoy' b'an' b'lot' b'by' b'' b',' b'a' b'performer' b'loyal' b'to' b'how' b'Polish' b'' b'biblical' b'blocks' b';' b'he' b'feels' b'the' b'idea' b'of' b'his' b'will' b'.' b'Tech' b'took' b'first' b'run' b'with' b'' b',' b'making' b'their' b'lover' b'that' b'compared' b'them' b'to' b'Nature' b'images' b'.' b'In' b'Good' b'stories' b'it' b'arrives' b'the' b'game' b'around' b'a' b'single' b'name' b';' b'"' b'Mothers' b'will' b'look' b'her' b'talent' b'...' b'our' b'audience' b'Model' b'that' b'you' b'is' b'a' b'source' b'of' b'protagonist' b'that' b'suppose' b'acquire' b'far' b'the' b'terrible' b'allegory' b'"' b'.' b'O' b"'Malley" b"'s" b'conversation' b'on' b'the' b'map' b'.' b'Robert' b"'Neil" b'Picard' b'was' b'exercised' b'by' b'her' b'official' b'lover' b',' b'Raghuveer' b'dating' b'and' b'known' b'and' b'marked' b'the' b'cake' b'on' b'his' b'minor' b'episode' b'.' b'The' b'story' b'herself' b'was' b'brought' b'into' b'upon' b'his' b'nephew' b'by' b'her' b'part' b'and' b'John' b'Wilson' b'described' b'e' b'lifeless' b';' b'it' b'could' b'suggest' b'that' b'an' b'bunch' b'of' b'police' b'for' b'orbit' b'and' b'producers' b'.' b'Protesters' b';' b'against' b'the' b'conclusion' b'anticipate' b'"' b'unusual' b'willows' b',' b'[' b'college' b']' b'act' b'have' b'no' b'small' b',' b'a' b'kingdom' b'' b'"' b'.' b'Towards' b'Biomech' b'against' b'astute' b'Nelson' b'Brown' b'commented' b':' b'"' b'We' b'so' b'Chatham' b'to' b'control' b'by' b'the' b'date' b'of' b'Bob' b'patience' b'"' b'\xe2\x80\x94' b'"' b'The' b'Thompson' b'Jewishness' b'is' b'Patrick' b'revolved' b',' b'"' b'had' b'the' b'Phillies' b"'s" b'strong' b'depiction' b'of' b'' b'asteroids' b'for' b'"' b'cooking' b'missionary' b'obstacles' b'.' b'"' b'The' b'Santa' b'' b'Mountains' b'' b'her' b'to' b'a' b'array' b'of' b'J\xc3\xbcrgen' b'' b',' b'whom' b'their' b'brain' b'changeable' b'continues' b'to' b'record' b'out' b'of' b'their' b'format' b':' b'He' b'is' b'to' b'be' b'another' b'crowded' b'legal' b'or' b'Philips' b'relationship' b'with' b'them' b'.' b'The' b'ability' b'by' b'dealing' b'were' b'' b'and' b'dispatched' b'some' b'other' b'consumers' b'of' b'' b'for' b'having' b'delivered' b'a' b'play' b'across' b'' b'situations' b'almost' b'colour' b'rather' b'norm' b'from' b'his' b'party' b'"' b'.' b'The' b'message' b'to' b'the' b'NZEF' b'is' b'Principe' b'...' b'strategies' b'have' b'a' b'device' b'suit' b'on' b'them' b'to' b'be' b'' b'walk' b'.' b'I' b'were' b'she' b'give' b'himself' b'nature' b'on' b'a' b'' b'followers' b',' b'the' b'followers' b'feature' b'us' b'.' b'East' b'might' b'know' b'American' b'or' b'them' b',' b'those' b'I' b'on' b'a' b'sense' b'of' b'love' b'or' b'is' b'unable' b'to' b'look' b'for' b',' b'and' b'U\xc3\xad' b'Key' b'is' b'loved' b'by' b'characters' b'.' b'If' b'everybody' b'is' b'just' b'to' b'create' b'time' b'by' b'toured' b'and' b'even' b'always' b'a' b'talent' b'ends' b'more' b'narrow' b'and' b'B.' b'returns' b'earlier' b'upon' b'instrumental' b'to' b'standing' b'.' b'Odaenathus' b'said' b'the' b'bird' b'hit' b'' b'' b'@-@' b'head' b',' b'mathematics' b'biological' b',' b'with' b'a' b'sign' b',' b'stipulated' b'"' b'It' b'is' b'a' b'part' b'of' b'litter' b'fraud' b',' b'and' b'a' b'indispensable' b'message' b'kill' b'display' b'at' b'Zhou' b'but' b'gene' b'.' b'5' b'people' b'suggests' b'that' b'Hornung' b"'s" b'real' b'allusion' b'and' b'convey' b';' b'it' b"'s" b'certain' b'persons' b'we' b'are' b'' b'and' b'challenge' b'itself' b'but' b'which' b'Am' b'' b'it' b'was' b'.' b'Nevertheless' b',' b'it' b'still' b'See' b'something' b'as' b'a' b'solid' b'fundamental' b'funds' b'to' b'fly' b'thy' b'combat' b'.' b'But' b'think' b'that' b'of' b'these' b'composers' b',' b'on' b'other' b'are' b'Reese' b'on' b'resupply' b',' b'does' b'a' b'kind' b'of' b'' b',' b'he' b'insisted' b'.' b'This' b'movie' b'gift' b'into' b'a' b'relationship' b',' b'declaring' b'a' b'' b'further' b'@-@' b'son' b'pathway' b'who' b'came' b'with' b'rich' b'results' b'where' b'he' b'automatons' b'on' b'me' b'by' b'+' b'3' b',' b'Halo' b',' b'known' b'as' b'' b'' b'and' b'Josh' b'' b'.' b'The' b'time' b'right' b'of' b'churchwardens' b'when' b'not' b'the' b'outcome' b'has' b',' b'including' b'their' b'life' b',' b"'" b',' b'at' b'the' b'end' b'of' b'the' b'19th' b'century' b'.' b'If' .. GENERATED FROM PYTHON SOURCE LINES 200-205 It's no GPT-2, but it looks like the model has started to learn the structure of language! We're almost ready to demonstrate dynamic quantization. We just need to define a few more helper functions: .. GENERATED FROM PYTHON SOURCE LINES 205-250 .. code-block:: default bptt = 25 criterion = nn.CrossEntropyLoss() eval_batch_size = 1 # create test data set def batchify(data, bsz): # Work out how cleanly we can divide the dataset into ``bsz`` parts. nbatch = data.size(0) // bsz # Trim off any extra elements that wouldn't cleanly fit (remainders). data = data.narrow(0, 0, nbatch * bsz) # Evenly divide the data across the ``bsz`` batches. return data.view(bsz, -1).t().contiguous() test_data = batchify(corpus.test, eval_batch_size) # Evaluation functions def get_batch(source, i): seq_len = min(bptt, len(source) - 1 - i) data = source[i:i+seq_len] target = source[i+1:i+1+seq_len].reshape(-1) return data, target def repackage_hidden(h): """Wraps hidden states in new Tensors, to detach them from their history.""" if isinstance(h, torch.Tensor): return h.detach() else: return tuple(repackage_hidden(v) for v in h) def evaluate(model_, data_source): # Turn on evaluation mode which disables dropout. model_.eval() total_loss = 0. hidden = model_.init_hidden(eval_batch_size) with torch.no_grad(): for i in range(0, data_source.size(0) - 1, bptt): data, targets = get_batch(data_source, i) output, hidden = model_(data, hidden) hidden = repackage_hidden(hidden) output_flat = output.view(-1, ntokens) total_loss += len(data) * criterion(output_flat, targets).item() return total_loss / (len(data_source) - 1) .. GENERATED FROM PYTHON SOURCE LINES 251-260 4. Test dynamic quantization ---------------------------- Finally, we can call ``torch.quantization.quantize_dynamic`` on the model! Specifically, - We specify that we want the ``nn.LSTM`` and ``nn.Linear`` modules in our model to be quantized - We specify that we want weights to be converted to ``int8`` values .. GENERATED FROM PYTHON SOURCE LINES 260-268 .. code-block:: default import torch.quantization quantized_model = torch.quantization.quantize_dynamic( model, {nn.LSTM, nn.Linear}, dtype=torch.qint8 ) print(quantized_model) .. rst-class:: sphx-glr-script-out .. code-block:: none LSTMModel( (drop): Dropout(p=0.5, inplace=False) (encoder): Embedding(33278, 512) (rnn): DynamicQuantizedLSTM(512, 256, num_layers=5, dropout=0.5) (decoder): DynamicQuantizedLinear(in_features=256, out_features=33278, dtype=torch.qint8, qscheme=torch.per_tensor_affine) ) .. GENERATED FROM PYTHON SOURCE LINES 269-271 The model looks the same; how has this benefited us? First, we see a significant reduction in model size: .. GENERATED FROM PYTHON SOURCE LINES 271-280 .. code-block:: default def print_size_of_model(model): torch.save(model.state_dict(), "temp.p") print('Size (MB):', os.path.getsize("temp.p")/1e6) os.remove('temp.p') print_size_of_model(model) print_size_of_model(quantized_model) .. rst-class:: sphx-glr-script-out .. code-block:: none Size (MB): 113.944455 Size (MB): 79.738939 .. GENERATED FROM PYTHON SOURCE LINES 281-285 Second, we see faster inference time, with no difference in evaluation loss: Note: we set the number of threads to one for single threaded comparison, since quantized models run single threaded. .. GENERATED FROM PYTHON SOURCE LINES 285-297 .. code-block:: default torch.set_num_threads(1) def time_model_evaluation(model, test_data): s = time.time() loss = evaluate(model, test_data) elapsed = time.time() - s print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed)) time_model_evaluation(model, test_data) time_model_evaluation(quantized_model, test_data) .. rst-class:: sphx-glr-script-out .. code-block:: none loss: 5.167 elapsed time (seconds): 199.8 loss: 5.168 elapsed time (seconds): 113.1 .. GENERATED FROM PYTHON SOURCE LINES 298-309 Running this locally on a MacBook Pro, without quantization, inference takes about 200 seconds, and with quantization it takes just about 100 seconds. Conclusion ---------- Dynamic quantization can be an easy way to reduce model size while only having a limited effect on accuracy. Thanks for reading! As always, we welcome any feedback, so please create an issue `here `_ if you have any. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 5 minutes 22.092 seconds) .. _sphx_glr_download_advanced_dynamic_quantization_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: dynamic_quantization_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: dynamic_quantization_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_