:github_url: https://github.com/pytorch/pytorch .. _program_listing_file_torch_csrc_autograd_variable.h: Program Listing for File variable.h =================================== |exhale_lsh| :ref:`Return to documentation for file ` (``torch/csrc/autograd/variable.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::autograd { using Variable = at::Tensor; } // namespace torch::autograd // The following are all internal APIs and should not be shown in libtorch docs. // Therefore, we wrap the following code with `#ifndef DOXYGEN_SHOULD_SKIP_THIS // ... #endif` #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace torch::autograd { inline bool isDifferentiableType(at::ScalarType t) { return isFloatingType(t) || isComplexType(t); } struct Node; struct AutogradMeta; struct DifferentiableViewMeta; // Private-ish functions for manipulating variables; we don't want to put them // on Tensor proper namespace impl { // WARNING: This may return a nullptr. If you require AutogradMeta to return // a materialized structure, use materialize_autograd_meta instead. TORCH_API AutogradMeta* get_autograd_meta(const at::TensorBase&); // WARNING: This will return a nullptr if the Tensor is not a view. TORCH_API DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase&); // Returns the current autograd meta, materializing it if it was previously // none. This counts as a *mutating* operation, so do not call it on // "read-only" operators; in particular, this is NOT thread safe TORCH_API AutogradMeta* materialize_autograd_meta(const at::TensorBase&); TORCH_API void set_grad_accumulator( const Variable&, std::weak_ptr grad_accumulator); TORCH_API std::shared_ptr try_get_grad_accumulator(const Variable&); TORCH_API std::shared_ptr grad_accumulator(const Variable&); TORCH_API Edge gradient_edge(const Variable&); TORCH_API void set_gradient_edge(const Variable&, Edge edge); // Autograd Graph Interaction //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TORCH_API void rebase_history(const Variable&, Edge gradient_edge); TORCH_API Node* grad_fn_unsafe(const Variable&); TORCH_API void bump_version(const Variable&); TORCH_API void set_version_counter( const Variable&, const c10::VariableVersion& version_counter); TORCH_API const c10::VariableVersion& version_counter(const Variable&); TORCH_API void set_name(const Variable&, const std::string& name); TORCH_API void add_hook( const at::TensorBase&, std::unique_ptr hook); TORCH_API std::vector>& hooks(const Variable&); TORCH_API void clear_hooks(const at::TensorBase&); TORCH_API void set_post_acc_grad_hooks( const at::TensorBase&, std::unique_ptr dict); TORCH_API std::unique_ptr& post_acc_grad_hooks( const Variable&); TORCH_API void create_cpp_hook( const at::TensorBase&, bool is_retains_grad_hooks = false); } // namespace impl //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // AutogradMeta //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { std::string name_; Variable grad_; std::shared_ptr grad_fn_; std::weak_ptr grad_accumulator_; // This field is used to store all the forward AD gradients // associated with this AutogradMeta (and the Tensor it corresponds to) // There is a semantic 1:1 correspondence between AutogradMeta and // ForwardGrad but: // - This field is lazily populated. // - This field is a shared_ptr but it must never be // shared by multiple Tensors. See Note [ Using ForwardGrad ] // Any transition from not_initialized to initialized // must be protected by mutex_ mutable std::shared_ptr fw_grad_; // The hooks_ field is actually reused by both python and cpp logic // For both cases, we have a data structure, cpp_hooks_list_ (cpp) // or dict (python) which is the canonical copy. // Then, for both cases, we always register a single hook to // hooks_ which wraps all the hooks in the list/dict. // And, again in both cases, if the grad_fn exists on that tensor // we will additionally register a single hook to the grad_fn. // // Note that the cpp and python use cases aren't actually aware of // each other, so using both is not defined behavior. std::vector> hooks_; std::shared_ptr cpp_hooks_list_; // The post_acc_grad_hooks_ field stores only Python hooks // (PyFunctionTensorPostAccGradHooks) that are called after the // .grad field has been accumulated into. This is less complicated // than the hooks_ field, which encapsulates a lot more. std::unique_ptr post_acc_grad_hooks_ = nullptr; // Only meaningful on leaf variables (must be false otherwise) bool requires_grad_{false}; // Only meaningful on non-leaf variables (must be false otherwise) bool retains_grad_{false}; bool is_view_{false}; // The "output number" of this variable; e.g., if this variable // was the second output of a function, then output_nr == 1. // We use this to make sure we can setup the backwards trace // correctly when this variable is passed to another function. uint32_t output_nr_; // Mutex to ensure that concurrent read operations that modify internal // state are still thread-safe. Used by grad_fn(), grad_accumulator(), // fw_grad() and set_fw_grad() // This is mutable because we need to be able to acquire this from const // version of this class for the functions above mutable std::mutex mutex_; void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) final { TORCH_CHECK( !requires_grad || isDifferentiableType(at::typeMetaToScalarType(self_impl->dtype())), "Only Tensors of floating point and complex dtype can require gradients"); requires_grad_ = requires_grad; } bool requires_grad() const override { return requires_grad_ || grad_fn_; } Variable& mutable_grad() override { return grad_; } const Variable& grad() const override { return grad_; } const Variable& fw_grad(uint64_t level, const at::TensorBase& self) const override; void set_fw_grad( const at::TensorBase& new_grad, const at::TensorBase& self, uint64_t level, bool is_inplace_op) override; AutogradMeta( at::TensorImpl* self_impl = nullptr, bool requires_grad = false, Edge gradient_edge = Edge()) : grad_fn_(std::move(gradient_edge.function)), output_nr_(gradient_edge.input_nr) { // set_requires_grad also checks error conditions. if (requires_grad) { TORCH_INTERNAL_ASSERT(self_impl); set_requires_grad(requires_grad, self_impl); } TORCH_CHECK( !grad_fn_ || !requires_grad_, "requires_grad should be false if grad_fn is set"); } ~AutogradMeta() override { // If AutogradMeta is being destroyed, it means that there is no other // reference to its corresponding Tensor. It implies that no other thread // can be using this object and so there is no need to lock mutex_ here to // guard the check if fw_grad_ is populated. if (fw_grad_) { // See note [ Using ForwardGrad ] fw_grad_->clear(); } } }; struct TORCH_API ViewFunc { virtual ~ViewFunc() = default; virtual std::vector get_symints() const { return {}; } virtual size_t num_symints() const { return 0; } virtual std::vector get_tensors() const { return {}; } virtual size_t num_tensors() const { return 0; } virtual at::Tensor operator()(const at::Tensor&) const = 0; virtual std::unique_ptr clone_and_set( std::optional> = std::nullopt, std::optional> = std::nullopt) const = 0; protected: virtual void set_symints(std::vector) {} virtual void set_tensors(std::vector) {} }; struct ChainedViewFunc : public ViewFunc { ChainedViewFunc( std::unique_ptr first, std::unique_ptr second) : first(std::move(first)), second(std::move(second)) {} ~ChainedViewFunc() override = default; std::vector get_symints() const override; size_t num_symints() const override { return first->num_symints() + second->num_symints(); } std::vector get_tensors() const override; size_t num_tensors() const override { return first->num_tensors() + second->num_tensors(); } at::Tensor operator()(const at::Tensor&) const override; std::unique_ptr clone_and_set( std::optional> = std::nullopt, std::optional> = std::nullopt) const override; private: std::unique_ptr first; std::unique_ptr second; }; struct ErroringViewFunc : public ViewFunc { ErroringViewFunc(std::string error_msg) : error_msg(std::move(error_msg)) {} ~ErroringViewFunc() override = default; at::Tensor operator()(const at::Tensor&) const override { TORCH_CHECK(false, error_msg); } std::unique_ptr clone_and_set( std::optional> = std::nullopt, std::optional> = std::nullopt) const override { return std::make_unique(error_msg); } private: std::string error_msg; }; struct TORCH_API ViewInfo { Variable base_; std::unique_ptr view_fn_; std::function rev_view_fn_; bool has_view_fn() const { // assume either BOTH or NEITHER of view_fn_ and rev_view_fn_ exist return view_fn_ != nullptr; } const ViewFunc& view_fn() const { TORCH_CHECK( has_view_fn(), "Can only access the view function if it exists."); return *view_fn_; } std::function rev_view_fn() const { TORCH_CHECK( has_view_fn(), "Can only access the reverse view function if it exists."); return rev_view_fn_; } ViewInfo chain( const Variable& base, const Variable& tensor, std::unique_ptr view_func = nullptr, std::function rev_view_func = nullptr) const; ViewInfo( Variable base, std::unique_ptr view_fn, std::function rev_view_fn) : base_(std::move(base)), view_fn_(std::move(view_fn)), rev_view_fn_(std::move(rev_view_fn)) { TORCH_CHECK(base_.defined(), "base is undefined"); } }; //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // DifferentiableViewMeta //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ enum class CreationMeta : uint8_t { DEFAULT, IN_CUSTOM_FUNCTION, MULTI_OUTPUT_NODE, NO_GRAD_MODE, INFERENCE_MODE }; inline CreationMeta propagate_creation_meta( CreationMeta prev_view_creation_meta, CreationMeta new_view_creation_meta) { return (new_view_creation_meta == CreationMeta::DEFAULT) ? prev_view_creation_meta : (prev_view_creation_meta == CreationMeta::INFERENCE_MODE ? prev_view_creation_meta : new_view_creation_meta); } TORCH_API void handle_view_on_rebase( DifferentiableViewMeta* diff_view_meta, bool indirect = false); struct TORCH_API DifferentiableViewMeta : public AutogradMeta { private: std::optional backward_info_; std::optional forward_info_; // Optimization to reduce the number of ViewInfo we create. // In the (very common) case where backward_info_ == forward_info_, we only // populate backward_info_ (that should be used as both the forward and // backward view information) and set shared_view_info_ = true. Invariants: // - If shared_view_info_ is false, there is no special constraints on // backward_info_ and forward_info_ // - If shared_view_info_ is true, we must have: // - backward_info_.has_value() == true // - forward_info_.has_value() == false bool shared_view_info_; uint32_t attr_version_; CreationMeta creation_meta_; public: bool requires_grad() const override { return requires_grad_ || grad_fn_ || (has_bw_view() && get_backward_view().base_.requires_grad()); } bool shared_view_info() const { return shared_view_info_; } bool has_bw_view() const { return backward_info_.has_value(); } const ViewInfo& get_backward_view() const { TORCH_CHECK( has_bw_view(), "backward view info can only exist for backward views."); // NOLINTNEXTLINE(bugprone-unchecked-optional-access) return backward_info_.value(); } uint32_t get_attr_version() const { TORCH_CHECK( has_bw_view(), "attr_version can only exist for backward views."); return attr_version_; } void set_attr_version(uint32_t new_attr_version) { TORCH_CHECK( has_bw_view(), "attr_version can only exist for backward views."); attr_version_ = new_attr_version; } CreationMeta get_creation_meta() const { TORCH_CHECK( has_bw_view(), "creation_meta can only exist for backward views."); return creation_meta_; } void set_creation_meta(CreationMeta new_creation_meta) { TORCH_CHECK( has_bw_view(), "creation_meta can only exist for backward views."); creation_meta_ = new_creation_meta; } bool has_fw_view() const { return shared_view_info_ || forward_info_.has_value(); } const ViewInfo& get_forward_view() const { TORCH_CHECK( has_fw_view(), "forward view info can only exist for forward views."); TORCH_CHECK( !shared_view_info_ || has_bw_view(), "forward view info can only exist for forward views."); // NOLINTNEXTLINE(bugprone-unchecked-optional-access) return shared_view_info_ ? backward_info_.value() : forward_info_.value(); } DifferentiableViewMeta( at::TensorImpl* self_impl, std::optional backward_info, std::optional forward_info, bool shared_view_info, CreationMeta creation_meta = CreationMeta::DEFAULT); }; //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Variable Implementation //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Factory Functions //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // See NOTE [ Autograd View Variables ] for details. // Differentiable view. Track history with DifferentiableViewMeta. inline Variable make_variable_differentiable_view( const at::Tensor& data, std::optional backward_info, std::optional forward_info, bool shared_view_info, CreationMeta creation_meta, bool allow_tensor_metadata_change = true) { if (data.defined()) { TORCH_CHECK( data.getIntrusivePtr()->autograd_meta() == nullptr, "Attempted to make a tensor into a differentiable view, but the " "tensor already had autograd metadata associated with it. If you are " "using a __torch_dispatch__ mode, the most common cause for this " "problem is that you used torch.overrides.enable_reentrant_dispatch() " "improperly; tensors created within the extent of reentrant dispatch " "MUST NOT be directly returned from __torch_dispatch__; instead, they " "must be wrapped into fresh tensors that serve as the output. If you " "are not using wrappers, you probably don't need reentrant dispatch. " "If this doesn't seem applicable, please file a bug to PyTorch."); at::TensorImpl* data_impl = data.unsafeGetTensorImpl(); data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); data_impl->set_autograd_meta(std::make_unique( data_impl, std::move(backward_info), std::move(forward_info), shared_view_info, creation_meta)); return data; } return Variable(); } // See NOTE [ Autograd View Variables ] for details. // Non-differentiable view. Just share version counter. inline Variable make_variable_non_differentiable_view( const Variable& base, const at::Tensor& data, bool allow_tensor_metadata_change = true) { if (data.defined()) { // Currently all of non-differentiable view ops(detach/_indices/_values) // share the same TensorImpl as their base Tensor. Thus a new TensorImpl // allocation here is required. auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( /*version_counter=*/impl::version_counter(base), /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); data_impl_copy->set_autograd_meta(nullptr); return Variable(data_impl_copy); } return Variable(); } inline Variable make_variable( at::Tensor data, bool requires_grad = false, bool allow_tensor_metadata_change = true) { if (data.defined()) { if (data.getIntrusivePtr().use_count() == 1 && data.getIntrusivePtr()->unique_version()) { auto data_impl = data.unsafeReleaseIntrusivePtr(); data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); if (requires_grad) { data_impl->set_autograd_meta( std::make_unique(data_impl.get(), requires_grad)); } else { data_impl->set_autograd_meta(nullptr); } return Variable(std::move(data_impl)); } else { auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( /*version_counter=*/0, /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); if (requires_grad) { data_impl_copy->set_autograd_meta(std::make_unique( data_impl_copy.get(), requires_grad)); } else { data_impl_copy->set_autograd_meta(nullptr); } return Variable(std::move(data_impl_copy)); } } return Variable(); } inline Variable make_variable( const at::Tensor& data, Edge gradient_edge, bool allow_tensor_metadata_change = true) { if (data.defined()) { auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( /*version_counter=*/0, /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); data_impl_copy->set_autograd_meta(std::make_unique( data_impl_copy.get(), false, std::move(gradient_edge))); return Variable(data_impl_copy); } return Variable(); } struct VariableHooks final : at::impl::VariableHooksInterface { at::TensorBase tensor_data(const at::TensorBase&) const override; at::TensorBase variable_data(const at::TensorBase&) const override; const std::shared_ptr& grad_fn( const at::TensorBase&) const override; unsigned _register_hook( const at::TensorBase&, std::function hook) const override; void remove_hook(const at::TensorBase&, unsigned pos) const override; bool is_view(const at::TensorBase&) const override; const at::TensorBase& base(const at::TensorBase&) const override; const std::string& name(const at::TensorBase&) const override; bool is_leaf(const at::TensorBase&) const override; int64_t output_nr(const at::TensorBase&) const override; void set_data(const at::TensorBase& self, const at::TensorBase& new_data) const override; at::TensorBase data(const at::TensorBase& self) const override; int64_t _version(const at::TensorBase& self) const override; void retain_grad(const at::TensorBase& self) const override; bool retains_grad(const at::TensorBase& self) const override; void _backward( const at::Tensor& self, at::TensorList inputs, const std::optional& gradient, std::optional keep_graph, bool create_graph) const override; void requires_grad_(const at::TensorBase& self, bool _requires_grad) const override; void basic_autograd_not_implemented_fallback( const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) const override; }; namespace utils { TORCH_API bool has_same_meta(const Variable& base, const Variable& other); } // namespace utils } // namespace torch::autograd #endif /* DOXYGEN_SHOULD_SKIP_THIS */