:github_url: https://github.com/pytorch/pytorch .. _program_listing_file_c10_cuda_CUDAGuard.h: Program Listing for File CUDAGuard.h ==================================== |exhale_lsh| :ref:`Return to documentation for file ` (``c10/cuda/CUDAGuard.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include #include #include #include namespace c10::cuda { // This code is kind of boilerplatey. See Note [Whither the DeviceGuard // boilerplate] struct CUDAGuard { explicit CUDAGuard() = delete; explicit CUDAGuard(DeviceIndex device_index) : guard_(device_index) {} explicit CUDAGuard(Device device) : guard_(device) {} // Copy is not allowed CUDAGuard(const CUDAGuard&) = delete; CUDAGuard& operator=(const CUDAGuard&) = delete; // Move is not allowed (there is no uninitialized state) CUDAGuard(CUDAGuard&& other) = delete; CUDAGuard& operator=(CUDAGuard&& other) = delete; ~CUDAGuard() = default; void set_device(Device device) { guard_.set_device(device); } void reset_device(Device device) { guard_.reset_device(device); } void set_index(DeviceIndex device_index) { guard_.set_index(device_index); } Device original_device() const { return guard_.original_device(); } Device current_device() const { return guard_.current_device(); } private: c10::impl::InlineDeviceGuard guard_; }; struct OptionalCUDAGuard { explicit OptionalCUDAGuard() = default; explicit OptionalCUDAGuard(std::optional device_opt) : guard_(device_opt) {} explicit OptionalCUDAGuard(std::optional device_index_opt) : guard_(device_index_opt) {} // Copy is not allowed OptionalCUDAGuard(const OptionalCUDAGuard&) = delete; OptionalCUDAGuard& operator=(const OptionalCUDAGuard&) = delete; // See Note [Move construction for RAII guards is tricky] OptionalCUDAGuard(OptionalCUDAGuard&& other) = delete; // See Note [Move assignment for RAII guards is tricky] OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete; ~OptionalCUDAGuard() = default; void set_device(Device device) { guard_.set_device(device); } void reset_device(Device device) { guard_.reset_device(device); } void set_index(DeviceIndex device_index) { guard_.set_index(device_index); } std::optional original_device() const { return guard_.original_device(); } std::optional current_device() const { return guard_.current_device(); } void reset() { guard_.reset(); } private: c10::impl::InlineOptionalDeviceGuard guard_; }; struct CUDAStreamGuard { explicit CUDAStreamGuard() = delete; explicit CUDAStreamGuard(Stream stream) : guard_(stream) {} ~CUDAStreamGuard() = default; CUDAStreamGuard(const CUDAStreamGuard&) = delete; CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete; CUDAStreamGuard(CUDAStreamGuard&& other) = delete; CUDAStreamGuard& operator=(CUDAStreamGuard&& other) = delete; void reset_stream(Stream stream) { guard_.reset_stream(stream); } CUDAStream original_stream() const { return CUDAStream(CUDAStream::UNCHECKED, guard_.original_stream()); } CUDAStream current_stream() const { return CUDAStream(CUDAStream::UNCHECKED, guard_.current_stream()); } Device current_device() const { return guard_.current_device(); } Device original_device() const { return guard_.original_device(); } private: c10::impl::InlineStreamGuard guard_; }; struct OptionalCUDAStreamGuard { explicit OptionalCUDAStreamGuard() = default; explicit OptionalCUDAStreamGuard(Stream stream) : guard_(stream) {} explicit OptionalCUDAStreamGuard(std::optional stream_opt) : guard_(stream_opt) {} OptionalCUDAStreamGuard(const OptionalCUDAStreamGuard&) = delete; OptionalCUDAStreamGuard& operator=(const OptionalCUDAStreamGuard&) = delete; // See Note [Move construction for RAII guards is tricky] OptionalCUDAStreamGuard(OptionalCUDAStreamGuard&& other) = delete; // See Note [Move assignment for RAII guards is tricky] OptionalCUDAStreamGuard& operator=(OptionalCUDAStreamGuard&& other) = delete; ~OptionalCUDAStreamGuard() = default; void reset_stream(Stream stream) { guard_.reset_stream(stream); } std::optional original_stream() const { auto r = guard_.original_stream(); if (r.has_value()) { return CUDAStream(CUDAStream::UNCHECKED, r.value()); } else { return std::nullopt; } } std::optional current_stream() const { auto r = guard_.current_stream(); if (r.has_value()) { return CUDAStream(CUDAStream::UNCHECKED, r.value()); } else { return std::nullopt; } } void reset() { guard_.reset(); } private: c10::impl::InlineOptionalStreamGuard guard_; }; struct CUDAMultiStreamGuard { explicit CUDAMultiStreamGuard(ArrayRef streams) : guard_(unwrapStreams(streams)) {} CUDAMultiStreamGuard(const CUDAMultiStreamGuard&) = delete; CUDAMultiStreamGuard& operator=(const CUDAMultiStreamGuard&) = delete; // See Note [Move construction for RAII guards is tricky] CUDAMultiStreamGuard(CUDAMultiStreamGuard&& other) = delete; // See Note [Move assignment for RAII guards is tricky] CUDAMultiStreamGuard& operator=(CUDAMultiStreamGuard&& other) = delete; ~CUDAMultiStreamGuard() = default; private: c10::impl::InlineMultiStreamGuard guard_; static std::vector unwrapStreams(ArrayRef cudaStreams) { std::vector streams; streams.reserve(cudaStreams.size()); for (const CUDAStream& cudaStream : cudaStreams) { streams.push_back(cudaStream); } return streams; } }; } // namespace c10::cuda