.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "intermediate/memory_format_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_intermediate_memory_format_tutorial.py: Channels Last Memory Format in PyTorch ******************************************************* **Author**: `Vitaly Fedyunin `_ .. grid:: 2 .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn :class-card: card-prerequisites * What is the channels last memory format in PyTorch? * How can it be used to improve performance on certain operators? .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites :class-card: card-prerequisites * PyTorch v1.5.0 * A CUDA-capable GPU ######################################################################### # Overview - What is channels last? # --------------------------------- The channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). For example, classic (contiguous) storage of NCHW tensor (in our case it is two 4x4 images with 3 color channels) look like this: .. figure:: /_static/img/classic_memory_format.png :alt: classic_memory_format Channels last memory format orders data differently: .. figure:: /_static/img/channels_last_memory_format.png :alt: channels_last_memory_format Pytorch supports memory formats by utilizing the existing strides structure. For example, 10x3x16x16 batch in Channels last format will have strides equal to (768, 1, 48, 3). .. GENERATED FROM PYTHON SOURCE LINES 42-44 Channels last memory format is implemented for 4D NCHW Tensors only. .. GENERATED FROM PYTHON SOURCE LINES 46-51 Memory Format API ----------------------- Here is how to convert tensors between contiguous and channels last memory formats. .. GENERATED FROM PYTHON SOURCE LINES 53-54 Classic PyTorch contiguous tensor .. GENERATED FROM PYTHON SOURCE LINES 54-60 .. code-block:: Python import torch N, C, H, W = 10, 3, 32, 32 x = torch.empty(N, C, H, W) print(x.stride()) # Outputs: (3072, 1024, 32, 1) .. rst-class:: sphx-glr-script-out .. code-block:: none (3072, 1024, 32, 1) .. GENERATED FROM PYTHON SOURCE LINES 61-62 Conversion operator .. GENERATED FROM PYTHON SOURCE LINES 62-66 .. code-block:: Python x = x.to(memory_format=torch.channels_last) print(x.shape) # Outputs: (10, 3, 32, 32) as dimensions order preserved print(x.stride()) # Outputs: (3072, 1, 96, 3) .. rst-class:: sphx-glr-script-out .. code-block:: none torch.Size([10, 3, 32, 32]) (3072, 1, 96, 3) .. GENERATED FROM PYTHON SOURCE LINES 67-68 Back to contiguous .. GENERATED FROM PYTHON SOURCE LINES 68-71 .. code-block:: Python x = x.to(memory_format=torch.contiguous_format) print(x.stride()) # Outputs: (3072, 1024, 32, 1) .. rst-class:: sphx-glr-script-out .. code-block:: none (3072, 1024, 32, 1) .. GENERATED FROM PYTHON SOURCE LINES 72-73 Alternative option .. GENERATED FROM PYTHON SOURCE LINES 73-76 .. code-block:: Python x = x.contiguous(memory_format=torch.channels_last) print(x.stride()) # Outputs: (3072, 1, 96, 3) .. rst-class:: sphx-glr-script-out .. code-block:: none (3072, 1, 96, 3) .. GENERATED FROM PYTHON SOURCE LINES 77-78 Format checks .. GENERATED FROM PYTHON SOURCE LINES 78-80 .. code-block:: Python print(x.is_contiguous(memory_format=torch.channels_last)) # Outputs: True .. rst-class:: sphx-glr-script-out .. code-block:: none True .. GENERATED FROM PYTHON SOURCE LINES 81-99 There are minor difference between the two APIs ``to`` and ``contiguous``. We suggest to stick with ``to`` when explicitly converting memory format of tensor. For general cases the two APIs behave the same. However in special cases for a 4D tensor with size ``NCHW`` when either: ``C==1`` or ``H==1 && W==1``, only ``to`` would generate a proper stride to represent channels last memory format. This is because in either of the two cases above, the memory format of a tensor is ambiguous, i.e. a contiguous tensor with size ``N1HW`` is both ``contiguous`` and channels last in memory storage. Therefore, they are already considered as ``is_contiguous`` for the given memory format and hence ``contiguous`` call becomes a no-op and would not update the stride. On the contrary, ``to`` would restride tensor with a meaningful stride on dimensions whose sizes are 1 in order to properly represent the intended memory format .. GENERATED FROM PYTHON SOURCE LINES 99-103 .. code-block:: Python special_x = torch.empty(4, 1, 4, 4) print(special_x.is_contiguous(memory_format=torch.channels_last)) # Outputs: True print(special_x.is_contiguous(memory_format=torch.contiguous_format)) # Outputs: True .. rst-class:: sphx-glr-script-out .. code-block:: none True True .. GENERATED FROM PYTHON SOURCE LINES 104-114 Same thing applies to explicit permutation API ``permute``. In special case where ambiguity could occur, ``permute`` does not guarantee to produce a stride that properly carry the intended memory format. We suggest to use ``to`` with explicit memory format to avoid unintended behavior. And a side note that in the extreme case, where three non-batch dimensions are all equal to ``1`` (``C==1 && H==1 && W==1``), current implementation cannot mark a tensor as channels last memory format. .. GENERATED FROM PYTHON SOURCE LINES 116-117 Create as channels last .. GENERATED FROM PYTHON SOURCE LINES 117-120 .. code-block:: Python x = torch.empty(N, C, H, W, memory_format=torch.channels_last) print(x.stride()) # Outputs: (3072, 1, 96, 3) .. rst-class:: sphx-glr-script-out .. code-block:: none (3072, 1, 96, 3) .. GENERATED FROM PYTHON SOURCE LINES 121-122 ``clone`` preserves memory format .. GENERATED FROM PYTHON SOURCE LINES 122-125 .. code-block:: Python y = x.clone() print(y.stride()) # Outputs: (3072, 1, 96, 3) .. rst-class:: sphx-glr-script-out .. code-block:: none (3072, 1, 96, 3) .. GENERATED FROM PYTHON SOURCE LINES 126-127 ``to``, ``cuda``, ``float`` ... preserves memory format .. GENERATED FROM PYTHON SOURCE LINES 127-131 .. code-block:: Python if torch.cuda.is_available(): y = x.cuda() print(y.stride()) # Outputs: (3072, 1, 96, 3) .. rst-class:: sphx-glr-script-out .. code-block:: none (3072, 1, 96, 3) .. GENERATED FROM PYTHON SOURCE LINES 132-133 ``empty_like``, ``*_like`` operators preserves memory format .. GENERATED FROM PYTHON SOURCE LINES 133-136 .. code-block:: Python y = torch.empty_like(x) print(y.stride()) # Outputs: (3072, 1, 96, 3) .. rst-class:: sphx-glr-script-out .. code-block:: none (3072, 1, 96, 3) .. GENERATED FROM PYTHON SOURCE LINES 137-138 Pointwise operators preserves memory format .. GENERATED FROM PYTHON SOURCE LINES 138-141 .. code-block:: Python z = x + y print(z.stride()) # Outputs: (3072, 1, 96, 3) .. rst-class:: sphx-glr-script-out .. code-block:: none (3072, 1, 96, 3) .. GENERATED FROM PYTHON SOURCE LINES 142-148 ``Conv``, ``Batchnorm`` modules using ``cudnn`` backends support channels last (only works for cuDNN >= 7.6). Convolution modules, unlike binary p-wise operator, have channels last as the dominating memory format. If all inputs are in contiguous memory format, the operator produces output in contiguous memory format. Otherwise, output will be in channels last memory format. .. GENERATED FROM PYTHON SOURCE LINES 148-159 .. code-block:: Python if torch.backends.cudnn.is_available() and torch.backends.cudnn.version() >= 7603: model = torch.nn.Conv2d(8, 4, 3).cuda().half() model = model.to(memory_format=torch.channels_last) # Module parameters need to be channels last input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, requires_grad=True) input = input.to(device="cuda", memory_format=torch.channels_last, dtype=torch.float16) out = model(input) print(out.is_contiguous(memory_format=torch.channels_last)) # Outputs: True .. rst-class:: sphx-glr-script-out .. code-block:: none True .. GENERATED FROM PYTHON SOURCE LINES 160-165 When input tensor reaches a operator without channels last support, a permutation should automatically apply in the kernel to restore contiguous on input tensor. This introduces overhead and stops the channels last memory format propagation. Nevertheless, it guarantees correct output. .. GENERATED FROM PYTHON SOURCE LINES 167-180 Performance Gains -------------------------------------------------------------------- Channels last memory format optimizations are available on both GPU and CPU. On GPU, the most significant performance gains are observed on NVIDIA's hardware with Tensor Cores support running on reduced precision (``torch.float16``). We were able to archive over 22% performance gains with channels last comparing to contiguous format, both while utilizing 'AMP (Automated Mixed Precision)' training scripts. Our scripts uses AMP supplied by NVIDIA https://github.com/NVIDIA/apex. ``python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 ./data`` .. GENERATED FROM PYTHON SOURCE LINES 180-213 .. code-block:: Python # opt_level = O2 # keep_batchnorm_fp32 = None # loss_scale = None # CUDNN VERSION: 7603 # => creating model 'resnet50' # Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights. # Defaults for this optimization level are: # enabled : True # opt_level : O2 # cast_model_type : torch.float16 # patch_torch_functions : False # keep_batchnorm_fp32 : True # master_weights : True # loss_scale : dynamic # Processing user overrides (additional kwargs that are not None)... # After processing overrides, optimization options are: # enabled : True # opt_level : O2 # cast_model_type : torch.float16 # patch_torch_functions : False # keep_batchnorm_fp32 : True # master_weights : True # loss_scale : dynamic # Epoch: [0][10/125] Time 0.866 (0.866) Speed 230.949 (230.949) Loss 0.6735125184 (0.6735) Prec@1 61.000 (61.000) Prec@5 100.000 (100.000) # Epoch: [0][20/125] Time 0.259 (0.562) Speed 773.481 (355.693) Loss 0.6968704462 (0.6852) Prec@1 55.000 (58.000) Prec@5 100.000 (100.000) # Epoch: [0][30/125] Time 0.258 (0.461) Speed 775.089 (433.965) Loss 0.7877287269 (0.7194) Prec@1 51.500 (55.833) Prec@5 100.000 (100.000) # Epoch: [0][40/125] Time 0.259 (0.410) Speed 771.710 (487.281) Loss 0.8285319805 (0.7467) Prec@1 48.500 (54.000) Prec@5 100.000 (100.000) # Epoch: [0][50/125] Time 0.260 (0.380) Speed 770.090 (525.908) Loss 0.7370464802 (0.7447) Prec@1 56.500 (54.500) Prec@5 100.000 (100.000) # Epoch: [0][60/125] Time 0.258 (0.360) Speed 775.623 (555.728) Loss 0.7592862844 (0.7472) Prec@1 51.000 (53.917) Prec@5 100.000 (100.000) # Epoch: [0][70/125] Time 0.258 (0.345) Speed 774.746 (579.115) Loss 1.9698858261 (0.9218) Prec@1 49.500 (53.286) Prec@5 100.000 (100.000) # Epoch: [0][80/125] Time 0.260 (0.335) Speed 770.324 (597.659) Loss 2.2505953312 (1.0879) Prec@1 50.500 (52.938) Prec@5 100.000 (100.000) .. GENERATED FROM PYTHON SOURCE LINES 214-217 Passing ``--channels-last true`` allows running a model in Channels last format with observed 22% performance gain. ``python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 --channels-last true ./data`` .. GENERATED FROM PYTHON SOURCE LINES 217-254 .. code-block:: Python # opt_level = O2 # keep_batchnorm_fp32 = None # loss_scale = None # # CUDNN VERSION: 7603 # # => creating model 'resnet50' # Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights. # # Defaults for this optimization level are: # enabled : True # opt_level : O2 # cast_model_type : torch.float16 # patch_torch_functions : False # keep_batchnorm_fp32 : True # master_weights : True # loss_scale : dynamic # Processing user overrides (additional kwargs that are not None)... # After processing overrides, optimization options are: # enabled : True # opt_level : O2 # cast_model_type : torch.float16 # patch_torch_functions : False # keep_batchnorm_fp32 : True # master_weights : True # loss_scale : dynamic # # Epoch: [0][10/125] Time 0.767 (0.767) Speed 260.785 (260.785) Loss 0.7579724789 (0.7580) Prec@1 53.500 (53.500) Prec@5 100.000 (100.000) # Epoch: [0][20/125] Time 0.198 (0.482) Speed 1012.135 (414.716) Loss 0.7007197738 (0.7293) Prec@1 49.000 (51.250) Prec@5 100.000 (100.000) # Epoch: [0][30/125] Time 0.198 (0.387) Speed 1010.977 (516.198) Loss 0.7113101482 (0.7233) Prec@1 55.500 (52.667) Prec@5 100.000 (100.000) # Epoch: [0][40/125] Time 0.197 (0.340) Speed 1013.023 (588.333) Loss 0.8943189979 (0.7661) Prec@1 54.000 (53.000) Prec@5 100.000 (100.000) # Epoch: [0][50/125] Time 0.198 (0.312) Speed 1010.541 (641.977) Loss 1.7113249302 (0.9551) Prec@1 51.000 (52.600) Prec@5 100.000 (100.000) # Epoch: [0][60/125] Time 0.198 (0.293) Speed 1011.163 (683.574) Loss 5.8537774086 (1.7716) Prec@1 50.500 (52.250) Prec@5 100.000 (100.000) # Epoch: [0][70/125] Time 0.198 (0.279) Speed 1011.453 (716.767) Loss 5.7595844269 (2.3413) Prec@1 46.500 (51.429) Prec@5 100.000 (100.000) # Epoch: [0][80/125] Time 0.198 (0.269) Speed 1011.827 (743.883) Loss 2.8196096420 (2.4011) Prec@1 47.500 (50.938) Prec@5 100.000 (100.000) .. GENERATED FROM PYTHON SOURCE LINES 255-258 The following list of models has the full support of Channels last and showing 8%-35% performance gains on Volta devices: ``alexnet``, ``mnasnet0_5``, ``mnasnet0_75``, ``mnasnet1_0``, ``mnasnet1_3``, ``mobilenet_v2``, ``resnet101``, ``resnet152``, ``resnet18``, ``resnet34``, ``resnet50``, ``resnext50_32x4d``, ``shufflenet_v2_x0_5``, ``shufflenet_v2_x1_0``, ``shufflenet_v2_x1_5``, ``shufflenet_v2_x2_0``, ``squeezenet1_0``, ``squeezenet1_1``, ``vgg11``, ``vgg11_bn``, ``vgg13``, ``vgg13_bn``, ``vgg16``, ``vgg16_bn``, ``vgg19``, ``vgg19_bn``, ``wide_resnet101_2``, ``wide_resnet50_2`` .. GENERATED FROM PYTHON SOURCE LINES 260-263 The following list of models has the full support of Channels last and showing 26%-76% performance gains on Intel(R) Xeon(R) Ice Lake (or newer) CPUs: ``alexnet``, ``densenet121``, ``densenet161``, ``densenet169``, ``googlenet``, ``inception_v3``, ``mnasnet0_5``, ``mnasnet1_0``, ``resnet101``, ``resnet152``, ``resnet18``, ``resnet34``, ``resnet50``, ``resnext101_32x8d``, ``resnext50_32x4d``, ``shufflenet_v2_x0_5``, ``shufflenet_v2_x1_0``, ``squeezenet1_0``, ``squeezenet1_1``, ``vgg11``, ``vgg11_bn``, ``vgg13``, ``vgg13_bn``, ``vgg16``, ``vgg16_bn``, ``vgg19``, ``vgg19_bn``, ``wide_resnet101_2``, ``wide_resnet50_2`` .. GENERATED FROM PYTHON SOURCE LINES 265-273 Converting existing models -------------------------- Channels last support is not limited by existing models, as any model can be converted to channels last and propagate format through the graph as soon as input (or certain weight) is formatted correctly. .. GENERATED FROM PYTHON SOURCE LINES 273-281 .. code-block:: Python # Need to be done once, after model initialization (or load) model = model.to(memory_format=torch.channels_last) # Replace with your model # Need to be done for every input input = input.to(memory_format=torch.channels_last) # Replace with your input output = model(input) .. GENERATED FROM PYTHON SOURCE LINES 282-304 However, not all operators fully converted to support channels last (usually returning contiguous output instead). In the example posted above, layers that does not support channels last will stop the memory format propagation. In spite of that, as we have converted the model to channels last format, that means each convolution layer, which has its 4 dimensional weight in channels last memory format, will restore channels last memory format and benefit from faster kernels. But operators that does not support channels last does introduce overhead by permutation. Optionally, you can investigate and identify operators in your model that does not support channels last, if you want to improve the performance of converted model. That means you need to verify the list of used operators against supported operators list https://github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support, or introduce memory format checks into eager execution mode and run your model. After running the code below, operators will raise an exception if the output of the operator doesn't match the memory format of the input. .. GENERATED FROM PYTHON SOURCE LINES 304-381 .. code-block:: Python def contains_cl(args): for t in args: if isinstance(t, torch.Tensor): if t.is_contiguous(memory_format=torch.channels_last) and not t.is_contiguous(): return True elif isinstance(t, list) or isinstance(t, tuple): if contains_cl(list(t)): return True return False def print_inputs(args, indent=""): for t in args: if isinstance(t, torch.Tensor): print(indent, t.stride(), t.shape, t.device, t.dtype) elif isinstance(t, list) or isinstance(t, tuple): print(indent, type(t)) print_inputs(list(t), indent=indent + " ") else: print(indent, t) def check_wrapper(fn): name = fn.__name__ def check_cl(*args, **kwargs): was_cl = contains_cl(args) try: result = fn(*args, **kwargs) except Exception as e: print("`{}` inputs are:".format(name)) print_inputs(args) print("-------------------") raise e failed = False if was_cl: if isinstance(result, torch.Tensor): if result.dim() == 4 and not result.is_contiguous(memory_format=torch.channels_last): print( "`{}` got channels_last input, but output is not channels_last:".format(name), result.shape, result.stride(), result.device, result.dtype, ) failed = True if failed and True: print("`{}` inputs are:".format(name)) print_inputs(args) raise Exception("Operator `{}` lost channels_last property".format(name)) return result return check_cl old_attrs = dict() def attribute(m): old_attrs[m] = dict() for i in dir(m): e = getattr(m, i) exclude_functions = ["is_cuda", "has_names", "numel", "stride", "Tensor", "is_contiguous", "__class__"] if i not in exclude_functions and not i.startswith("_") and "__call__" in dir(e): try: old_attrs[m][i] = e setattr(m, i, check_wrapper(e)) except Exception as e: print(i) print(e) attribute(torch.Tensor) attribute(torch.nn.functional) attribute(torch) .. GENERATED FROM PYTHON SOURCE LINES 382-386 If you found an operator that doesn't support channels last tensors and you want to contribute, feel free to use following developers guide https://github.com/pytorch/pytorch/wiki/Writing-memory-format-aware-operators. .. GENERATED FROM PYTHON SOURCE LINES 388-389 Code below is to recover the attributes of torch. .. GENERATED FROM PYTHON SOURCE LINES 389-394 .. code-block:: Python for (m, attrs) in old_attrs.items(): for (k, v) in attrs.items(): setattr(m, k, v) .. GENERATED FROM PYTHON SOURCE LINES 395-405 Work to do ---------- There are still many things to do, such as: - Resolving ambiguity of ``N1HW`` and ``NC11`` Tensors; - Testing of Distributed Training support; - Improving operators coverage. If you have feedback and/or suggestions for improvement, please let us know by creating `an issue `_. .. GENERATED FROM PYTHON SOURCE LINES 407-414 Conclusion ---------- This tutorial introduced the "channels last" memory format and demonstrated how to use it for performance gains. For a practical example of accelerating vision models using channels last, see the post `here `_. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.304 seconds) .. _sphx_glr_download_intermediate_memory_format_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: memory_format_tutorial.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: memory_format_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: memory_format_tutorial.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_