Graph#
- class torch.accelerator.Graph(keep_graph=False, *, pool=None, capture_error_mode='default')[source]#
Wrapper around an accelerator graph that supports capture and replay.
A graph captures a sequence of operations and their dependencies, allowing them to be replayed efficiently with reduced overhead. This class can be used as a context manager to automatically capture operations on the current stream.
- Parameters:
keep_graph (bool, optional) – If
False, the underlying graph is destroyed and the executable graph is instantiated on the GPU at the end ofcapture_end. IfTrue, the underlying graph is preserved aftercapture_end. In this case, the executable graph is not instantiated automatically; it must be explicitly created by callinginstantiate, or it will be instantiated on the first call toreplay. Defaults toFalse.pool (tuple[int, int], optional) – Memory pool identifier for this graph. Multiple graphs can share the same pool by passing the same identifier, which can reduce memory overhead. Defaults to
None.capture_error_mode (Literal["default", "global", "thread_local", "relaxed"], optional) – Specifies the behavior of graph capture. The exact semantics are backend-specific.
"default": backend-defined default capture behavior."global": potentially unsafe API calls are prohibited. Errors may occur if capture in the current thread affects other threads."thread_local": potentially unsafe API calls are prohibited. Errors occur only if capture in the current thread affects itself."relaxed": the current thread is allowed to make potentially unsafe API calls, except for calls that inherently conflict with stream capture. Default:"default".
- Return type:
Self
Example:
>>> x = torch.zeros([2000], device=0) >>> stream = torch.Stream() >>> graph = torch.accelerator.Graph() >>> with stream, graph: ... x += 1 >>> graph.replay()
- capture_begin()[source]#
Begin graph capture on the current stream.
All operations on the current stream after this call will be recorded into the graph until
capture_endis called, using the memory pool and capture error mode provided at construction time.
- capture_end()[source]#
End graph capture on the current stream of the current device.
After this call, the graph can be replayed via
replay.
- debug_dump(path)[source]#
Dump the captured graph to a file for debugging purposes if the debugging is enabled via
enable_debug_mode.- Parameters:
path (str) – Path to dump the graph to.
- Example::
>>> s = torch.Stream() >>> g = torch.accelerator.Graph() >>> g.enable_debug_mode()
>>> with s, g: >>> # ... operations ...
>>> # Dump captured graph to a file "graph_dump.dot" >>> g.debug_dump("graph_dump.dot")
- instantiate()[source]#
Instantiate the underlying graph. Will be called by
capture_endifkeep_graph=False, or byreplayifkeep_graph=Trueandinstantiatehas not already been explicitly called.
- pool()[source]#
Return an opaque token representing the id of this graph’s memory pool.
This id can optionally be passed to another graph’s
capture_begin, which hints the other graph may share the same memory pool.- Example::
>>> g1 = torch.accelerator.Graph() >>> g1.capture_begin() >>> # ... operations ... >>> g1.capture_end()
>>> # Share g1's memory pool with a new graph >>> pool_id = g1.pool() >>> g2 = torch.accelerator.Graph(pool=pool_id)