Shortcuts

TorchScript

TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency.

We provide tools to incrementally transition a model from a pure Python program to a TorchScript program that can be run independently from Python, such as in a standalone C++ program. This makes it possible to train models in PyTorch using familiar tools in Python and then export the model via TorchScript to a production environment where Python programs may be disadvantageous for performance and multi-threading reasons.

For a gentle introduction to TorchScript, see the Introduction to TorchScript tutorial.

For an end-to-end example of converting a PyTorch model to TorchScript and running it in C++, see the Loading a PyTorch Model in C++ tutorial.

Creating TorchScript Code

torch.jit.script(obj)[source]

Scripting a function or nn.Module will inspect the source code, compile it as TorchScript code using the TorchScript compiler, and return a ScriptModule or ScriptFunction. TorchScript itself is a subset of the Python language, so not all features in Python work, but we provide enough functionality to compute on tensors and do control-dependent operations. For a complete guide, see the TorchScript Language Reference.

torch.jit.script can be used as a function for modules and functions, and as a decorator @torch.jit.script for TorchScript Classes and functions.

Parameters

obj (callable, class, or nn.Module) – The nn.Module, function, or class type to compile.

Returns

If obj is nn.Module, script returns a ScriptModule object. The returned ScriptModule will have the same set of sub-modules and parameters as the original nn.Module. If obj is a standalone function, a ScriptFunction will be returned.

Scripting a function

The @torch.jit.script decorator will construct a ScriptFunction by compiling the body of the function.

Example (scripting a function):

import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r

print(type(foo))  # torch.jit.ScriptFuncion

# See the compiled graph as Python code
print(foo.code)

# Call the function using the TorchScript interpreter
foo(torch.ones(2, 2), torch.ones(2, 2))
Scripting an nn.Module

Scripting an nn.Module by default will compile the forward method and recursively compile any methods, submodules, and functions called by forward. If a nn.Module only uses features supported in TorchScript, no changes to the original module code should be necessary. script will construct ScriptModule that has copies of the attributes, parameters, and methods of the original module.

Example (scripting a simple module with a Parameter):

import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        # This parameter will be copied to the new ScriptModule
        self.weight = torch.nn.Parameter(torch.rand(N, M))

        # When this submodule is used, it will be compiled
        self.linear = torch.nn.Linear(N, M)

    def forward(self, input):
        output = self.weight.mv(input)

        # This calls the `forward` method of the `nn.Linear` module, which will
        # cause the `self.linear` submodule to be compiled to a `ScriptModule` here
        output = self.linear(output)
        return output

scripted_module = torch.jit.script(MyModule(2, 3))

Example (scripting a module with traced submodules):

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        # torch.jit.trace produces a ScriptModule's conv1 and conv2
        self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

    def forward(self, input):
      input = F.relu(self.conv1(input))
      input = F.relu(self.conv2(input))
      return input

scripted_module = torch.jit.script(MyModule())

To compile a method other than forward (and recursively compile anything it calls), add the @torch.jit.export decorator to the method. To opt out of compilation use @torch.jit.ignore or @torch.jit.unused.

Example (an exported and ignored method in a module):

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()

    @torch.jit.export
    def some_entry_point(self, input):
        return input + 10

    @torch.jit.ignore
    def python_only_fn(self, input):
        # This function won't be compiled, so any
        # Python APIs can be used
        import pdb
        pdb.set_trace()

    def forward(self, input):
        if self.training:
            self.python_only_fn(input)
        return input * 99

scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))
torch.jit.trace(func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5)[source]

Trace a function and return an executable or ScriptFunction that will be optimized using just-in-time compilation. Tracing is ideal for code that operates only on Tensors and lists, dictionaries, and tuples of Tensors.

Using torch.jit.trace and torch.jit.trace_module, you can turn an existing module or Python function into a TorchScript ScriptFunction or ScriptModule. You must provide example inputs, and we run the function, recording the operations performed on all the tensors.

  • The resulting recording of a standalone function produces ScriptFunction.

  • The resulting recording of forward function of nn.Module or nn.Module produces ScriptModule.

This module also contains any parameters that the original module had as well.

Warning

Tracing only correctly records functions and modules which are not data dependent (e.g., do not have conditionals on data in tensors) and do not have any untracked external dependencies (e.g., perform input/output or access global variables). Tracing only records operations done when the given function is run on the given tensors. Therefore, the returned ScriptModule will always run the same traced graph on any input. This has some important implications when your module is expected to run different sets of operations, depending on the input and/or the module state. For example,

  • Tracing will not record any control-flow like if-statements or loops. When this control-flow is constant across your module, this is fine and it often inlines the control-flow decisions. But sometimes the control-flow is actually part of the model itself. For instance, a recurrent network is a loop over the (possibly dynamic) length of an input sequence.

  • In the returned ScriptModule, operations that have different behaviors in training and eval modes will always behave as if it is in the mode it was in during tracing, no matter which mode the ScriptModule is in.

In cases like these, tracing would not be appropriate and scripting is a better choice. If you trace such models, you may silently get incorrect results on subsequent invocations of the model. The tracer will try to emit warnings when doing something that may cause an incorrect trace to be produced.

Parameters
  • func (callable or torch.nn.Module) – A Python function or torch.nn.Module that will be run with example_inputs. arguments and returns to func must be tensors or (possibly nested) tuples that contain tensors. When a module is passed to torch.jit.trace, only the forward method is run and traced (see torch.jit.trace for details).

  • example_inputs (tuple) – A tuple of example inputs that will be passed to the function while tracing. The resulting trace can be run with inputs of different types and shapes assuming the traced operations support those types and shapes. example_inputs may also be a single Tensor in which case it is automatically wrapped in a tuple.

Keyword Arguments
  • check_trace (bool, optional) – Check if the same inputs run through traced code produce the same outputs. Default: True. You might want to disable this if, for example, your network contains non- deterministic ops or if you are sure that the network is correct despite a checker failure.

  • check_inputs (list of tuples, optional) – A list of tuples of input arguments that should be used to check the trace against what is expected. Each tuple is equivalent to a set of input arguments that would be specified in example_inputs. For best results, pass in a set of checking inputs representative of the space of shapes and types of inputs you expect the network to see. If not specified, the original example_inputs are used for checking

  • check_tolerance (float, optional) – Floating-point comparison tolerance to use in the checker procedure. This can be used to relax the checker strictness in the event that results diverge numerically for a known reason, such as operator fusion.

Returns

If callable is nn.Module or forward of nn.Module, trace returns a ScriptModule object with a single forward method containing the traced code. The returned ScriptModule will have the same set of sub-modules and parameters as the original nn.Module. If callable is a standalone function, trace returns ScriptFunction

Example (tracing a function):

import torch

def foo(x, y):
    return 2 * x + y

# Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

# `traced_foo` can now be run with the TorchScript interpreter or saved
# and loaded in a Python-free environment

Example (tracing an existing module):

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)
torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-5)[source]

Trace a module and return an executable ScriptModule that will be optimized using just-in-time compilation. When a module is passed to torch.jit.trace, only the forward method is run and traced. With trace_module, you can specify a dictionary of method names to example inputs to trace (see the example_inputs) argument below.

See torch.jit.trace for more information on tracing.

Parameters
  • mod (torch.nn.Module) – A torch.nn.Module containing methods whose names are specified in example_inputs. The given methods will be compiled as a part of a single ScriptModule.

  • example_inputs (dict) – A dict containing sample inputs indexed by method names in mod. The inputs will be passed to methods whose names correspond to inputs’ keys while tracing. { 'forward' : example_forward_input, 'method2': example_method2_input}

Keyword Arguments
  • check_trace (bool, optional) – Check if the same inputs run through traced code produce the same outputs. Default: True. You might want to disable this if, for example, your network contains non- deterministic ops or if you are sure that the network is correct despite a checker failure.

  • check_inputs (list of dicts, optional) – A list of dicts of input arguments that should be used to check the trace against what is expected. Each tuple is equivalent to a set of input arguments that would be specified in example_inputs. For best results, pass in a set of checking inputs representative of the space of shapes and types of inputs you expect the network to see. If not specified, the original example_inputs are used for checking

  • check_tolerance (float, optional) – Floating-point comparison tolerance to use in the checker procedure. This can be used to relax the checker strictness in the event that results diverge numerically for a known reason, such as operator fusion.

Returns

A ScriptModule object with a single forward method containing the traced code. When func is a torch.nn.Module, the returned ScriptModule will have the same set of sub-modules and parameters as func.

Example (tracing a module with multiple methods):

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

    def weighted_kernel_sum(self, weight):
        return weight * self.conv.weight


n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

# Trace specific methods on a module (specified in `inputs`), constructs
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
module = torch.jit.trace_module(n, inputs)
class torch.jit.ScriptModule[source]

ScriptModule``s wrap a C++ ``torch::jit::Module. ScriptModule``s contain methods, attributes, parameters, and constants. These can be accessed the same as on a normal ``nn.Module.

property code

Returns a pretty-printed representation (as valid Python syntax) of the internal graph for the forward method. See Inspecting Code for details.

property graph

Returns a string representation of the internal graph for the forward method. See Interpreting Graphs for details.

property inlined_graph

Returns a string representation of the internal graph for the forward method. This graph will be preprocessed to inline all function and method calls. See Interpreting Graphs for details.

save(f, _extra_files=ExtraFilesMap{})

See torch.jit.save for details.

class torch.jit.ScriptFunction

Functionally equivalent to a ScriptModule, but represents a single function and does not have any attributes or Parameters.

torch.jit.save(m, f, _extra_files=ExtraFilesMap{})[source]

Save an offline version of this module for use in a separate process. The saved module serializes all of the methods, submodules, parameters, and attributes of this module. It can be loaded into the C++ API using torch::jit::load(filename) or into the Python API with torch.jit.load.

To be able to save a module, it must not make any calls to native Python functions. This means that all submodules must be subclasses of ScriptModule as well.

Danger

All modules, no matter their device, are always loaded onto the CPU during loading. This is different from torch.load()’s semantics and may change in the future.

Parameters
  • m – A ScriptModule to save.

  • f – A file-like object (has to implement write and flush) or a string containing a file name.

  • _extra_files – Map from filename to contents which will be stored as part of ‘f’.

Warning

If you are using Python 2, torch.jit.save does NOT support StringIO.StringIO as a valid file-like object. This is because the write method should return the number of bytes written; StringIO.write() does not do this.

Please use something like io.BytesIO instead.

Example:

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

m = torch.jit.script(MyModule())

# Save to file
torch.jit.save(m, 'scriptmodule.pt')
# This line is equivalent to the previous
m.save("scriptmodule.pt")

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)

# Save with extra files
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
torch.jit.load(f, map_location=None, _extra_files=ExtraFilesMap{})[source]

Load a ScriptModule or ScriptFunction previously saved with torch.jit.save

All previously saved modules, no matter their device, are first loaded onto CPU, and then are moved to the devices they were saved from. If this fails (e.g. because the run time system doesn’t have certain devices), an exception is raised.

Parameters
  • f – a file-like object (has to implement read, readline, tell, and seek), or a string containing a file name

  • map_location (string or torch.device) – A simplified version of map_location in torch.save used to dynamically remap storages to an alternative set of devices.

  • _extra_files (dictionary of filename to content) – The extra filenames given in the map would be loaded and their content would be stored in the provided map.

Returns

A ScriptModule object.

Example:

import torch
import io

torch.jit.load('scriptmodule.pt')

# Load ScriptModule from io.BytesIO object
with open('scriptmodule.pt', 'rb') as f:
    buffer = io.BytesIO(f.read())

# Load all tensors to the original device
torch.jit.load(buffer)

# Load all tensors onto CPU, using a device
buffer.seek(0)
torch.jit.load(buffer, map_location=torch.device('cpu'))

# Load all tensors onto CPU, using a string
buffer.seek(0)
torch.jit.load(buffer, map_location='cpu')

# Load with extra files.
extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
torch.jit.load('scriptmodule.pt', _extra_files=extra_files)
print(extra_files['foo.txt'])
torch.jit.ignore(drop=False, **kwargs)[source]

This decorator indicates to the compiler that a function or method should be ignored and left as a Python function. This allows you to leave code in your model that is not yet TorchScript compatible. If called from TorchScript, ignored functions will dispatch the call to the Python interpreter. Models with ignored functions cannot be exported; use @torch.jit.unused instead.

Example (using @torch.jit.ignore on a method):

import torch
import torch.nn as nn

class MyModule(nn.Module):
    @torch.jit.ignore
    def debugger(self, x):
        import pdb
        pdb.set_trace()

    def forward(self, x):
        x += 10
        # The compiler would normally try to compile `debugger`,
        # but since it is `@ignore`d, it will be left as a call
        # to Python
        self.debugger(x)
        return x

m = torch.jit.script(MyModule())

# Error! The call `debugger` cannot be saved since it calls into Python
m.save("m.pt")

Example (using @torch.jit.ignore(drop=True) on a method):

import torch
import torch.nn as nn

class MyModule(nn.Module):
    @torch.jit.ignore(drop=True)
    def training_method(self, x):
        import pdb
        pdb.set_trace()

    def forward(self, x):
        if self.training:
            self.training_method(x)
        return x

m = torch.jit.script(MyModule())

# This is OK since `training_method` is not saved, the call is replaced
# with a `raise`.
m.save("m.pt")
torch.jit.unused(fn)[source]

This decorator indicates to the compiler that a function or method should be ignored and replaced with the raising of an exception. This allows you to leave code in your model that is not yet TorchScript compatible and still export your model.

Example (using @torch.jit.unused on a method):

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self, use_memory_efficent):
        super(MyModule, self).__init__()
        self.use_memory_efficent = use_memory_efficent

    @torch.jit.unused
    def memory_efficient(self, x):
        import pdb
        pdb.set_trace()
        return x + 10

    def forward(self, x):
        # Use not-yet-scriptable memory efficient mode
        if self.use_memory_efficient:
            return self.memory_efficient(x)
        else:
            return x + 10

m = torch.jit.script(MyModule(use_memory_efficent=False))
m.save("m.pt")

m = torch.jit.script(MyModule(use_memory_efficient=True))
# exception raised
m(torch.rand(100))

Mixing Tracing and Scripting

In many cases either tracing or scripting is an easier approach for converting a model to TorchScript. Tracing and scripting can be composed to suit the particular requirements of a part of a model.

Scripted functions can call traced functions. This is particularly useful when you need to use control-flow around a simple feed-forward model. For instance the beam search of a sequence to sequence model will typically be written in script but can call an encoder module generated using tracing.

Example (calling a traced function in script):

import torch

def foo(x, y):
    return 2 * x + y

traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

@torch.jit.script
def bar(x):
    return traced_foo(x, x)

Traced functions can call script functions. This is useful when a small part of a model requires some control-flow even though most of the model is just a feed-forward network. Control-flow inside of a script function called by a traced function is preserved correctly.

Example (calling a script function in a traced function):

import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r


def bar(x, y, z):
    return foo(x, y) + z

traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))

This composition also works for nn.Modules as well, where it can be used to generate a submodule using tracing that can be called from the methods of a script module.

Example (using a traced module):

import torch
import torchvision

class MyScriptModule(torch.nn.Module):
    def __init__(self):
        super(MyScriptModule, self).__init__()
        self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
                                        .resize_(1, 3, 1, 1))
        self.resnet = torch.jit.trace(torchvision.models.resnet18(),
                                      torch.rand(1, 3, 224, 224))

    def forward(self, input):
        return self.resnet(input - self.means)

my_script_module = torch.jit.script(MyScriptModule())

TorchScript Language

TorchScript is a statically typed subset of Python, so many Python features apply directly to TorchScript. See the full TorchScript Language Reference for details.

Built-in Functions and Modules

TorchScript supports the use of most PyTorch functions and many Python built-ins. See TorchScript Builtins for a full reference of supported functions.

PyTorch Functions and Modules

TorchScript supports a subset of the tensor and neural network functions that PyTorch provides. Most methods on Tensor as well as functions in the torch namespace, all functions in torch.nn.functional and most modules from torch.nn are supported in TorchScript.

See TorchScript Unsupported Pytorch Constructs for a list of unsupported PyTorch functions and modules.

Python Functions and Modules

Many of Python’s built-in functions are supported in TorchScript. The math module is also supported (see math-module for details), but no other Python modules (built-in or third party) are supported.

Python Language Reference Comparison

For a full listing of supported Python features, see Python Language Reference Coverage.

Debugging

Disable JIT for Debugging

PYTORCH_JIT

Setting the environment variable PYTORCH_JIT=0 will disable all script and tracing annotations. If there is hard-to-debug error in one of your TorchScript model, you can use this flag to force everything to run using native Python. Since TorchScript (scripting and tracing) are disabled with this flag, you can use tools like pdb to debug the model code.

Given an example

@torch.jit.script def scripted_fn(x : torch.Tensor):

for i in range(12):

x = x + x

return x

def fn(x):

x = torch.neg(x) import pdb; pdb.set_trace() return scripted_fn(x)

traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),)) traced_fn(torch.rand(3, 4))

Debugging this script with pdb works except for when we invoke the @torch.jit.script function. We can globally disable JIT, so that we can call the @torch.jit.script function as a normal Python function and not compile it. If the above script is called disable_jit_example.py, we can invoke it like so:

$ PYTORCH_JIT=0 python disable_jit_example.py

and we will be able to step into the @torch.jit.script function as a normal Python function. To disable the TorchScript compiler for a specific function, see @torch.jit.ignore.

Inspecting Code

TorchScript provides a code pretty-printer for all ScriptModule instances. This pretty-printer gives an interpretation of the script method’s code as valid Python syntax. For example:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.code)

A ScriptModule with a single forward method will have an attribute code, which you can use to inspect the ScriptModule’s code. If the ScriptModule has more than one method, you will need to access .code on the method itself and not the module. We can inspect the code of a method named foo on a ScriptModule by accessing .foo.code. The example above produces this output:

def foo(len: int) -> Tensor:
    rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
    rv0 = rv
    for i in range(len):
        if torch.lt(i, 10):
            rv1 = torch.sub(rv0, 1., 1)
        else:
            rv1 = torch.add(rv0, 1., 1)
        rv0 = rv1
    return rv0

This is TorchScript’s compilation of the code for the forward method. You can use this to ensure TorchScript (tracing or scripting) has captured your model code correctly.

Interpreting Graphs

TorchScript also has a representation at a lower level than the code pretty- printer, in the form of IR graphs.

TorchScript uses a static single assignment (SSA) intermediate representation (IR) to represent computation. The instructions in this format consist of ATen (the C++ backend of PyTorch) operators and other primitive operators, including control flow operators for loops and conditionals. As an example:

@torch.jit.script
def foo(len):
    # type: (int) -> torch.Tensor
    rv = torch.zeros(3, 4)
    for i in range(len):
        if i < 10:
            rv = rv - 1.0
        else:
            rv = rv + 1.0
    return rv

print(foo.graph)

graph follows the same rules described in the Inspecting Code section with regard to forward method lookup.

The example script above produces the graph:

graph(%len.1 : int):
  %24 : int = prim::Constant[value=1]()
  %17 : bool = prim::Constant[value=1]() # test.py:10:5
  %12 : bool? = prim::Constant()
  %10 : Device? = prim::Constant()
  %6 : int? = prim::Constant()
  %1 : int = prim::Constant[value=3]() # test.py:9:22
  %2 : int = prim::Constant[value=4]() # test.py:9:25
  %20 : int = prim::Constant[value=10]() # test.py:11:16
  %23 : float = prim::Constant[value=1]() # test.py:12:23
  %4 : int[] = prim::ListConstruct(%1, %2)
  %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
  %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5
    block0(%i.1 : int, %rv.14 : Tensor):
      %21 : bool = aten::lt(%i.1, %20) # test.py:11:12
      %rv.13 : Tensor = prim::If(%21) # test.py:11:9
        block0():
          %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18
          -> (%rv.3)
        block1():
          %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18
          -> (%rv.6)
      -> (%17, %rv.13)
  return (%rv)

Take the instruction %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10 for example.

  • %rv.1 : Tensor means we assign the output to a (unique) value named rv.1, that value is of Tensor type and that we do not know its concrete shape.

  • aten::zeros is the operator (equivalent to torch.zeros) and the input list (%4, %6, %6, %10, %12) specifies which values in scope should be passed as inputs. The schema for built-in functions like aten::zeros can be found at `Builtin Functions`_.

  • # test.py:9:10 is the location in the original source file that generated this instruction. In this case, it is a file named test.py, on line 9, and at character 10.

Notice that operators can also have associated blocks, namely the prim::Loop and prim::If operators. In the graph print-out, these operators are formatted to reflect their equivalent source code forms to facilitate easy debugging.

Graphs can be inspected as shown to confirm that the computation described by a ScriptModule is correct, in both automated and manual fashion, as described below.

Tracer

Tracing Edge Cases

There are some edge cases that exist where the trace of a given Python function/module will not be representative of the underlying code. These cases can include:

  • Tracing of control flow that is dependent on inputs (e.g. tensor shapes)

  • Tracing of in-place operations of tensor views (e.g. indexing on the left-hand side of an assignment)

Note that these cases may in fact be traceable in the future.

Automatic Trace Checking

One way to automatically catch many errors in traces is by using check_inputs on the torch.jit.trace() API. check_inputs takes a list of tuples of inputs that will be used to re-trace the computation and verify the results. For example:

def loop_in_traced_fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)

Gives us the following diagnostic information:

ERROR: Graphs differed across invocations!
Graph diff:

            graph(%x : Tensor) {
            %1 : int = prim::Constant[value=0]()
            %2 : int = prim::Constant[value=0]()
            %result.1 : Tensor = aten::select(%x, %1, %2)
            %4 : int = prim::Constant[value=0]()
            %5 : int = prim::Constant[value=0]()
            %6 : Tensor = aten::select(%x, %4, %5)
            %result.2 : Tensor = aten::mul(%result.1, %6)
            %8 : int = prim::Constant[value=0]()
            %9 : int = prim::Constant[value=1]()
            %10 : Tensor = aten::select(%x, %8, %9)
        -   %result : Tensor = aten::mul(%result.2, %10)
        +   %result.3 : Tensor = aten::mul(%result.2, %10)
        ?          ++
            %12 : int = prim::Constant[value=0]()
            %13 : int = prim::Constant[value=2]()
            %14 : Tensor = aten::select(%x, %12, %13)
        +   %result : Tensor = aten::mul(%result.3, %14)
        +   %16 : int = prim::Constant[value=0]()
        +   %17 : int = prim::Constant[value=3]()
        +   %18 : Tensor = aten::select(%x, %16, %17)
        -   %15 : Tensor = aten::mul(%result, %14)
        ?     ^                                 ^
        +   %19 : Tensor = aten::mul(%result, %18)
        ?     ^                                 ^
        -   return (%15);
        ?             ^
        +   return (%19);
        ?             ^
            }

This message indicates to us that the computation differed between when we first traced it and when we traced it with the check_inputs. Indeed, the loop within the body of loop_in_traced_fn depends on the shape of the input x, and thus when we try another x with a different shape, the trace differs.

In this case, data-dependent control flow like this can be captured using torch.jit.script() instead:

def fn(x):
    result = x[0]
    for i in range(x.size(0)):
        result = result * x[i]
    return result

inputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]

scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())

for input_tuple in [inputs] + check_inputs:
    torch.testing.assert_allclose(fn(*input_tuple), scripted_fn(*input_tuple))

Which produces:

graph(%x : Tensor) {
    %5 : bool = prim::Constant[value=1]()
    %1 : int = prim::Constant[value=0]()
    %result.1 : Tensor = aten::select(%x, %1, %1)
    %4 : int = aten::size(%x, %1)
    %result : Tensor = prim::Loop(%4, %5, %result.1)
    block0(%i : int, %7 : Tensor) {
        %10 : Tensor = aten::select(%x, %1, %i)
        %result.2 : Tensor = aten::mul(%7, %10)
        -> (%5, %result.2)
    }
    return (%result);
}

Tracer Warnings

The tracer produces warnings for several problematic patterns in traced computation. As an example, take a trace of a function that contains an in-place assignment on a slice (a view) of a Tensor:

def fill_row_zero(x):
    x[0] = torch.rand(*x.shape[1:2])
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

Produces several warnings and a graph which simply returns the input:

fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.
    x[0] = torch.rand(*x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)
    traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {
    return (%0);
}

We can fix this by modifying the code to not use the in-place update, but rather build up the result tensor out-of-place with torch.cat:

def fill_row_zero(x):
    x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0)
    return x

traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

Built-in Functions and Modules

See TorchScript Builtins for a full reference of supported functions.

Frequently Asked Questions

Q: I would like to train a model on GPU and do inference on CPU. What are the best practices?

First convert your model from GPU to CPU and then save it, like so:

cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(traced_cpu, sample_input_cpu)
torch.jit.save(traced_cpu, "cpu.pth")

traced_gpu = torch.jit.trace(traced_gpu, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pth")

# ... later, when using the model:

if use_gpu:
  model = torch.jit.load("gpu.pth")
else:
  model = torch.jit.load("cpu.pth")

model(input)

This is recommended because the tracer may witness tensor creation on a specific device, so casting an already-loaded model may have unexpected effects. Casting the model before saving it ensures that the tracer has the correct device information.

Q: How do I store attributes on a ScriptModule?

Say we have a model like:

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.x = 2

    def forward(self):
        return self.x

m = torch.jit.script(Model())

If Model is instantiated it will result in a compilation error since the compiler doesn’t know about x. There are 4 ways to inform the compiler of attributes on ScriptModule:

1. nn.Parameter - Values wrapped in nn.Parameter will work as they do on nn.Modules

2. register_buffer - Values wrapped in register_buffer will work as they do on nn.Modules. This is equivalent to an attribute (see 4) of type Tensor.

3. Constants - Annotating a class member as Final (or adding it to a list called __constants__ at the class definition level) will mark the contained names as constants. Constants are saved directly in the code of the model. See `Python-defined Constants`_ for details.

4. Attributes - Values that are a `supported type`_ can be added as mutable attributes. Most types can be inferred but some may need to be specified, see `Module Attributes`_ for details.

Q: I would like to trace module’s method but I keep getting this error:

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient

This error usually means that the method you are tracing uses a module’s parameters and you are passing the module’s method instead of the module instance (e.g. my_module_instance.forward vs my_module_instance).

  • Invoking trace with a module’s method captures module parameters (which may require gradients) as constants.

  • On the other hand, invoking trace with module’s instance (e.g. my_module) creates a new module and correctly copies parameters into the new module, so they can accumulate gradients if required.

To trace a specific method on a module, see torch.jit.trace_module

Appendix

Migrating to PyTorch 1.2 Recursive Scripting API

This section details the changes to TorchScript in PyTorch 1.2. If you are new to TorchScript you can skip this section. There are two main changes to the TorchScript API with PyTorch 1.2.

1. torch.jit.script will now attempt to recursively compile functions, methods, and classes that it encounters. Once you call torch.jit.script, compilation is “opt-out”, rather than “opt-in”.

2. torch.jit.script(nn_module_instance) is now the preferred way to create ScriptModules, instead of inheriting from torch.jit.ScriptModule. These changes combine to provide a simpler, easier-to-use API for converting your nn.Modules into ScriptModules, ready to be optimized and executed in a non-Python environment.

The new usage looks like this:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

my_model = Model()
my_scripted_model = torch.jit.script(my_model)
  • The module’s forward is compiled by default. Methods called from forward are lazily compiled in the order they are used in forward.

  • To compile a method other than forward that is not called from forward, add @torch.jit.export.

  • To stop the compiler from compiling a method, add @torch.jit.ignore or @torch.jit.unused. @ignore leaves the

  • method as a call to python, and @unused replaces it with an exception. @ignored cannot be exported; @unused can.

  • Most attribute types can be inferred, so torch.jit.Attribute is not necessary. For empty container types, annotate their types using PEP 526-style class annotations.

  • Constants can be marked with a Final class annotation instead of adding the name of the member to __constants__.

  • Python 3 type hints can be used in place of torch.jit.annotate

As a result of these changes, the following items are considered deprecated and should not appear in new code:
  • The @torch.jit.script_method decorator

  • Classes that inherit from torch.jit.ScriptModule

  • The torch.jit.Attribute wrapper class

  • The __constants__ array

  • The torch.jit.annotate function

Modules

Warning

The @torch.jit.ignore annotation’s behavior changes in PyTorch 1.2. Before PyTorch 1.2 the @ignore decorator was used to make a function or method callable from code that is exported. To get this functionality back, use @torch.jit.unused(). @torch.jit.ignore is now equivalent to @torch.jit.ignore(drop=False). See @torch.jit.ignore and @torch.jit.unused for details.

When passed to the torch.jit.script function, a torch.nn.Module’s data is copied to a ScriptModule and the TorchScript compiler compiles the module. The module’s forward is compiled by default. Methods called from forward are lazily compiled in the order they are used in forward, as well as any @torch.jit.export methods.

torch.jit.export(fn)[source]

This decorator indicates that a method on an nn.Module is used as an entry point into a ScriptModule and should be compiled.

forward implicitly is assumed to be an entry point, so it does not need this decorator. Functions and methods called from forward are compiled as they are seen by the compiler, so they do not need this decorator either.

Example (using @torch.jit.export on a method):

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def implicitly_compiled_method(self, x):
        return x + 99

    # `forward` is implicitly decorated with `@torch.jit.export`,
    # so adding it here would have no effect
    def forward(self, x):
        return x + 10

    @torch.jit.export
    def another_forward(self, x):
        # When the compiler sees this call, it will compile
        # `implicitly_compiled_method`
        return self.implicitly_compiled_method(x)

    def unused_method(self, x):
        return x - 20

# `m` will contain compiled methods:
#     `forward`
#     `another_forward`
#     `implicitly_compiled_method`
# `unused_method` will not be compiled since it was not called from
# any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())

Functions

Functions don’t change much, they can be decorated with @torch.jit.ignore or torch.jit.unused if needed.

# Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():
    return 2

# Marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
    return 2

# As with ignore, if nothing calls it then it has no effect.
# If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():
  import pdb; pdb.set_trace()
  return 4

# Doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn4():
    return 2

TorchScript Classes

Warning

TorchScript class support is experimental. Currently it is best suited for simple record-like types (think a NamedTuple with methods attached).

Everything in a user defined `TorchScript Class`_ is exported by default, functions can be decorated with @torch.jit.ignore if needed.

Attributes

The TorchScript compiler needs to know the types of `module attributes`_. Most types can be inferred from the value of the member. Empty lists and dicts cannot have their types inferred and must have their types annotated with PEP 526-style class annotations. If a type cannot be inferred and is not explicitly annotated, it will not be added as an attribute to the resulting ScriptModule

Old API:

from typing import Dict
import torch

class MyModule(torch.jit.ScriptModule):
    def __init__(self):
        super(MyModule, self).__init__()
        self.my_dict = torch.jit.Attribute({}, Dict[str, int])
        self.my_int = torch.jit.Attribute(20, int)

m = MyModule()

New API:

from typing import Dict

class MyModule(torch.nn.Module):
    my_dict: Dict[str, int]

    def __init__(self):
        super(MyModule, self).__init__()
        # This type cannot be inferred and must be specified
        self.my_dict = {}

        # The attribute type here is inferred to be `int`
        self.my_int = 20

    def forward(self):
        pass

m = torch.jit.script(MyModule())

Constants

The Final type constructor can be used to mark members as `constant`_. If members are not marked constant, they will be copied to the resulting ScriptModule as an attribute. Using Final opens opportunities for optimization if the value is known to be fixed and gives additional type safety.

Old API:

class MyModule(torch.jit.ScriptModule):
    __constants__ = ['my_constant']

    def __init__(self):
        super(MyModule, self).__init__()
        self.my_constant = 2

    def forward(self):
        pass
m = MyModule()

New API:

try:
    from typing_extensions import Final
except:
    # If you don't have `typing_extensions` installed, you can use a
    # polyfill from `torch.jit`.
    from torch.jit import Final

class MyModule(torch.nn.Module):

    my_constant: Final[int]

    def __init__(self):
        super(MyModule, self).__init__()
        self.my_constant = 2

    def forward(self):
        pass

m = torch.jit.script(MyModule())

Variables

Containers are assumed to have type Tensor and be non-optional (see `Default Types`_ for more information). Previously, torch.jit.annotate was used to tell the TorchScript compiler what the type should be. Python 3 style type hints are now supported.

import torch
from typing import Dict, Optional

@torch.jit.script
def make_dict(flag: bool):
    x: Dict[str, int] = {}
    x['hi'] = 2
    b: Optional[int] = None
    if flag:
        b = 2
    return x, b

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources