Graph#
- class torch.accelerator.Graph(keep_graph=False, *, pool=None)[source]#
Wrapper around an accelerator graph that supports capture and replay.
A graph captures a sequence of operations and their dependencies, allowing them to be replayed efficiently with reduced overhead. This class can be used as a context manager to automatically capture operations on the current stream.
- Parameters:
keep_graph (bool, optional) – If
False, the underlying graph is destroyed and the executable graph is instantiated on the GPU at the end ofcapture_end. IfTrue, the underlying graph is preserved aftercapture_end. In this case, the executable graph is not instantiated automatically; it must be explicitly created by callinginstantiate, or it will be instantiated on the first call toreplay. Defaults toFalse.pool (tuple[int, int], optional) – Memory pool identifier for this graph. Multiple graphs can share the same pool by passing the same identifier, which can reduce memory overhead. Defaults to
None.
- Return type:
Self
Example:
>>> x = torch.zeros([2000], device=0) >>> stream = torch.Stream() >>> graph = torch.accelerator.Graph() >>> with stream, graph: ... x += 1 >>> graph.replay()
- capture_begin(capture_error_mode='default')[source]#
Begin graph capture on the current stream.
All operations executed on the current stream of the current device after this call will be recorded into the graph until
capture_endis called. By default, capture uses the memory pool provided at construction time.- Parameters:
capture_error_mode (Literal["default", "global", "thread_local", "relaxed"], optional) – Specifies the behavior of graph capture. The exact semantics are backend-specific. Defaults to “default”. default, backend-defined default capture behavior. global, potentially unsafe API calls are prohibited. Errors may occur if capture in the current thread affects other threads. thread_local, potentially unsafe API calls are prohibited. Errors occur only if capture in the current thread affects itself. relaxed, the current thread is allowed to make potentially unsafe API calls, except for calls that inherently conflict with stream capture.
- capture_end()[source]#
End graph capture on the current stream of the current device.
After this call, the graph can be replayed via
replay.
- debug_dump(path)[source]#
Dump the captured graph to a file for debugging purposes if the debugging is enabled via
enable_debug_mode.- Parameters:
path (str) – Path to dump the graph to.
- Example::
>>> s = torch.Stream() >>> g = torch.accelerator.Graph() >>> g.enable_debug_mode()
>>> with s, g: >>> # ... operations ...
>>> # Dump captured graph to a file "graph_dump.dot" >>> g.debug_dump("graph_dump.dot")
- instantiate()[source]#
Instantiate the underlying graph. Will be called by
capture_endifkeep_graph=False, or byreplayifkeep_graph=Trueandinstantiatehas not already been explicitly called.
- pool()[source]#
Return an opaque token representing the id of this graph’s memory pool.
This id can optionally be passed to another graph’s
capture_begin, which hints the other graph may share the same memory pool.- Example::
>>> g1 = torch.accelerator.Graph() >>> g1.capture_begin() >>> # ... operations ... >>> g1.capture_end()
>>> # Share g1's memory pool with a new graph >>> pool_id = g1.pool() >>> g2 = torch.accelerator.Graph(pool=pool_id)