Shortcuts

Source code for torchrl.weight_update._shared

from __future__ import annotations

from collections.abc import Callable
from typing import Any

import torch
import torch.distributed

from tensordict import TensorDict, TensorDictBase

from torch import multiprocessing as mp, nn

from torchrl._utils import logger as torchrl_logger
from torchrl.collectors._constants import WEIGHT_SYNC_TIMEOUT

from torchrl.weight_update.utils import _resolve_model
from torchrl.weight_update.weight_sync_schemes import (
    TransportBackend,
    WeightStrategy,
    WeightSyncScheme,
)


def _close_mp_queue(queue: mp.Queue) -> None:
    """Close a multiprocessing Queue and wait for its feeder thread to exit."""
    queue.close()
    queue.join_thread()


[docs]class SharedMemTransport: """Shared memory transport for in-place weight updates. This transport uses queue-based buffer distribution for initialization, then updates shared memory tensors directly for subsequent weight updates. Workers automatically see weight updates without explicit communication. Initialization flow: - Shared memory buffers are created and sent to workers via per-worker queues - Workers receive the buffer reference and apply weights to their models - Subsequent updates are pure in-place shared memory (zero-copy) Both CPU and CUDA tensors maintain shared references when sent through mp.Queue. """ def __init__(self): self._params_map = None # a dict[worker_idx, TensorDictBase] map self._weight_queues = ( None # Dict of per-worker queues for distributing shared weights ) self._unique_weights = None @property def unique_weights(self) -> list[TensorDictBase]: """Get the unique weights. Returns: The unique weights. """ if self._unique_weights is None: raise RuntimeError("Unique weights not set. Call register_weights() first.") return self._unique_weights
[docs] def register_weights( self, params_map: dict[int, mp.Queue], init_queues: dict[int, mp.Queue] ) -> None: """Initialize per-worker queues for shared memory buffer distribution.""" from torchrl.collectors.utils import _cast self._weight_queues = init_queues self._params_map = params_map # Create set of the unique weights self._unique_weights = [] for weights in params_map.values(): if id(weights) in [id(w) for w in self._unique_weights]: continue weights = weights.data.apply(_cast, weights) self._unique_weights.append(weights)
[docs] def setup_connection_and_weights_on_sender(self) -> None: """Send shared memory buffer reference to workers via their per-worker queues. Both CPU and CUDA tensors maintain shared references through queues. Each worker reads from its own dedicated queue, to avoid race conditions. """ if self._weight_queues is None: raise RuntimeError("Queues not created yet. Call init_on_sender() first.") for worker_idx, queue in self._weight_queues.items(): weights = self._params_map[worker_idx] queue.put(weights)
[docs] def setup_connection_and_weights_on_receiver( self, *, worker_idx: int | None = None, weights: Any = None, model: Any = None, strategy: Any = None, timeout: float = WEIGHT_SYNC_TIMEOUT, ) -> TensorDictBase: """Receive shared memory buffer reference from sender via their per-worker queues. Each worker reads from its own dedicated queue, to avoid race conditions. Args: worker_idx: The worker index. weights: Ignored (weights come from queue). model: Ignored. strategy: Ignored. timeout: Timeout for reading from queue. Returns: The shared memory weights TensorDict. """ if self._weight_queues is None: raise RuntimeError("Queues not created yet. Call init_on_sender() first.") if worker_idx not in self._weight_queues: raise RuntimeError(f"Worker {worker_idx} not registered in queues.") # Read from dedicated queue for this worker worker_queue = self._weight_queues[worker_idx] received_weights = worker_queue.get(timeout=timeout) return received_weights
[docs] def send_weights(self, weights: Any) -> None: """Update weights in-place in shared memory. Args: weights: New weights to send. Can be a TensorDictBase or dict. Raises: ValueError: If weights type is unsupported. """ # Update shared memory in-place (workers see this automatically) if isinstance(weights, dict): weights = TensorDict(weights) if not isinstance(weights, TensorDictBase): raise ValueError(f"Unsupported weights type: {type(weights)=}") # Unflatten if needed to match shared buffer structure weights_to_update = weights if any("." in key for key in weights.keys()): weights_to_update = weights.unflatten_keys(".") # Detach weights to allow in-place updates (gradients are not needed for weight sync) weights_to_update = weights_to_update.detach() if self._unique_weights is None: raise RuntimeError("Unique weights not set. Call register_weights() first.") for buffer in self._unique_weights: if buffer.requires_grad: raise RuntimeError( "Gradients should not be required for shared memory buffers." ) if weights_to_update.requires_grad: raise RuntimeError("Gradients should not be required for weights.") buffer.update_(weights_to_update, non_blocking=True) if torch.cuda.is_available(): torch.cuda.synchronize()
[docs] def receive_weights( self, timeout: float | None = None, *, weights: Any = None, model: Any = None, strategy: Any = None, ) -> Any | None: """Apply shared memory weights to the model. For shared memory, weights are already available (passed via the weights arg). This method applies them to the model, matching the pattern of other transports. Args: timeout: Ignored (shared memory access is instant). weights: The shared memory buffer containing current weights. model: The model to apply weights to. strategy: Strategy for applying weights. Returns: The applied weights, or None if not applied. """ # Apply weights to model if provided (same pattern as other transports) if model is not None and strategy is not None and weights is not None: torchrl_logger.debug("Applying shared memory weights to model.") strategy.apply_weights(model, weights) return weights return None
[docs] def send_ack(self, message: str = "updated") -> None: """No-op for shared memory - no acknowledgment needed."""
[docs]class SharedMemWeightSyncScheme(WeightSyncScheme): """Weight synchronization using shared memory. This scheme uses shared memory for in-place weight updates. Workers automatically see weight updates without explicit message passing. A background thread on the receiver side listens for "receive" instructions from the sender. When an instruction arrives, the thread applies the current shared memory weights to the model and sends an acknowledgment. Args: strategy: The weight transmission strategy (default: "tensordict"). sync: If True (default), send() blocks until receiver acknowledges. If False, send() returns immediately (use send_async/wait_async). Example: >>> # Basic usage >>> scheme = SharedMemWeightSyncScheme() >>> # Weights are initialized via init_on_sender() """ def __init__( self, strategy: str = "tensordict", sync: bool = True, ): super().__init__(strategy) self.sync = sync # Create a single shared transport for all workers self.shared_transport = SharedMemTransport() # Create per-worker queues to avoid race conditions # Each worker gets its own queue for weight initialization self._weight_init_queues = {} # worker_idx -> Queue # Instruction queues: sender puts "receive" instruction, receiver's background thread reads self._instruction_queues: dict[int, mp.Queue] = {} # worker_idx -> Queue # Acknowledgment queues: receiver puts "updated" ack, sender reads for sync mode self._ack_queues: dict[int, mp.Queue] = {} # worker_idx -> Queue # Receiver's instruction queue reference (set during init_on_receiver) self._receiver_instruction_queue: mp.Queue | None = None self._receiver_ack_queue: mp.Queue | None = None def _init_on_sender_impl( self, *, model_id: str | None = None, context: Any = None, weights: TensorDictBase | None = None, model: nn.Module | None = None, params_map: dict[int, TensorDictBase] | None = None, devices: list[torch.device] | None = None, device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, num_workers: int | None = None, ctx: Any = None, ) -> None: """Initialize on the main process (sender side). We create a map dict[worker_idx, weights_on_device]. Each model will be assigned a device. If two workers share the same device, the entry in the dict will be the same. To do this, we need to know the number of workers, their assigned device, and have access to the parameters. If a context is provided, we read the devices from it. If not, the dict[worker_idx, device] map must be provided explicitly. In some cases, the policy on the worker side will be on multiple devices which may or may not be the same as the devices on the main process. In this case, init_on_sender() needs to receive a mapping function as argument that will take as input the worker_idx and the parameters and return a new set of parameters on the desired devices. Args: model_id: Identifier for the model being synchronized context: Optional context object providing device_to_workers mapping and model access weights: Pre-extracted weights as TensorDict (for policy factory usage) model: Model to extract weights from params_map: Direct mapping of worker_idx to weights on device (most explicit) devices: List of devices for each worker device_map_fn: Custom function to map worker_idx and weights to device-specific weights num_workers: Number of workers (required with device_map_fn) ctx: Multiprocessing context. Defaults to `mp.get_context()`. Examples: Simple usage with collector context (stateful policy): >>> policy = make_stateful_policy() >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") >>> collector = MultiSyncCollector( ... create_env_fn=[lambda: GymEnv("CartPole-v1")], ... policy=policy, ... frames_per_batch=100, ... total_frames=1000, ... weight_sync_schemes={"policy": scheme}, ... ) >>> # scheme.init_on_sender() is called automatically by collector Pre-initialized usage (policy factory): >>> policy_on_main = make_stateful_policy() >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") >>> # Must initialize before collector creation when using policy_factory >>> scheme.init_on_sender( ... model_id="policy", ... weights=TensorDict.from_module(policy_on_main), ... devices=[torch.device("cuda:0"), torch.device("cuda:1")], ... num_workers=2, ... ) >>> collector = MultiSyncCollector( ... create_env_fn=[lambda: GymEnv("CartPole-v1")], ... policy_factory=[make_stateful_policy], ... frames_per_batch=100, ... total_frames=1000, ... weight_sync_schemes={"policy": scheme}, ... ) Direct params_map usage (advanced): >>> weights_cpu = TensorDict.from_module(policy).share_memory_() >>> weights_cuda = weights_cpu.to("cuda").share_memory_() >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") >>> scheme.init_on_sender( ... model_id="policy", ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda}, ... ) """ # Plan: the goal of this init is to obtain a map dict[worker_idx, weights_on_device] that we can use to init # the weights on the workers. # Scenarios: # - Easiest scenario: the user provides the map directly (params_map). Nothing to do other than creating # the transport and registering the workers etc. # - The user provides a model or its params and a device map. We need to create the map from the params # explicitly. # - The user provides a context (e.g. a Collector) and a model_id. Same as above, except that we need # to collect the model from the context. params_map = self._get_params_map( context=context, model_id=model_id, weights=weights, model=model, params_map=params_map, devices=devices, device_map_fn=device_map_fn, num_workers=num_workers, ) # Create per-worker queues if not already created # Collect all unique worker indices all_workers = list(params_map.keys()) if ctx is None: ctx = mp.get_context() for worker_idx in all_workers: if worker_idx not in self._weight_init_queues: self._weight_init_queues[worker_idx] = ctx.Queue() # Create instruction queues for background receiver if worker_idx not in self._instruction_queues: self._instruction_queues[worker_idx] = ctx.Queue() # Create ack queues for synchronous mode if worker_idx not in self._ack_queues: self._ack_queues[worker_idx] = ctx.Queue() # Set worker info in transport self.shared_transport.register_weights(params_map, self._weight_init_queues) # Store model_id and context on scheme self.model_id = model_id if context is not None: self.context = context def _get_params_map( self, context: Any = None, model_id: str | None = None, weights: TensorDictBase | None = None, model: nn.Module | None = None, params_map: dict[int, TensorDictBase] | None = None, devices: list[torch.device] | None = None, device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, num_workers: int | None = None, ): """Get the params_map for init_on_sender().""" # Import _cast locally to avoid circular imports from torchrl.collectors.utils import _cast if params_map is not None: # Sanity check: params_map must be a dict[int, TensorDictBase] # All other args must be None if ( not isinstance(params_map, dict) or not all(isinstance(v, int) for v in params_map.keys()) or not all(isinstance(v, TensorDictBase) for v in params_map.values()) ): raise ValueError("params_map must be a dict[int, TensorDictBase]") if model_id is not None or weights is not None or model is not None: raise ValueError( "model_id, weights, and model cannot be provided if params_map is provided" ) if context is not None: raise ValueError("context cannot be provided if params_map is provided") if devices is not None: raise ValueError("devices cannot be provided if params_map is provided") if device_map_fn is not None: raise ValueError( "device_map_fn cannot be provided if params_map is provided" ) if num_workers is not None: raise ValueError( "num_workers cannot be provided if params_map is provided" ) return params_map elif context is not None: if devices is not None: raise ValueError("devices cannot be provided if context is provided") # Sanity check: model_id must be provided if context is provided # All other args must be None if model_id is None: raise ValueError("model_id must be provided if context is provided") if model is not None: raise ValueError("model cannot be provided if context is provided") if weights is not None: raise ValueError("weights cannot be provided if context is provided") if device_map_fn is not None: raise ValueError( "device_map_fn cannot be provided if context is provided" ) # Get device map: the devices are stored as policy_device in the collector -- other contexts will be customized later devices = context.policy_device if num_workers is not None and num_workers != len(devices): raise ValueError( "num_workers cannot be provided if context is provided" ) # Get the weights model = _resolve_model(context, model_id) if model is None: if model_id == "policy": # we need to get a copy of the weights from the factory model = context.policy_factory[0]() weights = TensorDict.from_module(model) elif model is not None: if weights is not None: raise ValueError("weights cannot be provided if model is provided") weights = TensorDict.from_module(model) if weights is not None: weights = weights.data.apply(_cast, weights) # To make the map, we need the list of devices, or the map fn if devices is not None: # Get the unique devices devices_set = set(devices) weights_devices = ( {p.device for p in weights.values(True, True)} if weights is not None else set() ) if len(weights_devices) == 1: weights_device = weights_devices.pop() else: weights_device = None # Create device map with proper Parameter handling using _cast # _cast ensures Parameters stay as Parameters (with requires_grad=False) device_map = {} for d in devices_set: if d != weights_device: # Move to device and apply _cast to preserve Parameter/Buffer types weights_on_device = weights.to(d) weights_on_device = weights_on_device.apply(_cast, weights) device_map[d] = weights_on_device else: # Already on correct device, just apply _cast device_map[d] = weights.apply(_cast, weights) # Create the map params_map = { worker_idx: device_map[device] for worker_idx, device in enumerate(devices) } return params_map if device_map_fn is not None: return { worker_idx: device_map_fn(worker_idx, weights) for worker_idx in range(num_workers) } raise ValueError( "Either params_map, model_id + context or model/weights + devices must be provided." ) def _init_on_receiver_impl( self, *, model_id: str | None = None, context: Any = None, model: Any = None, worker_idx: int | None = None, **kwargs, ) -> None: """Initialize on worker process (receiver side). Reads from the worker's dedicated queue to receive shared weights, then registers them in the transport. The receiver then applies these weights to the model. Args: model_id: Identifier for the model being synchronized context: Optional context object providing model and worker_idx model: Model being synchronized worker_idx: Worker index **kwargs: Alternative to context (model, worker_idx, timeout, etc.) """ # Extract parameters from context or kwargs if context is not None: if model_id is None: raise ValueError("model_id is required when context is provided") if hasattr(context, "get_model"): model = context.get_model(model_id) elif model is None: model = _resolve_model(context, model_id) worker_idx = getattr(context, "worker_idx", worker_idx) # Store on scheme directly self.model_id = model_id if context is not None: self.context = context # Register the model if model is not None: self.model = model # Store worker_idx for synchronize_weights self.worker_idx = worker_idx # Store references to instruction and ack queues for this worker # These are created by init_on_sender and passed via pickle if worker_idx is not None: if worker_idx in self._instruction_queues: self._receiver_instruction_queue = self._instruction_queues[worker_idx] if worker_idx in self._ack_queues: self._receiver_ack_queue = self._ack_queues[worker_idx] self.create_transport() def _wait_for_instruction(self, timeout: float | None = None) -> str | None: """Block until an instruction arrives from the sender. Args: timeout: Maximum time to wait for instruction (seconds). None means block indefinitely. Returns: The instruction string (e.g., "receive", "stop"), or None if stop event is set or timeout expires. """ if self._receiver_instruction_queue is None: raise RuntimeError( "Instruction queue not set. init_on_receiver() must be called first." ) try: # Check stop event periodically while waiting while True: if self._stop_event is not None and self._stop_event.is_set(): return None try: # Use short timeout to allow checking stop event instruction = self._receiver_instruction_queue.get(timeout=0.1) return instruction except Exception: # Queue.Empty - continue waiting if timeout is not None: timeout -= 0.1 if timeout <= 0: return None except Exception as e: torchrl_logger.warning(f"Error waiting for instruction: {e}") return None def _send_instruction( self, instruction: str = "receive", worker_ids: int | list[int] | None = None, ) -> None: """Send instruction to receiver(s) to trigger weight reception. Args: instruction: The instruction to send (default: "receive"). worker_ids: Which workers to send to (None = all workers). """ if not self._instruction_queues: raise RuntimeError( "Instruction queues not created. init_on_sender() must be called first." ) if worker_ids is None: target_workers = list(self._instruction_queues.keys()) elif isinstance(worker_ids, int): target_workers = [worker_ids] else: target_workers = list(worker_ids) for worker_idx in target_workers: if worker_idx not in self._instruction_queues: raise ValueError(f"Worker {worker_idx} not registered") self._instruction_queues[worker_idx].put(instruction) def _send_ack(self, message: str = "updated") -> None: """Send acknowledgment back to sender after receiving weights. Args: message: The acknowledgment message (default: "updated"). """ if self._receiver_ack_queue is not None: self._receiver_ack_queue.put(message) def _wait_for_ack( self, worker_ids: int | list[int] | None = None, timeout: float | None = None, ) -> None: """Wait for acknowledgment from receiver(s). Args: worker_ids: Which workers to wait for (None = all workers). timeout: Maximum time to wait (seconds). None means block indefinitely. """ if not self._ack_queues: return # No ack queues, nothing to wait for if worker_ids is None: target_workers = list(self._ack_queues.keys()) elif isinstance(worker_ids, int): target_workers = [worker_ids] else: target_workers = list(worker_ids) for worker_idx in target_workers: if worker_idx not in self._ack_queues: raise ValueError(f"Worker {worker_idx} not registered") try: ack = self._ack_queues[worker_idx].get(timeout=timeout) if ack != "updated": torchrl_logger.warning( f"Unexpected ack from worker {worker_idx}: {ack}" ) except Exception as e: torchrl_logger.warning( f"Timeout waiting for ack from worker {worker_idx}: {e}" )
[docs] def create_transport(self, **kwargs) -> TransportBackend: """Create shared memory transport. Returns the shared transport instance that all workers will use. Since this is shared memory, there's only one transport shared by all workers. Note: This is used internally by init_on_sender/init_on_receiver. """ return self.shared_transport
[docs] def prepare_weights( self, weights: Any, model_id: str, strategy: WeightStrategy, context: Any = None, ) -> Any: """Prepare weights for SharedMemWeightSyncScheme. When weights=None, we extract fresh weights from the model and update the shared memory buffer in-place so workers see the change. Args: weights: Raw weights input model_id: The model identifier strategy: WeightStrategy for extracting/converting weights context: Optional context (e.g., collector) for cache lookup Returns: Shared memory weights ready to send """ # If weights are explicitly provided, use them directly if weights is not None: fresh_weights = super().prepare_weights( weights, model_id, strategy, context ) else: # Extract fresh weights from the model (base class handles this) fresh_weights = super().prepare_weights(None, model_id, strategy, context) if fresh_weights is None: return None # Update the shared memory buffer in-place so workers see the change if self._shared_transport is not None and self.shared_transport.unique_weights: shared_weights = self.shared_transport.unique_weights[0] # In-place update of shared memory buffer with fresh weights shared_weights.data.update_(fresh_weights.data) return shared_weights # If no shared transport, just return the fresh weights return fresh_weights
[docs] def send( self, weights: Any = None, worker_ids: int | list[int] | None = None, ) -> None: """Send weights via shared memory (in-place update). For SharedMemWeightSyncScheme: 1. prepare_weights() updates the shared memory buffer in-place 2. _send_instruction() tells workers to apply the new weights 3. If sync=True, waits for acknowledgments from all workers Args: weights: Weights to send (can be None to extract from model). worker_ids: Which workers to notify (None = all workers). """ if not self.initialized_on_sender: raise RuntimeError("Must be initialized on sender before sending weights") if not self.synchronized_on_sender: raise RuntimeError("Must be synchronized on sender before sending weights") # prepare_weights updates the shared buffer in-place self.prepare_weights( weights=weights, model_id=self._model_id, strategy=self._strategy, context=self.context, ) # Send instruction to workers' background threads to apply the weights self._send_instruction(instruction="receive", worker_ids=worker_ids) # Wait for acknowledgments if in synchronous mode if self.sync: self._wait_for_ack(worker_ids=worker_ids)
@property def weights(self) -> Any | None: """Get the current weights from shared memory. For SharedMemWeightSyncScheme: - On sender side: weights are in transport's _unique_weights - On receiver side: weights are in _receiver_shared_weights (stored during connect()) Returns: The weights TensorDict if available, None otherwise. """ # On receiver side, use the stored shared buffer reference if ( hasattr(self, "_receiver_shared_weights") and self._receiver_shared_weights is not None ): return self._receiver_shared_weights # On sender side, get from the shared transport if self._shared_transport is not None and self.shared_transport.unique_weights: return self.shared_transport.unique_weights[0] # Fall back to parent implementation return super().weights def _setup_connection_and_weights_on_receiver_impl( self, *, worker_idx: int | None = None ) -> None: """Synchronize weights on receiver side for shared memory. Reads the shared memory buffer from the queue and applies it to the model. Then starts a background thread that listens for "receive" instructions from the sender and applies weights when instructed. If a receiver_transport is set (e.g., for MultiProcessWeightSyncScheme), defers to the base class implementation. """ # If receiver_transport is set (e.g., MultiProcess subclass), use base behavior if self._receiver_transport is not None: return super()._setup_connection_and_weights_on_receiver_impl( worker_idx=worker_idx ) # SharedMem-specific: use shared_transport if self._shared_transport is None: raise RuntimeError( "SharedMemWeightSyncScheme requires shared_transport to be set." ) # Use stored worker_idx if not provided if worker_idx is None: worker_idx = self.worker_idx if worker_idx is None: raise RuntimeError( "worker_idx must be provided for _setup_connection_and_weights_on_receiver_impl." ) # Read shared memory buffer from queue weights = self._shared_transport.setup_connection_and_weights_on_receiver( worker_idx=worker_idx ) # Store the shared buffer reference for later receive() calls # This is the actual shared memory buffer that the sender updates self._receiver_shared_weights = weights # Apply weights to model if weights is not None and self.model is not None: self._strategy.apply_weights(self.model, weights, inplace=False) # Start background receiver thread that listens for instructions self._start_background_receiver() def _background_receive_loop(self): """Background thread loop that waits for instructions and applies weights. This loop: 1. Waits for a "receive" instruction from the sender 2. Applies the current shared memory weights to the model 3. Sends an acknowledgment back to the sender 4. Repeats until stop event is set or "stop" instruction received """ while not self._stop_event.is_set(): try: instruction = self._wait_for_instruction() if instruction is None: # Stop event was set or timeout continue if instruction == "receive": # Apply the current shared memory weights to the model # The weights are already updated in shared memory by the sender if ( self._receiver_shared_weights is not None and self.model is not None ): self._strategy.apply_weights( self.model, self._receiver_shared_weights, inplace=True ) # Cascade weight update to sub-collectors if context supports it model_id = self._model_id or "policy" if self.context is not None and hasattr( self.context, "update_policy_weights_" ): self.context.update_policy_weights_( model_id=model_id, policy_or_weights=self._receiver_shared_weights, ) # Send acknowledgment self._send_ack("updated") elif instruction == "stop": break else: torchrl_logger.warning( f"SharedMemWeightSyncScheme: Unknown instruction: {instruction}" ) except Exception as e: if not self._stop_event.is_set(): torchrl_logger.warning( f"SharedMemWeightSyncScheme: Background receiver error: {e}" ) def __getstate__(self): """Prepare the scheme for pickling.""" state = super().__getstate__() # mp.Queue objects can be pickled and shared across processes # Keep them in state so workers have access return state
[docs] def shutdown(self) -> None: """Stop the background receiver thread and clean up.""" # Check if already shutdown if getattr(self, "_is_shutdown", False): return self._is_shutdown = True # Signal all workers to stop instruction_queues = getattr(self, "_instruction_queues", None) if instruction_queues: for _, queue in instruction_queues.items(): queue.put("stop") # Let base class handle background thread cleanup super().shutdown() # Close all multiprocessing queues created by the scheme. queues_to_close = [] for name in ("_weight_init_queues", "_instruction_queues", "_ack_queues"): mapping = getattr(self, name, None) if not mapping: continue queues_to_close.extend(mapping.values()) setattr(self, name, {}) unique = {} for q in queues_to_close: unique[id(q)] = q for q in unique.values(): _close_mp_queue(q)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources