Source code for torchrl.collectors._base
from __future__ import annotations
import abc
import contextlib
import functools
import typing
import warnings
from collections import OrderedDict
from collections.abc import Callable, Iterator
from copy import deepcopy
from typing import Any, overload
import torch
from tensordict import TensorDict, TensorDictBase
from tensordict.base import NO_DEFAULT
from tensordict.nn import TensorDictModule, TensorDictModuleBase
from torch import nn as nn
from torch.utils.data import IterableDataset
from torchrl._utils import logger as torchrl_logger
from torchrl.collectors.utils import _map_weight
from torchrl.collectors.weight_update import WeightUpdaterBase
from torchrl.weight_update.utils import _resolve_attr
from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme
[docs]class BaseCollector(IterableDataset, metaclass=abc.ABCMeta):
"""Base class for data collectors."""
_task = None
_iterator = None
total_frames: int
requested_frames_per_batch: int
frames_per_batch: int
trust_policy: bool
compiled_policy: bool
cudagraphed_policy: bool
_weight_updater: WeightUpdaterBase | None = None
_weight_sync_schemes: dict[str, WeightSyncScheme] | None = None
verbose: bool = False
@property
def weight_updater(self) -> WeightUpdaterBase:
return self._weight_updater
@weight_updater.setter
def weight_updater(self, value: WeightUpdaterBase | None):
if value is not None:
if not isinstance(value, WeightUpdaterBase) and callable(
value
): # Fall back to default constructor
value = value()
value.register_collector(self)
if value.collector is not self:
raise RuntimeError("Failed to register collector.")
self._weight_updater = value
@property
def worker_idx(self) -> int | None:
"""Get the worker index for this collector.
Returns:
The worker index (0-indexed).
Raises:
RuntimeError: If worker_idx has not been set.
"""
if not hasattr(self, "_worker_idx"):
raise RuntimeError(
"worker_idx has not been set. This collector may not have been "
"initialized as a worker in a distributed setup."
)
return self._worker_idx
@worker_idx.setter
def worker_idx(self, value: int | None) -> None:
"""Set the worker index for this collector.
Args:
value: The worker index (0-indexed) or None.
"""
self._worker_idx = value
[docs] def cascade_execute(self, attr_path: str, *args, **kwargs) -> Any:
"""Execute a method on a nested attribute of this collector.
This method allows remote callers to invoke methods on nested attributes
of the collector without needing to know the full structure. It's particularly
useful for calling methods on weight sync schemes from the sender side.
Args:
attr_path: Full path to the callable, e.g.,
"_receiver_schemes['model_id']._set_dist_connection_info"
*args: Positional arguments to pass to the method.
**kwargs: Keyword arguments to pass to the method.
Returns:
The return value of the method call.
Examples:
>>> collector.cascade_execute(
... "_receiver_schemes['policy']._set_dist_connection_info",
... connection_info_ref,
... worker_idx=0
... )
"""
attr = _resolve_attr(self, attr_path)
if callable(attr):
return attr(*args, **kwargs)
else:
if args or kwargs:
raise ValueError(
f"Arguments and keyword arguments are not supported for non-callable attributes. Got {args} and {kwargs} for {attr_path}"
)
return attr
def _get_policy_and_device(
self,
policy: Callable[[Any], Any] | None = None,
policy_device: Any = NO_DEFAULT,
env_maker: Any | None = None,
env_maker_kwargs: dict[str, Any] | None = None,
) -> tuple[TensorDictModule, None | Callable[[], dict]]:
"""Util method to get a policy and its device given the collector __init__ inputs.
We want to copy the policy and then move the data there, not call policy.to(device).
Args:
policy (TensorDictModule, optional): a policy to be used
policy_device (torch.device, optional): the device where the policy should be placed.
Defaults to self.policy_device
env_maker (a callable or a batched env, optional): the env_maker function for this device/policy pair.
env_maker_kwargs (a dict, optional): the env_maker function kwargs.
"""
if policy_device is NO_DEFAULT:
policy_device = self.policy_device
if not policy_device:
return policy, None
if isinstance(policy, nn.Module):
param_and_buf = TensorDict.from_module(policy, as_module=True)
else:
# Because we want to reach the warning
param_and_buf = TensorDict()
i = -1
for p in param_and_buf.values(True, True):
i += 1
if p.device != policy_device:
# Then we need casting
break
else:
if i == -1 and not self.trust_policy:
# We trust that the policy policy device is adequate
warnings.warn(
"A policy device was provided but no parameter/buffer could be found in "
"the policy. Casting to policy_device is therefore impossible. "
"The collector will trust that the devices match. To suppress this "
"warning, set `trust_policy=True` when building the collector."
)
return policy, None
# Create a stateless policy, then populate this copy with params on device
def get_original_weights(policy=policy):
td = TensorDict.from_module(policy)
return td.data
# We need to use ".data" otherwise buffers may disappear from the `get_original_weights` function
with param_and_buf.data.to("meta").to_module(policy):
policy_new_device = deepcopy(policy)
param_and_buf_new_device = param_and_buf.apply(
functools.partial(_map_weight, policy_device=policy_device),
filter_empty=False,
)
param_and_buf_new_device.to_module(policy_new_device)
# Sanity check
if set(TensorDict.from_module(policy_new_device).keys(True, True)) != set(
get_original_weights().keys(True, True)
):
raise RuntimeError("Failed to map weights. The weight sets mismatch.")
return policy_new_device, get_original_weights
[docs] def start(self):
"""Starts the collector for asynchronous data collection.
This method initiates the background collection of data, allowing for decoupling of data collection and training.
The collected data is typically stored in a replay buffer passed during the collector's initialization.
.. note:: After calling this method, it's essential to shut down the collector using :meth:`~.async_shutdown`
when you're done with it to free up resources.
.. warning:: Asynchronous data collection can significantly impact training performance due to its decoupled nature.
Ensure you understand the implications for your specific algorithm before using this mode.
Raises:
NotImplementedError: If not implemented by a subclass.
"""
raise NotImplementedError(
f"Collector start() is not implemented for {type(self).__name__}."
)
[docs] @contextlib.contextmanager
def pause(self):
"""Context manager that pauses the collector if it is running free."""
raise NotImplementedError(
f"Collector pause() is not implemented for {type(self).__name__}."
)
[docs] def async_shutdown(
self, timeout: float | None = None, close_env: bool = True
) -> None:
"""Shuts down the collector when started asynchronously with the `start` method.
Args:
timeout (float, optional): The maximum time to wait for the collector to shutdown.
close_env (bool, optional): If True, the collector will close the contained environment.
Defaults to `True`.
.. seealso:: :meth:`~.start`
"""
return self.shutdown(timeout=timeout, close_env=close_env)
def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any:
"""Extract weights from a model if needed.
For the new weight sync scheme system, weight preparation is handled
by the scheme's prepare_weights() method. This method now only handles
legacy weight updater cases.
Args:
weights: Either already-extracted weights or a model to extract from.
model_id: The model identifier for resolving string paths.
Returns:
Extracted weights in the appropriate format.
"""
# New weight sync schemes handle preparation themselves
if self._weight_sync_schemes:
# Just pass through - WeightSender will call scheme.prepare_weights()
return weights
# Legacy weight updater path
return self._legacy_extract_weights(weights, model_id)
def _legacy_extract_weights(self, weights: Any, model_id: str) -> Any:
"""Legacy weight extraction for old weight updater system.
Args:
weights: Either already-extracted weights or a model to extract from.
model_id: The model identifier.
Returns:
Extracted weights.
"""
if weights is None:
if model_id == "policy" and hasattr(self, "policy_weights"):
return self.policy_weights
elif model_id == "policy" and hasattr(self, "_policy_weights_dict"):
policy_device = (
self.policy_device
if not isinstance(self.policy_device, (list, tuple))
else self.policy_device[0]
)
return self._policy_weights_dict.get(policy_device)
return None
return weights
@property
def _legacy_weight_updater(self) -> bool:
return self._weight_updater is not None
# Overloads for update_policy_weights_ to support multiple calling conventions
@overload
def update_policy_weights_(
self,
policy_or_weights: TensorDictBase | TensorDictModuleBase | nn.Module | dict,
/,
) -> None:
...
@overload
def update_policy_weights_(
self,
policy_or_weights: TensorDictBase | TensorDictModuleBase | nn.Module | dict,
/,
*,
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
model_id: str | None = None,
) -> None:
...
@overload
def update_policy_weights_(
self,
*,
weights: TensorDictBase | dict,
model_id: str | None = None,
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
) -> None:
...
@overload
def update_policy_weights_(
self,
*,
policy: TensorDictModuleBase | nn.Module,
model_id: str | None = None,
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
) -> None:
...
@overload
def update_policy_weights_(
self,
*,
weights_dict: dict[
str, TensorDictBase | TensorDictModuleBase | nn.Module | dict
],
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
) -> None:
...
[docs] def update_policy_weights_(
self,
policy_or_weights: TensorDictBase
| TensorDictModuleBase
| nn.Module
| dict
| None = None,
*,
weights: TensorDictBase | dict | None = None,
policy: TensorDictModuleBase | nn.Module | None = None,
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
model_id: str | None = None,
weights_dict: dict[str, Any] | None = None,
**kwargs,
) -> None:
"""Update policy weights for the data collector.
This method synchronizes the policy weights used by the collector with the latest
trained weights. It supports both local and remote weight updates, depending on
the collector configuration.
The method accepts weights in multiple forms for convenience:
Examples:
>>> # Pass policy module as positional argument
>>> collector.update_policy_weights_(policy_module)
>>>
>>> # Pass TensorDict weights as positional argument
>>> collector.update_policy_weights_(weights_tensordict)
>>>
>>> # Use keyword arguments for clarity
>>> collector.update_policy_weights_(weights=weights_td, model_id="actor")
>>> collector.update_policy_weights_(policy=actor_module, model_id="actor")
>>>
>>> # Update multiple models atomically
>>> collector.update_policy_weights_(weights_dict={
... "actor": actor_weights,
... "critic": critic_weights,
... })
Args:
policy_or_weights: The weights to update with. Can be:
- ``nn.Module``: A policy module whose weights will be extracted
- ``TensorDictModuleBase``: A TensorDict module whose weights will be extracted
- ``TensorDictBase``: A TensorDict containing weights
- ``dict``: A regular dict containing weights
- ``None``: Will try to get weights from server using ``_get_server_weights()``
Keyword Args:
weights: Alternative to positional argument. A TensorDict or dict containing
weights to update. Cannot be used together with ``policy_or_weights`` or ``policy``.
policy: Alternative to positional argument. An ``nn.Module`` or ``TensorDictModuleBase``
whose weights will be extracted. Cannot be used together with ``policy_or_weights``
or ``weights``.
worker_ids: Identifiers for the workers to update. Relevant when the collector
has multiple workers. Can be int, list of ints, device, or list of devices.
model_id: The model identifier to update (default: ``"policy"``).
Cannot be used together with ``weights_dict``.
weights_dict: Dictionary mapping model_id to weights for updating
multiple models atomically. Keys should match model_ids registered in
``weight_sync_schemes``. Cannot be used together with ``model_id``,
``policy_or_weights``, ``weights``, or ``policy``.
Raises:
TypeError: If ``worker_ids`` is provided but no ``weight_updater`` is configured.
ValueError: If conflicting parameters are provided.
.. note:: Users should extend the ``WeightUpdaterBase`` classes to customize
the weight update logic for specific use cases.
.. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and
:meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`.
"""
# Handle the different keyword argument forms
if weights is not None:
if policy_or_weights is not None:
raise ValueError(
"Cannot specify both positional 'policy_or_weights' and keyword 'weights'"
)
if policy is not None:
raise ValueError("Cannot specify both 'weights' and 'policy'")
policy_or_weights = weights
if policy is not None:
if policy_or_weights is not None:
raise ValueError(
"Cannot specify both positional 'policy_or_weights' and keyword 'policy'"
)
policy_or_weights = policy
if self._legacy_weight_updater:
return self._legacy_weight_update_impl(
policy_or_weights=policy_or_weights,
worker_ids=worker_ids,
model_id=model_id,
weights_dict=weights_dict,
**kwargs,
)
else:
return self._weight_update_impl(
policy_or_weights=policy_or_weights,
worker_ids=worker_ids,
model_id=model_id,
weights_dict=weights_dict,
**kwargs,
)
def _legacy_weight_update_impl(
self,
policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
*,
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
model_id: str | None = None,
weights_dict: dict[str, Any] | None = None,
**kwargs,
) -> None:
if weights_dict is not None:
raise ValueError("weights_dict is not supported with legacy weight updater")
if model_id is not None:
raise ValueError("model_id is not supported with legacy weight updater")
# Fall back to old weight updater system
self.weight_updater(
policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
)
def _weight_update_impl(
self,
policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
*,
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
model_id: str | None = None,
weights_dict: dict[str, Any] | None = None,
**kwargs,
) -> None:
if "policy_weights" in kwargs:
warnings.warn(
"`policy_weights` is deprecated. Use `policy_or_weights` instead.",
DeprecationWarning,
)
policy_or_weights = kwargs.pop("policy_weights")
if weights_dict is not None and model_id is not None:
raise ValueError("Cannot specify both 'weights_dict' and 'model_id'")
if weights_dict is not None and policy_or_weights is not None:
raise ValueError(
"Cannot specify both 'weights_dict' and 'policy_or_weights'"
)
if self._weight_sync_schemes:
if model_id is None:
model_id = "policy"
if policy_or_weights is not None and weights_dict is None:
# Use model_id as the key, not hardcoded "policy"
weights_dict = {model_id: policy_or_weights}
elif weights_dict is None:
weights_dict = {model_id: policy_or_weights}
torchrl_logger.debug(
f"Calling weight update with {model_id=} and {weights_dict.keys()=}"
)
for target_model_id, weights in weights_dict.items():
if target_model_id not in self._weight_sync_schemes:
raise KeyError(
f"Model '{target_model_id}' not found in registered weight sync schemes. "
f"Available models: {list(self._weight_sync_schemes.keys())}"
)
processed_weights = self._extract_weights_if_needed(
weights, target_model_id
)
# Use new send() API with worker_ids support
torchrl_logger.debug("weight update -- getting scheme")
scheme = self._weight_sync_schemes.get(target_model_id)
if not isinstance(scheme, WeightSyncScheme):
raise TypeError(f"Expected WeightSyncScheme, got {target_model_id}")
torchrl_logger.debug(
f"calling send() on scheme {type(scheme).__name__}"
)
self._send_weights_scheme(
scheme=scheme,
processed_weights=processed_weights,
worker_ids=worker_ids,
model_id=target_model_id,
)
elif self._weight_updater is not None:
# unreachable
raise RuntimeError
else:
# No weight updater configured, try fallback
torchrl_logger.debug("No weight update configured, trying fallback.")
self._maybe_fallback_update(policy_or_weights, model_id=model_id)
def _maybe_fallback_update(
self,
policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
*,
model_id: str | None = None,
) -> None:
"""Fallback weight update when no scheme is configured.
Override in subclasses to provide custom fallback behavior.
By default, this is a no-op.
"""
def _send_weights_scheme(self, *, model_id, scheme, processed_weights, worker_ids):
# method to override if the scheme requires an RPC call to receive the weights
scheme.send(weights=processed_weights, worker_ids=worker_ids)
def _receive_weights_scheme(self):
"""Receive weights for all registered receiver schemes.
scheme.receive() handles both applying weights locally and cascading
to sub-collectors via context.update_policy_weights_().
"""
if not hasattr(self, "_receiver_schemes"):
raise RuntimeError("No receiver schemes registered.")
for model_id, scheme in self._receiver_schemes.items():
torchrl_logger.debug(
f"Receiving weights for scheme {type(scheme).__name__} for model '{model_id}' on worker {self._worker_idx}"
)
received_weights = scheme.receive()
torchrl_logger.debug(f"Received weights: {type(received_weights)=}")
# Overloads for receive_weights to support multiple calling conventions
@overload
def receive_weights(self) -> None:
...
@overload
def receive_weights(
self,
policy_or_weights: TensorDictBase | TensorDictModuleBase | nn.Module | dict,
/,
) -> None:
...
@overload
def receive_weights(
self,
*,
weights: TensorDictBase | dict,
) -> None:
...
@overload
def receive_weights(
self,
*,
policy: TensorDictModuleBase | nn.Module,
) -> None:
...
[docs] def receive_weights(
self,
policy_or_weights: TensorDictBase
| TensorDictModuleBase
| nn.Module
| dict
| None = None,
*,
weights: TensorDictBase | dict | None = None,
policy: TensorDictModuleBase | nn.Module | None = None,
) -> None:
"""Receive and apply weights to the collector's policy.
This method applies weights to the local policy. When receiver schemes are
registered, it delegates to those schemes. Otherwise, it directly applies
the provided weights.
The method accepts weights in multiple forms for convenience:
Examples:
>>> # Receive from registered schemes (distributed collectors)
>>> collector.receive_weights()
>>>
>>> # Apply weights from a policy module (positional)
>>> collector.receive_weights(trained_policy)
>>>
>>> # Apply weights from a TensorDict (positional)
>>> collector.receive_weights(weights_tensordict)
>>>
>>> # Use keyword arguments for clarity
>>> collector.receive_weights(weights=weights_td)
>>> collector.receive_weights(policy=trained_policy)
Args:
policy_or_weights: The weights to apply. Can be:
- ``nn.Module``: A policy module whose weights will be extracted and applied
- ``TensorDictModuleBase``: A TensorDict module whose weights will be extracted
- ``TensorDictBase``: A TensorDict containing weights
- ``dict``: A regular dict containing weights
- ``None``: Receive from registered schemes or mirror from original policy
Keyword Args:
weights: Alternative to positional argument. A TensorDict or dict containing
weights to apply. Cannot be used together with ``policy_or_weights`` or ``policy``.
policy: Alternative to positional argument. An ``nn.Module`` or ``TensorDictModuleBase``
whose weights will be extracted. Cannot be used together with ``policy_or_weights``
or ``weights``.
Raises:
ValueError: If conflicting parameters are provided or if arguments are passed
when receiver schemes are registered.
"""
# Handle the different keyword argument forms
if weights is not None:
if policy_or_weights is not None:
raise ValueError(
"Cannot specify both positional 'policy_or_weights' and keyword 'weights'"
)
if policy is not None:
raise ValueError("Cannot specify both 'weights' and 'policy'")
policy_or_weights = weights
if policy is not None:
if policy_or_weights is not None:
raise ValueError(
"Cannot specify both positional 'policy_or_weights' and keyword 'policy'"
)
policy_or_weights = policy
if getattr(self, "_receiver_schemes", None) is not None:
if policy_or_weights is not None:
raise ValueError(
"Cannot specify 'policy_or_weights' when using 'receiver_schemes'. Schemes should know how to get the weights."
)
self._receive_weights_scheme()
return
# No weight updater configured
# For single-process collectors, apply weights locally if explicitly provided
if policy_or_weights is not None:
from torchrl.weight_update.weight_sync_schemes import WeightStrategy
# Use WeightStrategy to apply weights properly
strategy = WeightStrategy(extract_as="tensordict")
# Extract weights if needed
if isinstance(policy_or_weights, nn.Module):
weights = strategy.extract_weights(policy_or_weights)
else:
weights = policy_or_weights
# Apply to local policy
if hasattr(self, "policy") and isinstance(self.policy, nn.Module):
strategy.apply_weights(self.policy, weights)
# Otherwise, no action needed - policy is local and changes are immediately visible
[docs] def register_scheme_receiver(
self,
weight_recv_schemes: dict[str, WeightSyncScheme],
*,
synchronize_weights: bool = True,
): # noqa: D417
"""Set up receiver schemes for this collector to receive weights from parent collectors.
This method initializes receiver schemes and stores them in _receiver_schemes
for later use by _receive_weights_scheme() and receive_weights().
Receiver schemes enable cascading weight updates across collector hierarchies:
- Parent collector sends weights via its weight_sync_schemes (senders)
- Child collector receives weights via its weight_recv_schemes (receivers)
- If child is also a parent (intermediate node), it can propagate to its own children
Args:
weight_recv_schemes (dict[str, WeightSyncScheme]): Dictionary of {model_id: WeightSyncScheme} to set up as receivers.
These schemes will receive weights from parent collectors.
Keyword Args:
synchronize_weights (bool, optional): If True, synchronize weights immediately after registering the schemes.
Defaults to `True`.
"""
# Initialize _receiver_schemes if not already present
if not hasattr(self, "_receiver_schemes"):
self._receiver_schemes = {}
# Initialize each scheme on the receiver side
for model_id, scheme in weight_recv_schemes.items():
if not scheme.initialized_on_receiver:
if scheme.initialized_on_sender:
raise RuntimeError(
"Weight sync scheme cannot be initialized on both sender and receiver."
)
scheme.init_on_receiver(
model_id=model_id,
context=self,
worker_idx=self.worker_idx,
)
# Store the scheme for later use in receive_weights()
self._receiver_schemes[model_id] = scheme
# Perform initial synchronization
if synchronize_weights:
for model_id, scheme in weight_recv_schemes.items():
if not scheme.synchronized_on_receiver:
torchrl_logger.debug(
f"Synchronizing weights for scheme {type(scheme).__name__} for model '{model_id}'"
)
scheme.connect(worker_idx=self.worker_idx)
def __iter__(self) -> Iterator[TensorDictBase]:
try:
yield from self.iterator()
except Exception:
self.shutdown()
raise
def next(self):
try:
if self._iterator is None:
self._iterator = iter(self)
out = next(self._iterator)
# if any, we don't want the device ref to be passed in distributed settings
if out is not None and (out.device != "cpu"):
out = out.copy().clear_device_()
return out
except StopIteration:
return None
@abc.abstractmethod
def shutdown(
self,
timeout: float | None = None,
close_env: bool = True,
raise_on_error: bool = True,
) -> None:
raise NotImplementedError
@abc.abstractmethod
def iterator(self) -> Iterator[TensorDictBase]:
raise NotImplementedError
@abc.abstractmethod
def set_seed(self, seed: int, static_seed: bool = False) -> int:
raise NotImplementedError
@abc.abstractmethod
def state_dict(self) -> OrderedDict:
raise NotImplementedError
@abc.abstractmethod
def load_state_dict(self, state_dict: OrderedDict) -> None:
raise NotImplementedError
def _read_compile_kwargs(self, compile_policy, cudagraph_policy):
self.compiled_policy = compile_policy not in (False, None)
self.cudagraphed_policy = cudagraph_policy not in (False, None)
self.compiled_policy_kwargs = (
{} if not isinstance(compile_policy, typing.Mapping) else compile_policy
)
self.cudagraphed_policy_kwargs = (
{} if not isinstance(cudagraph_policy, typing.Mapping) else cudagraph_policy
)
def __repr__(self) -> str:
string = f"{self.__class__.__name__}()"
return string
def __class_getitem__(self, index):
raise NotImplementedError
def __len__(self) -> int:
if self.total_frames > 0:
return -(self.total_frames // -self.requested_frames_per_batch)
raise RuntimeError("Non-terminating collectors do not have a length")
[docs] def init_updater(self, *args, **kwargs):
"""Initialize the weight updater with custom arguments.
This method passes the arguments to the weight updater's init method.
If no weight updater is set, this is a no-op.
Args:
*args: Positional arguments for weight updater initialization
**kwargs: Keyword arguments for weight updater initialization
"""
if self.weight_updater is not None:
self.weight_updater.init(*args, **kwargs)
def _make_legacy_metaclass(parent_metaclass):
"""Create a legacy metaclass for deprecated collector names.
This factory creates a metaclass that inherits from the given parent metaclass
to avoid metaclass conflicts.
"""
class _LegacyMeta(parent_metaclass):
"""Metaclass for deprecated collector class names.
Raises a deprecation warning when the old class name is instantiated,
and ensures isinstance() checks work for both old and new names.
"""
def __call__(cls, *args, **kwargs):
warnings.warn(
f"{cls.__name__} has been deprecated and will be removed in v0.13. "
f"Please use {cls.__bases__[0].__name__} instead.",
category=DeprecationWarning,
)
return super().__call__(*args, **kwargs)
def __instancecheck__(cls, instance):
if super().__instancecheck__(instance):
return True
parent_cls = cls.__bases__[0]
return isinstance(instance, parent_cls)
return _LegacyMeta
# Default legacy metaclass for classes with abc.ABCMeta
_LegacyCollectorMeta = _make_legacy_metaclass(abc.ABCMeta)
class DataCollectorBase(BaseCollector, metaclass=_LegacyCollectorMeta):
"""Deprecated version of :class:`~torchrl.collectors.BaseCollector`."""
...