{ "cells": [ { "cell_type": "markdown", "id": "e2b62a7c", "metadata": {}, "source": [ "# Control MXU Floating Point Precision\n", " **Author:** `Yaoshiang Ho`\n", "\n", " **Date created:** 2025/05/15\n", "\n", " **Last modified:** 2025/05/15\n", "\n", " In this tutorial, you will learn how to control the floating point\n", " precision of matrix multiplication (mat mul) operations\n", " when using certain accelerators with PyTorch/XLA, such as TPUs.\n", " You will also learn how to access torch's floating point info\n", " and how to visually inspect the floating point representation of numbers." ] }, { "cell_type": "markdown", "id": "66ff85ec", "metadata": {}, "source": [ "## Introduction\n", "\n", " Google TPUs are built with matrix multiplication optimized on silicon\n", " in a physical module called a Matrix Multiply Unit or MXU.\n", " To maintain speed, researchers identified an inexpensive tradeoff.\n", " [The research](https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus)\n", " showed that neural networks were able to train with less precision\n", " than FP32\n", " \\\"without having any noticeable impact on model accuracy\\\".\n", " The same was not true for range. Due to operations like norms,\n", " FP32's range was important to keep. The solution was bfloat16:\n", " the same range as FP32, with less precision.\n", "\n", " Nvidia V100 and newer GPUs also include specialized matrix multiplication\n", " units called TensorCores. These GPUs use a numerical format called\n", " TF32, which has the same range as FP32 and bfloat16, but an\n", " intermediate precision (10 bits of mantissa)\n", " because TF32 only has 19 total bits.\n", "\n", " Matrix multiplication operations performed on FP32 values will\n", " yield results in bfloat16 for TPUs and TF32 (with 19 bits)\n", " for Nvidia GPUs.\n", "\n", " ![bits layout](../_static/img/bit_layout.svg)" ] }, { "cell_type": "markdown", "id": "95e5e8fd", "metadata": {}, "source": [ "## Higher precision math on lower precision hardware\n", "\n", " Even with the 7 mantissa bits of bfloat16, it is possible to calculate\n", " math in higher precision. This is done by breaking up a number into\n", " its components. To build intuition, imagine an MXU\n", " that supports 2 digits in a base-10 (decimal) number system.\n", " The goal is to multiply numbers with\n", " 4 digits of precision, for example, $9.111$ and $9.222$.\n", " In infinite precision, the product is $84.021642$. Notice that\n", " two numbers with 4 digits of precision generates twice as many\n", " digits of precision in the result. But given the number format\n", " is 4 digits, the result will be rounded to $84.02$.\n", "\n", " The simplest approach is to round the numbers to $9.1$ and $9.2$,\n", " resulting in $83.72$. This is conceptually the \"default\" precision\n", " setting of PyTorch/XLA on TPU.\n", "\n", " The next approach is to break each number up into two parts,\n", " high and low (H and L):\n", " $(9.1 + 0.011) \\times (9.2 + 0.022)$.\n", " This equals $(H \\times H + H \\times L + L \\times H + L \\times L)$.\n", " The first three matrix\n", " multiplications comprise the three-pass approach and roughly\n", " doubles the effective precision. The fourth term, $L \\times L$, is ignored\n", " and looking at the result, $0.000242$, it is easy to see that\n", " this value will not contribute to the final result. Some values of\n", " $L \\times L$\n", " could generate a fourth term that moves the value by one bit, but adding\n", " one bit of information half the time provides little value relative\n", " to the cost of running another multiplication.\n", "\n", " ```text\n", " +--------+--------+\n", " | 9.222 |\n", " +--------+--------+\n", " | 9.2 | 0.022 |\n", " +--------+--------+--------+--------+\n", " |9.111 | 9.1 |83.72 | 0.2002 |\n", " +--------+--------+--------+--------+\n", " | | 0.011 | 0.1012 | | = 84.0214 => 84.02\n", " +--------+--------+--------+--------+\n", " ```\n", "\n", "\n", " Extending this approach again would yield roughly triple the precision.\n", " The idea is to break the number into\n", " into high, medium, and low (H, M, and L), generating nine possible terms:\n", " $(H + M + L) \\times (H + M + L) = HH + HM + MH + MM + HL + LH + ML + LM + LL$.\n", " The final three are ignored, and the first six comprise the six-pass approach.\n", " It is essentially equivalent to FP32, with some room for variances in the minor bit." ] }, { "cell_type": "markdown", "id": "5cd81784", "metadata": {}, "source": [ "## PyTorch/XLA and TPUs\n", " PyTorch/XLA allows control of the one-pass, three-pass, and six-pass\n", " approaches in the `torch_xla.backends.set_mat_mul_precision()`\n", " function.\n", " The valid values are `default`, `high`, and `highest`. Now, you'll investigate\n", " the differences between these three settings.\n", "\n", " Warning: Although this notebook demonstrates setting precision multiple\n", " times\n", " it is recommended to only set the precision once at the beginning of your\n", " script." ] }, { "cell_type": "markdown", "id": "08e2cb59", "metadata": {}, "source": [ "## Preparations\n", "\n", " Make sure you are running this tutorial on a TPU. You can\n", " access a TPU using Google Colab.\n", "\n", " Import the required packages." ] }, { "cell_type": "code", "execution_count": 1, "id": "da0d842a", "metadata": { "execution": { "iopub.execute_input": "2025-05-15T20:08:59.346015Z", "iopub.status.busy": "2025-05-15T20:08:59.345651Z", "iopub.status.idle": "2025-05-15T20:09:02.266478Z", "shell.execute_reply": "2025-05-15T20:09:02.265986Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.\n" ] } ], "source": [ "import torch\n", "import torch_xla.backends\n", "\n", "torch.set_printoptions(precision=20, sci_mode=False, linewidth=240)" ] }, { "cell_type": "markdown", "id": "05c0809b", "metadata": {}, "source": [ "Epsilon is the minimum difference between 1.0 and the next highest representable number. Retrieve the value out from torch." ] }, { "cell_type": "code", "execution_count": 2, "id": "a3fc145c", "metadata": { "execution": { "iopub.execute_input": "2025-05-15T20:09:02.268113Z", "iopub.status.busy": "2025-05-15T20:09:02.267838Z", "iopub.status.idle": "2025-05-15T20:09:02.270606Z", "shell.execute_reply": "2025-05-15T20:09:02.270204Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "bfloat16 epsilon: 0.0078125\n", "return type of torch.finfo: \n" ] } ], "source": [ "eps = torch.finfo(torch.bfloat16).eps\n", "print(f\"bfloat16 epsilon: {eps}\")\n", "print(f\"return type of torch.finfo: {type(eps)}\")" ] }, { "cell_type": "markdown", "id": "7dff7be6", "metadata": {}, "source": [ "The epsilon is also defined as 1 / 2^p, where p is the number of bits in the mantissa." ] }, { "cell_type": "code", "execution_count": 3, "id": "cc2b5702", "metadata": { "execution": { "iopub.execute_input": "2025-05-15T20:09:02.271923Z", "iopub.status.busy": "2025-05-15T20:09:02.271691Z", "iopub.status.idle": "2025-05-15T20:09:02.274030Z", "shell.execute_reply": "2025-05-15T20:09:02.273621Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.0078125\n" ] } ], "source": [ "print(1 / (2**7))" ] }, { "cell_type": "markdown", "id": "ef147654", "metadata": {}, "source": [ "Numbers in between may get rounded up to 1.0 + epsilon, or down to 1.0." ] }, { "cell_type": "code", "execution_count": 4, "id": "80595534", "metadata": { "execution": { "iopub.execute_input": "2025-05-15T20:09:02.275297Z", "iopub.status.busy": "2025-05-15T20:09:02.275092Z", "iopub.status.idle": "2025-05-15T20:09:02.278882Z", "shell.execute_reply": "2025-05-15T20:09:02.278473Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([1.00000000000000000000, 1.00000000000000000000, 1.00000000000000000000, 1.00781250000000000000, 1.00781250000000000000], dtype=torch.bfloat16)\n" ] } ], "source": [ "print(\n", " torch.tensor(\n", " [1.0, 1.0 + eps / 4.0, 1.0 + eps / 2, 1.0 + eps * 3 / 4, 1.0 + eps],\n", " dtype=torch.bfloat16,\n", " ))" ] }, { "cell_type": "markdown", "id": "a54979f0", "metadata": {}, "source": [ "## Get ready to look directly at bits\n", "\n", " Set up tools to convert binary strings to FP32 numbers, and vice\n", " versa. Create a function to generate a random matrix.\n", "\n", " In general, when\n", " testing an MXU (or TensorCore), pass matrices to encourage XLA to\n", " use the MXU rather than the slower but more precise units for FP32." ] }, { "cell_type": "code", "execution_count": 5, "id": "5e8d9050", "metadata": { "execution": { "iopub.execute_input": "2025-05-15T20:09:02.280122Z", "iopub.status.busy": "2025-05-15T20:09:02.279908Z", "iopub.status.idle": "2025-05-15T20:09:02.283995Z", "shell.execute_reply": "2025-05-15T20:09:02.283591Z" } }, "outputs": [], "source": [ "import struct\n", "\n", "\n", "def binary_fraction_to_fp32(bstr: str) -> float:\n", " if bstr[:4] != \"0b1.\":\n", " raise ValueError(f\"Invalid binary string: {bstr}\")\n", " fraction_bits = bstr[4:]\n", " mantissa = 1.0\n", " for i, bit in enumerate(fraction_bits):\n", " mantissa += int(bit) * 2**-(i + 1)\n", " return float(mantissa)\n", "\n", "\n", "def fp32_to_binary_fraction(fp32_float: float) -> str:\n", " x_bytes = struct.pack(\">f\", fp32_float) # Big-endian IEEE 754 float32\n", " as_int = struct.unpack(\">I\", x_bytes)[0] # Interpret bits as uint32\n", " sign = (as_int >> 31) & 0b1\n", " exponent = (as_int >> 23) & 0xFF\n", " mantissa = as_int & 0x7FFFFF # lower 23 bits\n", " return f\"FORMAT:0b SIGN:{sign} EXPONENT:{exponent:08b} MANTISSA:{mantissa:023b} VALUE={fp32_float}\"\n", "\n", "\n", "def get_rand_matrix():\n", " \"\"\"Returns a diagonal matrix of shape 1024, 1024, values between 0.999 and 1.111\"\"\"\n", " eye = torch.eye(1024, dtype=torch.float32, device=\"xla\")\n", " rand_ = torch.rand(\n", " (1024, 1024), dtype=torch.float32, device=\"xla\") * 0.2 + 0.9\n", " result = eye * rand_\n", " assert torch.nonzero(result).size(0) == 1024, torch.nonzero(result).size(0)\n", " return result" ] }, { "cell_type": "markdown", "id": "876974bb", "metadata": {}, "source": [ "## Examining a number\n", "\n", " Generate an FP32 number representing 1 + bf16_eps/2. This\n", " will put one extra bit out of reach of a bfloat16's mantissa." ] }, { "cell_type": "code", "execution_count": 6, "id": "aea31981", "metadata": { "execution": { "iopub.execute_input": "2025-05-15T20:09:02.285188Z", "iopub.status.busy": "2025-05-15T20:09:02.285017Z", "iopub.status.idle": "2025-05-15T20:09:02.287520Z", "shell.execute_reply": "2025-05-15T20:09:02.287126Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "FP32 : 1.00390625\n", "1 + eps/2: 1.00390625\n" ] } ], "source": [ "one_plus_half_eps = binary_fraction_to_fp32(\"0b1.\" + \"0\" * 7 + \"1\" + \"0\" * 15)\n", "print(f\"FP32 : {one_plus_half_eps }\")\n", "print(f\"1 + eps/2: {1.0 + eps / 2}\")" ] }, { "cell_type": "markdown", "id": "4aab8ded", "metadata": {}, "source": [ "Print the bits for FP32 and BF16. Notice that the 8th bit is lost.\n", " This reconfirms that BF16 cannot represent the 8th bit of precision." ] }, { "cell_type": "code", "execution_count": 7, "id": "fc626b20", "metadata": { "execution": { "iopub.execute_input": "2025-05-15T20:09:02.288805Z", "iopub.status.busy": "2025-05-15T20:09:02.288596Z", "iopub.status.idle": "2025-05-15T20:09:02.291240Z", "shell.execute_reply": "2025-05-15T20:09:02.290846Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "FP32: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000001000000000000000 VALUE=1.00390625\n", "BF16: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000000000000000000000 VALUE=1.0\n" ] } ], "source": [ "print(f\"FP32: {fp32_to_binary_fraction(one_plus_half_eps)}\")\n", "ones_bf16 = torch.tensor(\n", " one_plus_half_eps, dtype=torch.bfloat16).to(torch.float32).item()\n", "print(f\"BF16: {fp32_to_binary_fraction(ones_bf16)}\")" ] }, { "cell_type": "markdown", "id": "2eff9e6f", "metadata": {}, "source": [ "## MXU\n", "\n", " Place your numbers of interest in a diagonal matrix. By putting them in\n", " a matrix, XLA will execute the math on the MXU. By making the matrices\n", " diagonal, the math will be equivalent to element-wise multiplication.\n", "\n", " Notice that the values are essentially rounded down to 1.0 before\n", " being multiplied, resulting in 1.0 as the output.\n", " This is the loss of precision that occurs in a TPU." ] }, { "cell_type": "code", "execution_count": 8, "id": "ce0df76e", "metadata": { "execution": { "iopub.execute_input": "2025-05-15T20:09:02.292448Z", "iopub.status.busy": "2025-05-15T20:09:02.292247Z", "iopub.status.idle": "2025-05-15T20:09:17.401029Z", "shell.execute_reply": "2025-05-15T20:09:17.400357Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000001000000000000000 VALUE=1.00390625\n", "Y: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000001000000000000000 VALUE=1.00390625\n", "Z: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000000000000000000000 VALUE=1.0\n" ] } ], "source": [ "X = get_rand_matrix()\n", "Y = get_rand_matrix()\n", "X[0, 0] = one_plus_half_eps\n", "Y[0, 0] = one_plus_half_eps\n", "Z = torch.matmul(X, Y)\n", "print(f\"X: {fp32_to_binary_fraction(X[0][0].item())}\")\n", "print(f\"Y: {fp32_to_binary_fraction(Y[0][0].item())}\")\n", "print(f\"Z: {fp32_to_binary_fraction(Z[0][0].item())}\")" ] }, { "cell_type": "markdown", "id": "598de4b7", "metadata": {}, "source": [ "## FP32 precision on bfloat16 hardware\n", "\n", " The 3 and 6 pass approaches generate more bits of precision.\n", " Turn on the highest precision mode (six passes) and run\n", " the experiment again. Notice that the TPU has calculated FP32 precision." ] }, { "cell_type": "code", "execution_count": 9, "id": "ebfa965b", "metadata": { "execution": { "iopub.execute_input": "2025-05-15T20:09:17.402608Z", "iopub.status.busy": "2025-05-15T20:09:17.402427Z", "iopub.status.idle": "2025-05-15T20:09:17.772271Z", "shell.execute_reply": "2025-05-15T20:09:17.771603Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:torch_xla.backends:Setting mat mul precision multiple times is not recommended. If you need to do so, please empirically verify that the precision setting is behaving as expected.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Z_ref: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000010000000010000000 VALUE=1.0078277587890625\n", "Z: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000010000000010000000 VALUE=1.0078277587890625\n" ] } ], "source": [ "Z_ref = torch.matmul(\n", " X.to(\"cpu\").to(torch.float32),\n", " Y.to(\"cpu\").to(torch.float32))\n", "print(f\"Z_ref: {fp32_to_binary_fraction(Z_ref[0][0].item())}\")\n", "torch_xla.backends.set_mat_mul_precision(\"highest\")\n", "Z = torch.matmul(X, Y)\n", "print(f\"Z: {fp32_to_binary_fraction(Z[0][0].item())}\")" ] }, { "cell_type": "markdown", "id": "d14d7dfe", "metadata": {}, "source": [ "## Edge-case numbers\n", "\n", " In the previous example, you saw no difference between\n", " the six-pass and FP32 multiplication. Now, you will use an edge\n", " case number to demonstrate a difference in the\n", " final bit between the six-pass approach and full FP32." ] }, { "cell_type": "code", "execution_count": 10, "id": "8011e785", "metadata": { "execution": { "iopub.execute_input": "2025-05-15T20:09:17.773780Z", "iopub.status.busy": "2025-05-15T20:09:17.773587Z", "iopub.status.idle": "2025-05-15T20:09:18.408244Z", "shell.execute_reply": "2025-05-15T20:09:18.407613Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:torch_xla.backends:Setting mat mul precision multiple times is not recommended. If you need to do so, please empirically verify that the precision setting is behaving as expected.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Z_ref: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:01110000101000111101100 VALUE=1.440000057220459\n", "Z: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:01110000101000111101101 VALUE=1.4400001764297485\n" ] } ], "source": [ "X = get_rand_matrix()\n", "Y = get_rand_matrix()\n", "X[0, 0] = 1.2\n", "Y[0, 0] = 1.2\n", "Z_ref = torch.matmul(\n", " X.to(\"cpu\").to(torch.float32),\n", " Y.to(\"cpu\").to(torch.float32))\n", "print(f\"Z_ref: {fp32_to_binary_fraction(Z_ref[0][0].item())}\")\n", "torch_xla.backends.set_mat_mul_precision(\"highest\")\n", "Z = torch.matmul(X, Y)\n", "print(f\"Z: {fp32_to_binary_fraction(Z[0][0].item())}\")" ] }, { "cell_type": "markdown", "id": "ab0a2c9a", "metadata": {}, "source": [ "## Conclusion\n", "\n", " In this tutorial, you learned how to control the floating point\n", " precision of your matrix multiplication (mat mul) operations.\n", " You also learned the internal algorithm used to generate\n", " higher precision through the three-pass and six-pass approaches." ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }