Shortcuts

Source code for torchrl.envs.transforms.module

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from collections.abc import Callable
from contextlib import nullcontext
from typing import overload

import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModuleBase

from torchrl.data.tensor_specs import TensorSpec
from torchrl.envs.transforms.ray_service import _RayServiceMetaClass, RayTransform
from torchrl.envs.transforms.transforms import Transform


__all__ = ["ModuleTransform", "RayModuleTransform"]


class RayModuleTransform(RayTransform):
    """Ray-based ModuleTransform for distributed processing.

    This transform creates a Ray actor that wraps a ModuleTransform,
    allowing module execution in a separate Ray worker process.
    """

    def _create_actor(self, **kwargs):
        import ray

        remote = self._ray.remote(ModuleTransform)
        ray_kwargs = {}
        num_gpus = self._num_gpus
        if num_gpus is not None:
            ray_kwargs["num_gpus"] = num_gpus
        num_cpus = self._num_cpus
        if num_cpus is not None:
            ray_kwargs["num_cpus"] = num_cpus
        actor_name = self._actor_name
        if actor_name is not None:
            ray_kwargs["name"] = actor_name
        if ray_kwargs:
            remote = remote.options(**ray_kwargs)
        actor = remote.remote(**kwargs)
        # wait till the actor is ready
        ray.get(actor._ready.remote())
        return actor

    @overload
    def update_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
        ...

    @overload
    def update_weights(self, params: TensorDictBase) -> None:
        ...

    def update_weights(self, *args, **kwargs) -> None:
        import ray

        if self._update_weights_method == "tensordict":
            try:
                td = kwargs.get("params", args[0])
            except IndexError:
                raise ValueError("params must be provided")
            return ray.get(self._actor._update_weights_tensordict.remote(params=td))
        elif self._update_weights_method == "state_dict":
            try:
                state_dict = kwargs.get("state_dict", args[0])
            except IndexError:
                raise ValueError("state_dict must be provided")
            return ray.get(
                self._actor._update_weights_state_dict.remote(state_dict=state_dict)
            )
        else:
            raise ValueError(
                f"Invalid update_weights_method: {self._update_weights_method}"
            )


[docs]class ModuleTransform(Transform, metaclass=_RayServiceMetaClass): """A transform that wraps a module. Keyword Args: module (TensorDictModuleBase): The module to wrap. Exclusive with `module_factory`. At least one of `module` or `module_factory` must be provided. module_factory (Callable[[], TensorDictModuleBase]): The factory to create the module. Exclusive with `module`. At least one of `module` or `module_factory` must be provided. no_grad (bool, optional): Whether to use gradient computation. Default is `False`. inverse (bool, optional): Whether to use the inverse of the module. Default is `False`. device (torch.device, optional): The device to use. Default is `None`. use_ray_service (bool, optional): Whether to use Ray service. Default is `False`. num_gpus (int, optional): The number of GPUs to use if using Ray. Default is `None`. num_cpus (int, optional): The number of CPUs to use if using Ray. Default is `None`. actor_name (str, optional): The name of the actor to use. Default is `None`. If an actor name is provided and an actor with this name already exists, the existing actor will be used. observation_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the observation after it has been transformed by the module, or a function that modifies the existing spec. Defaults to `None` (observation specs remain unchanged). done_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the done after it has been transformed by the module, or a function that modifies the existing spec. Defaults to `None` (done specs remain unchanged). reward_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the reward after it has been transformed by the module, or a function that modifies the existing spec. Defaults to `None` (reward specs remain unchanged). state_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the state after it has been transformed by the module, or a function that modifies the existing spec. Defaults to `None` (state specs remain unchanged). action_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the action after it has been transformed by the module, or a function that modifies the existing spec. Defaults to `None` (action specs remain unchanged). """ _RayServiceClass = RayModuleTransform def __init__( self, *, module: TensorDictModuleBase | None = None, module_factory: Callable[[], TensorDictModuleBase] | None = None, no_grad: bool = False, inverse: bool = False, device: torch.device | None = None, use_ray_service: bool = False, # noqa actor_name: str | None = None, # noqa num_gpus: int | None = None, num_cpus: int | None = None, observation_spec_transform: TensorSpec | Callable[[TensorSpec], TensorSpec] | None = None, action_spec_transform: TensorSpec | Callable[[TensorSpec], TensorSpec] | None = None, reward_spec_transform: TensorSpec | Callable[[TensorSpec], TensorSpec] | None = None, done_spec_transform: TensorSpec | Callable[[TensorSpec], TensorSpec] | None = None, state_spec_transform: TensorSpec | Callable[[TensorSpec], TensorSpec] | None = None, ): super().__init__() if module is None and module_factory is None: raise ValueError( "At least one of `module` or `module_factory` must be provided." ) if module is not None and module_factory is not None: raise ValueError( "Only one of `module` or `module_factory` must be provided." ) self.module = module if module is not None else module_factory() self.no_grad = no_grad self.inverse = inverse self.device = device self.observation_spec_transform = observation_spec_transform self.action_spec_transform = action_spec_transform self.reward_spec_transform = reward_spec_transform self.done_spec_transform = done_spec_transform self.state_spec_transform = state_spec_transform @property def in_keys(self) -> list[str]: return self._in_keys() def _in_keys(self): return self.module.in_keys if not self.inverse else [] @in_keys.setter def in_keys(self, value: list[str] | None): if value is not None: raise RuntimeError(f"in_keys {value} cannot be set for ModuleTransform") @property def out_keys(self) -> list[str]: return self._out_keys() def _out_keys(self): return self.module.out_keys if not self.inverse else [] @property def in_keys_inv(self) -> list[str]: return self._in_keys_inv() def _in_keys_inv(self): return self.module.out_keys if self.inverse else [] @in_keys_inv.setter def in_keys_inv(self, value: list[str]): if value is not None: raise RuntimeError(f"in_keys_inv {value} cannot be set for ModuleTransform") @property def out_keys_inv(self) -> list[str]: return self._out_keys_inv() def _out_keys_inv(self): return self.module.in_keys if self.inverse else [] @out_keys_inv.setter def out_keys_inv(self, value: list[str] | None): if value is not None: raise RuntimeError( f"out_keys_inv {value} cannot be set for ModuleTransform" ) @out_keys.setter def out_keys(self, value: list[str] | None): if value is not None: raise RuntimeError(f"out_keys {value} cannot be set for ModuleTransform")
[docs] def forward(self, tensordict: TensorDictBase) -> TensorDictBase: return self._call(tensordict)
def _call(self, tensordict: TensorDictBase) -> TensorDictBase: if self.inverse: return tensordict with torch.no_grad() if self.no_grad else nullcontext(): with ( tensordict.to(self.device) if self.device is not None else nullcontext(tensordict) ) as td: return self.module(td) def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: if not self.inverse: return tensordict with torch.no_grad() if self.no_grad else nullcontext(): with ( tensordict.to(self.device) if self.device is not None else nullcontext(tensordict) ) as td: return self.module(td) def _update_weights_tensordict(self, params: TensorDictBase) -> None: params.to_module(self.module) def _update_weights_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None: self.module.load_state_dict(state_dict)
[docs] def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: if self.observation_spec_transform is not None: if isinstance(self.observation_spec_transform, TensorSpec): return self.observation_spec_transform else: return self.observation_spec_transform(observation_spec) return observation_spec
[docs] def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec: if self.action_spec_transform is not None: if isinstance(self.action_spec_transform, TensorSpec): return self.action_spec_transform else: return self.action_spec_transform(action_spec) return action_spec
[docs] def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: if self.reward_spec_transform is not None: if isinstance(self.reward_spec_transform, TensorSpec): return self.reward_spec_transform else: return self.reward_spec_transform(reward_spec) return reward_spec
[docs] def transform_done_spec(self, done_spec: TensorSpec) -> TensorSpec: if self.done_spec_transform is not None: if isinstance(self.done_spec_transform, TensorSpec): return self.done_spec_transform else: return self.done_spec_transform(done_spec) return done_spec
[docs] def transform_state_spec(self, state_spec: TensorSpec) -> TensorSpec: if self.state_spec_transform is not None: if isinstance(self.state_spec_transform, TensorSpec): return self.state_spec_transform else: return self.state_spec_transform(state_spec) return state_spec

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources