:github_url: https://github.com/pytorch/pytorch .. _program_listing_file_torch_csrc_autograd_function.h: Program Listing for File function.h =================================== |exhale_lsh| :ref:`Return to documentation for file ` (``torch/csrc/autograd/function.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::autograd { struct Edge; struct FunctionPostHook; struct FunctionPreHook; using tensor_list = std::vector; using variable_list = std::vector; using edge_list = std::vector; using saved_variable_list = std::vector; using ivalue_list = std::vector; using functional_apply_t = std::function< variable_list(const variable_list&, const std::vector&)>; using IndexRange = std::pair; using torch::dynamo::autograd::CompiledNodeArgs; using torch::dynamo::autograd::PackedArgs; using torch::dynamo::autograd::SwapSavedVariables; // Custom deleter to prevent stack overflows. TORCH_API void deleteNode(Node* function); // Guard that sets and restores the evaluating node class NodeGuard { public: explicit NodeGuard(std::shared_ptr node); ~NodeGuard(); private: std::shared_ptr last_evaluating_node_; }; // Return the Node currently being evaluated (if any) // This is only set during the backward pass while a Node is being // executed. TORCH_API std::shared_ptr get_current_node(); //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Node //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // A `Node` is an abstract class that represents an operation taking zero // or more input `Variable`s and producing zero or more output `Variable`s. All // functions in PyTorch's autograd machinery derive from this class and // override its `apply` method. Instances of such subclasses will then be // invokable via the call operator. // // Nodes in the Autograd Graph //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // When viewing the autograd system as a graph, `Node`s are the vertices or // nodes, connected to each other via (directed) `Edge`s, which themselves are // represented via (`Node`, input_nr) pairs. `Variable`s are the outputs to // and inputs of `Node`s, and travel between these edges during execution // of the graph. When two or more `Edge`s (from different sources) point at the // same input to a `Node`, the values produced along all of these edges are // implicitly summed prior to being forwarded to the target `Node`. // // Hierarchy //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Subclasses usually represent differentiable functions as well as their // gradient operators. Note, however, that due to the very general definition // of a `Node` taking *zero* or more inputs and producing *zero* or more // outputs, uses of `Node`s are flexible and extend beyond purely // mathematical operations. For example, the `AccumulateGrad` function is a // *sink*: it takes one input, but produces no outputs, instead accumulating // the input as a side effect. At the other extreme, the `GraphRoot` function // receives no inputs from other functions, but produces multiple outputs. // // Interface //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // The most important method on `Node` is the call operator, which takes in // a list of variables and produces a list of variables. The precise size of // these lists can be determined with `num_inputs()` and `num_outputs()`. // `Node`s are stitched together via their `next_edge` interface, which let // you manipulate the set of outgoing edges of a `Node`. You can add an // edge with `add_next_edge()`, retrieve an edge with `next_edge(index)` and // iterate over them via the `next_edges()` method. Other methods exist for // integration with the JIT and other parts of PyTorch. Every `Node` has a // *sequence number* that increases monotonically in the order of `Node` // construction. It can be retrieved via the `sequence_nr()` method. Note that // this sequence number is *thread local*. This means that when `Node`s // `A`, `B` and `C` are created consecutively in the same thread, their // sequence numbers will be ordered `A` < `B` < `C`. If, however, `A` and `B` // are created in one thread and `C` is created in a new thread, there are *no // guarantees* w.r.t. the ordering of `C` relative to `A` or `B`. // See NOTE [ Sequence Number] for more details on the usages of sequence // number. //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ struct TORCH_API Node : std::enable_shared_from_this { public: explicit Node(uint64_t sequence_nr, edge_list&& next_edges = edge_list()) : sequence_nr_(sequence_nr), next_edges_(std::move(next_edges)) { for (const Edge& edge : next_edges_) { update_topological_nr(edge); } if (AnomalyMode::is_enabled()) { metadata()->store_stack(); // If anomaly mode is enabled and graph is constructed, then assign the // currently evaluating node as the parent of this node. // A parent is a Node where this Node is created. // We are tracking the parents to track multiple backward operations. assign_parent(); } // Store the thread_id of the forward operator. // See NOTE [ Sequence Numbers ] thread_id_ = at::RecordFunction::currentThreadId(); } explicit Node(edge_list&& next_edges = edge_list()) : Node( /*sequence_nr=*/at::sequence_number::get_and_increment(), std::move(next_edges)) {} Node(const Node& other) = delete; Node(Node&& other) = delete; Node& operator=(const Node& other) = delete; Node& operator=(Node&& other) = delete; virtual ~Node() = default; std::shared_ptr getptr() { return shared_from_this(); } variable_list operator()(variable_list&& inputs) { // In the first iteration of named tensors, autograd ignores names and // operates on unnamed tensors. In the long term, autograd should // probably operate with names. at::NoNamesGuard no_names_guard; #ifdef USE_ROCM // Keep track of backward pass for rocblas. at::ROCmBackwardPassGuard in_backward; #endif auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION); if (C10_UNLIKELY(step_callbacks.has_value())) { at::RecordFunction guard(std::move(*step_callbacks)); // Using sequence number and thread id to correlate with // the forward pass function guard.setForwardThreadId(thread_id_); if (guard.needsInputs()) { std::vector inputs_vec(inputs.begin(), inputs.end()); guard.before( name(), c10::ArrayRef( inputs_vec.data(), inputs_vec.size()), static_cast(sequence_nr())); } else { guard.before(name(), static_cast(sequence_nr())); } return apply(std::move(inputs)); } else { return apply(std::move(inputs)); } } // Graph Connectivity API //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Inputs. NOTE: inputs of the grad_fn correspond to Tensor outputs of the // forward function. // Marker for expected undefined input struct undefined_input {}; uint32_t add_input_metadata( const at::TensorOptions& options, c10::SymIntArrayRef shape, bool is_tensor_subclass, bool is_nested) noexcept { uint32_t input_nr = input_metadata_.size(); auto meta_shape = MetadataShape{std::in_place_type, shape}; input_metadata_.emplace_back( options, meta_shape, is_tensor_subclass, is_nested); return input_nr; } uint32_t add_input_metadata(const at::Tensor& t) noexcept { uint32_t input_nr = input_metadata_.size(); input_metadata_.emplace_back(t); return input_nr; } uint32_t add_input_metadata(undefined_input u) noexcept { uint32_t input_nr = input_metadata_.size(); input_metadata_.emplace_back(); return input_nr; } uint32_t num_inputs() const noexcept { return input_metadata_.size(); } const InputMetadata& input_metadata(size_t index) const { return input_metadata_[index]; } // Danger: not thread safe, caller must protect with lock InputMetadata& mutable_input_metadata(size_t index) { return input_metadata_[index]; } std::optional stream() { auto opt_device_type = at::getAccelerator(); if (!opt_device_type.has_value()) { return std::nullopt; } for (const auto& metadata : input_metadata_) { if (metadata.device().type() == opt_device_type.value()) return metadata.stream(); } return std::nullopt; } // Used by the engine to determine what device thread to run on at::Device device() { // Since we pick the first non-CPU tensor, this won't work with // mixed device-type operations (e.g., an op that is both CUDA // and XLA). This is *incredibly* unlikely, so we don't worry // about it. for (const auto& metadata : input_metadata_) { auto device = metadata.device(); if (device.type() != at::kCPU) { return device; } } // Only report to the CPU thread if there really were no tensors // from other devices. return at::kCPU; } void clear_input_metadata() { input_metadata_.clear(); } // Outputs ("Next Edges") void update_topological_nr(const Edge& edge) { TORCH_INTERNAL_ASSERT( !has_parent_, "Cannot update a node's topological_nr after it already has a parent." " If we allow this, we can no longer guarantee that a parent's" " topo_nr is always greater than those of all its children") Node* node = edge.function.get(); if (node) { auto topo_nr = node->topological_nr(); if (topological_nr_ <= topo_nr) { topological_nr_ = topo_nr + 1; } } } void set_next_edge(size_t index, Edge edge) { update_topological_nr(edge); next_edges_[index] = std::move(edge); } void add_next_edge(Edge edge) { update_topological_nr(edge); next_edges_.emplace_back(std::move(edge)); } void set_next_edges(edge_list&& next_edges) { next_edges_ = std::move(next_edges); for (const auto& next_edge : next_edges_) { update_topological_nr(next_edge); } } const Edge& next_edge(size_t index) const noexcept { return next_edges_[index]; } const edge_list& next_edges() const noexcept { return next_edges_; } edge_list& next_edges() noexcept { return next_edges_; } uint32_t num_outputs() const noexcept { return next_edges_.size(); } // Miscellaneous Methods //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ uint64_t sequence_nr() const noexcept { return sequence_nr_; } void set_sequence_nr(uint64_t sequence_nr) { sequence_nr_ = sequence_nr; } // NOTE [ Topological Number ] // // topological_nr is used to prune branches in the DAG during autograd // discovery as maintaining topological_nr helps us check in O(1) if there // does NOT exist a directed path between two nodes. // // The topological order number of this `Node` representing the length of the // longest possible path from this Node to any leaf node. If you are leaf // node, aka AccumulateGrad, this will be zero. This value has the property // that For every pair of nodes X, Y in G, existence of a directed path from X // to Y implies topo_nr(X) > topo_nr(Y). The converse is not true, however, so // we cannot prove existence of a path from X to Y, only non-existence. // // One assumption we make when using topo_nr is that once a node // has been used, i.e., has a parent node, its own topo_nr does not change // we have added some checks with the `has_parent_` field to enforce this. // // What NOT to do: // // 1) 2 -> 1 -> 0 In this diagram we label nodes with their // topo_nr. // 2 -> 1 -> 0 We have two simple graphs that can each // arise from // `t.exp().exp()`, for example. // 2) 2 -> 1 -> 0 // / // 2 -> 1 -> 0 We add 2 as a next edge to 1 even though 1 // already // has a parent. // 3) 2 -> 1 -> 0 // / // 2 -> 3 -> 0 2 < 3, yet there exists a path from 2 to 3! // uint64_t topological_nr() const noexcept { has_parent_ = true; return topological_nr_; } // assigning a node as a parent to this node void assign_parent(); uint64_t thread_id() const noexcept { return thread_id_; } virtual std::string name() const; bool should_compute_output(size_t output_edge_index) const { TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range"); return next_edges_[output_edge_index].is_valid(); } bool should_compute_output(std::initializer_list idxs) const { return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) { for (const auto i : c10::irange(range.first, range.second)) { if (should_compute_output(i)) return true; } return false; }); } bool task_should_compute_output(size_t output_edge_index) const { TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range"); const auto& next = next_edges_[output_edge_index]; if (next.is_valid()) { const auto exec_info = get_current_graph_task_exec_info(); if (exec_info && !exec_info->empty()) { auto it = exec_info->find(next.function.get()); if (it == exec_info->end() || !it->second.should_execute()) { return false; // this edge is not needed for the current graph_task } } return true; } return false; } bool task_should_compute_output( std::initializer_list idxs) const { return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) { for (const auto i : c10::irange(range.first, range.second)) { if (task_should_compute_output(i)) return true; } return false; }); } PyObject* pyobj() const noexcept { return pyobj_; } void set_pyobj(PyObject* pyobj) noexcept { pyobj_ = pyobj; } AnomalyMetadata* metadata() noexcept; // Hook API //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ uintptr_t add_post_hook(std::unique_ptr&& post_hook) { post_hooks_.emplace_back(std::move(post_hook)); // Use the raw pointer as the unique key to identify this hook. This key // can then be used in del_post_hook(key) to remove this hook. return reinterpret_cast(post_hooks_.back().get()); } const std::vector>& post_hooks() const noexcept { return post_hooks_; } // delete a post hook matching the key bool del_post_hook(const uintptr_t& key) { for (auto it = post_hooks_.begin(); it != post_hooks_.end(); ++it) { if (key == reinterpret_cast(it->get())) { post_hooks_.erase(it); return true; } } return false; } std::vector>& post_hooks() noexcept { return post_hooks_; } void add_pre_hook(std::unique_ptr&& pre_hook) { pre_hooks_.emplace_back(std::move(pre_hook)); } void add_tensor_pre_hook(std::unique_ptr&& pre_hook) { tensor_pre_hooks_.emplace_back(std::move(pre_hook)); } void add_retains_grad_hook( std::unique_ptr&& pre_hook, size_t output_idx) { retains_grad_hooks_[output_idx] = std::move(pre_hook); } std::unique_ptr pop_retains_grad_hook(size_t output_idx) { auto ret = std::move(retains_grad_hooks_[output_idx]); retains_grad_hooks_.erase(output_idx); return ret; } const std::vector>& pre_hooks() const noexcept { return pre_hooks_; } std::vector>& pre_hooks() noexcept { return pre_hooks_; } virtual std::vector>& tensor_pre_hooks() noexcept { return tensor_pre_hooks_; } virtual std::unique_ptr& tensor_post_acc_grad_hooks() const noexcept { static std::unique_ptr empty = nullptr; return empty; } std::unordered_map>& retains_grad_hooks() noexcept { return retains_grad_hooks_; } // Customization Points for Subclasses //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ virtual void release_variables() {} virtual void will_release_variables() {} virtual bool is_traceable() { return false; } virtual bool passes_state_transparently() { return false; } // see [Note: Compiled Autograd] // Used by compiled autograd to // 1) Extract tensors/symint args // 2) Collect node information for specialization and caching // Implementations in subclasses should call args.collect() with all node // attrs. These functions are only called durring backward. virtual void compiled_args(CompiledNodeArgs& args) const { throw std::runtime_error( std::string("compiled_args not implemented: ") + name()); } // Used by compiled autograd to call apply() with different saved tensors // Implementations should call saved.before() on all attrs, then apply(), then // saved.after() on all attrs in the same order. virtual variable_list apply_with_saved( const variable_list& inputs, SwapSavedVariables& saved) { throw std::runtime_error( std::string("apply_with_saved not implemented: ") + name()); } // If this node is the AOTBackward node produced by torch.compile. // Compiled Autograd special-cases on this information. virtual bool is_aot_backward() const { return false; } protected: virtual variable_list apply(variable_list&& inputs) = 0; variable_list traced_apply(variable_list inputs); // Sequence number used to correlate backward nodes with forward ops in the // profiler and provide determinism in the engine. // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) uint64_t sequence_nr_; // See NOTE [ Topological Number ] uint64_t topological_nr_ = 0; // Tracks whether this node has been added as the next_edge of another node // via set_next_edge(s), which always calls topological_nr() of all its // children See NOTE [ Topological Number ] for why we need this. mutable bool has_parent_ = false; // Id of the thread that created the instance uint64_t thread_id_ = 0; // Note [Thread Safety on Autograd Node] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Autograd Engine let the owning thread which calls Engine::execute to drive // the GraphTask execution, there might be cases that part of the GraphTask is // shared across different `backward()` or `grad()` calls, i.e. fork new // threads in the middle of the forward and call `backward()` separately from // different threads. We need to protect the thread safety on NodeTask to // prevent data racing on shared variables read/write. // // NB: This is only needed for Autograd Nodes that runs on CPU, technically // "CUDA", "XLA" nodes don't need locking because device threads are always // single threaded. // // Here we add a thread mutex to help protect the Node's thread safety, so // that different threads cannot race the shared data when executing the same // NodeTask from multiple CPU threads. It IS the user/developer responsibility // to take advantage of this mutex to protect the thread safety of their // autograd Node. The general strategy of thread safety on autograd Node: // // 1. User should lock the mutex during Node::release_variables() if the Node // needs // to release the variables on the fly, this serve the purpose that when we // release saved_variables from one thread, no other threads can release // the saved variables concurrently. call the Node::apply(), // 2. User should lock the mutex during Node::apply(), this is to ensure Node // that // writing to the shared variable are not racing across threads (i.e. // AccumulateGrad and custom C++ Autograd Node if writing to shared // variables ) // 3. item 2 and item 3 should work together so that when we release saved // variables // from one thread, no other threads can call Node::apply(), this ensures // the variable references from other threads aren't dangling. // 4. if the Node don't release any variables and no shared data read/write in // the Node // i.e. purely functional, user don't need to lock the mutex // // This way we could protect the thread safety on Autograd Node, but we could // still not protect the thread safety on Node pre/post C++ hooks (python // hooks are automatically thread safe), we rely on the user to write thread // safe C++ hooks if they want the hook to be correctly applied in // multithreading environment. std::mutex mutex_; edge_list next_edges_; PyObject* pyobj_ = nullptr; // weak reference std::unique_ptr anomaly_metadata_ = nullptr; // NOTE [Hooks ordering] // We have 3 separate fields for pre hooks registered to the autograd nodes // because the conditions under which they execute are different, and we // want more fine-grained control over the order in which different types // of hooks are executed. // - pre_hooks are only executed when the node itself is executed // - tensor_pre_hook is executed as long as the engine traverses over it // even if that node won't be executed. // - retains_grad_hook are like tensor_pre_hooks except they are always // ordered after all other tensor pre hooks std::vector> pre_hooks_; std::vector> tensor_pre_hooks_; std::unordered_map> retains_grad_hooks_; std::vector> post_hooks_; at::SmallVector input_metadata_; }; struct TraceableFunction : public Node { using Node::Node; bool is_traceable() final { return true; } }; //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Associated Free Nodes //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ namespace detail { // Implementation of `collect_next_edges` (see below). struct MakeNextFunctionList : IterArgs { edge_list next_edges; using IterArgs::operator(); void operator()(const Variable& variable) { if (variable.defined()) { next_edges.emplace_back(impl::gradient_edge(variable)); } else { next_edges.emplace_back(); } } void operator()(const Variable* variable) { operator()(*variable); } void operator()(const std::optional& variable) { if (variable.has_value()) { operator()(*variable); } else { next_edges.emplace_back(); } } }; } // namespace detail inline void create_gradient_edge( Variable& variable, std::shared_ptr function) { // Copy before move. const auto input_nr = function->add_input_metadata(variable); impl::set_gradient_edge(variable, {std::move(function), input_nr}); } inline bool any_variable_requires_grad(const variable_list& variables) { return std::any_of( variables.begin(), variables.end(), [](const Variable& variable) { return variable.defined() && variable.requires_grad(); }); } template edge_list collect_next_edges(Variables&&... variables) { detail::MakeNextFunctionList make; make.apply(std::forward(variables)...); return std::move(make.next_edges); } struct TypeAndSize { TypeAndSize() = default; /* implicit */ TypeAndSize(const at::Tensor& t) : sym_sizes(t.sym_sizes().vec()), options(t.options()) {} at::Tensor zeros(); std::vector sym_sizes; at::TensorOptions options; }; } // namespace torch::autograd