Shortcuts

Source code for torchrl.envs.transforms.vecnorm

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import math
import uuid
import warnings
from copy import copy

from typing import Any, OrderedDict, Sequence

import torch
from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key
from tensordict.utils import _zip_strict
from torch import multiprocessing as mp
from torchrl.data.tensor_specs import Bounded, Composite, Unbounded

from torchrl.envs.common import EnvBase
from torchrl.envs.transforms.transforms import Compose, ObservationNorm, Transform

from torchrl.envs.transforms.utils import _set_missing_tolerance


[docs]class VecNormV2(Transform): """A class for normalizing vectorized observations and rewards in reinforcement learning environments. `VecNormV2` can operate in either a stateful or stateless mode. In stateful mode, it maintains internal statistics (mean and variance) to normalize inputs. In stateless mode, it requires external statistics to be provided for normalization. .. note:: This class is designed to be an almost drop-in replacement for :class:`~torchrl.envs.transforms.VecNorm`. It should not be constructed directly, but rather with the :class:`~torchrl.envs.transforms.VecNorm` transform using the `new_api=True` keyword argument. In v0.10, the :class:`~torchrl.envs.transforms.VecNorm` transform will be switched to the new api by default. Stateful vs. Stateless: Stateful Mode (`stateful=True`): - Maintains internal statistics (`loc`, `var`, `count`) for normalization. - Updates statistics with each call unless frozen. - `state_dict` returns the current statistics. - `load_state_dict` updates the internal statistics with the provided state. Stateless Mode (`stateful=False`): - Requires external statistics to be provided for normalization. - Does not maintain or update internal statistics. - `state_dict` returns an empty dictionary. - `load_state_dict` does not affect internal state. Args: in_keys (Sequence[NestedKey]): The input keys for the data to be normalized. out_keys (Sequence[NestedKey] | None): The output keys for the normalized data. Defaults to `in_keys` if not provided. lock (mp.Lock, optional): A lock for thread safety. stateful (bool, optional): Whether the `VecNorm` is stateful. Stateless versions of this transform requires the data to be carried within the input/output tensordicts. Defaults to `True`. decay (float, optional): The decay rate for updating statistics. Defaults to `0.9999`. If `decay=1` is used, the normalizing statistics have an infinite memory (each item is weighed identically). Lower values weigh recent data more than old ones. eps (float, optional): A small value to prevent division by zero. Defaults to `1e-4`. shared_data (TensorDictBase | None, optional): Shared data for initialization. Defaults to `None`. reduce_batch_dims (bool, optional): If `True`, the batch dimensions are reduced by averaging the data before updating the statistics. This is useful when samples are received in batches, as it allows the moving average to be computed over the entire batch rather than individual elements. Note that this option is only supported in stateful mode (`stateful=True`). Defaults to `False`. Attributes: stateful (bool): Indicates whether the VecNormV2 is stateful or stateless. lock (mp.Lock): A multiprocessing lock to ensure thread safety when updating statistics. decay (float): The decay rate for updating statistics. eps (float): A small value to prevent division by zero during normalization. frozen (bool): Indicates whether the VecNormV2 is frozen, preventing updates to statistics. _cast_int_to_float (bool): Indicates whether integer inputs should be cast to float. Methods: freeze(): Freezes the VecNorm, preventing updates to statistics. unfreeze(): Unfreezes the VecNorm, allowing updates to statistics. frozen_copy(): Returns a frozen copy of the VecNorm. clone(): Returns a clone of the VecNorm. transform_observation_spec(observation_spec): Transforms the observation specification. transform_reward_spec(reward_spec, observation_spec): Transforms the reward specification. transform_output_spec(output_spec): Transforms the output specification. to_observation_norm(): Converts the VecNorm to an ObservationNorm transform. set_extra_state(state): Sets the extra state for the VecNorm. get_extra_state(): Gets the extra state of the VecNorm. loc: Returns the location (mean) for normalization. scale: Returns the scale (standard deviation) for normalization. standard_normal: Indicates whether the normalization follows the standard normal distribution. State Dict Behavior: - In stateful mode, `state_dict` returns a dictionary containing the current `loc`, `var`, and `count`. These can be used to share the tensors across processes (this method is automatically triggered by :class:`~torchrl.envs.VecNorm` to share the VecNorm states across processes). - In stateless mode, `state_dict` returns an empty dictionary as no internal state is maintained. Load State Dict Behavior: - In stateful mode, `load_state_dict` updates the internal `loc`, `var`, and `count` with the provided state. - In stateless mode, `load_state_dict` does not modify any internal state as there is none to update. .. seealso:: :class:`~torchrl.envs.transforms.VecNorm` for the first version of this transform. Examples: >>> import torch >>> from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, SerialEnv, VecNormV2 >>> >>> torch.manual_seed(0) >>> env = GymEnv("Pendulum-v1") >>> env_trsf = env.append_transform( >>> VecNormV2(in_keys=["observation", "reward"], out_keys=["observation_norm", "reward_norm"]) >>> ) >>> r = env_trsf.rollout(10) >>> print("Unnormalized rewards", r["next", "reward"]) Unnormalized rewards tensor([[ -1.7967], [ -2.1238], [ -2.5911], [ -3.5275], [ -4.8585], [ -6.5028], [ -8.2505], [-10.3169], [-12.1332], [-13.1235]]) >>> print("Normalized rewards", r["next", "reward_norm"]) Normalized rewards tensor([[-1.6596e-04], [-8.3072e-02], [-1.9170e-01], [-3.9255e-01], [-5.9131e-01], [-7.4671e-01], [-8.3760e-01], [-9.2058e-01], [-9.3484e-01], [-8.6185e-01]]) >>> # Aggregate values when using batched envs >>> env = SerialEnv(2, [lambda: GymEnv("Pendulum-v1")] * 2) >>> env_trsf = env.append_transform( >>> VecNormV2( >>> in_keys=["observation", "reward"], >>> out_keys=["observation_norm", "reward_norm"], >>> # Use reduce_batch_dims=True to aggregate values across batch elements >>> reduce_batch_dims=True, ) >>> ) >>> r = env_trsf.rollout(10) >>> print("Unnormalized rewards", r["next", "reward"]) Unnormalized rewards tensor([[[-0.1456], [-0.1862], [-0.2053], [-0.2605], [-0.4046], [-0.5185], [-0.8023], [-1.1364], [-1.6183], [-2.5406]], [[-0.0920], [-0.1492], [-0.2702], [-0.3917], [-0.5001], [-0.7947], [-1.0160], [-1.3347], [-1.9082], [-2.9679]]]) >>> print("Normalized rewards", r["next", "reward_norm"]) Normalized rewards tensor([[[-0.2199], [-0.2918], [-0.1668], [-0.2083], [-0.4981], [-0.5046], [-0.7950], [-0.9791], [-1.1484], [-1.4182]], [[ 0.2201], [-0.0403], [-0.5206], [-0.7791], [-0.8282], [-1.2306], [-1.2279], [-1.2907], [-1.4929], [-1.7793]]]) >>> print("Loc / scale", env_trsf.transform.loc["reward"], env_trsf.transform.scale["reward"]) Loc / scale tensor([-0.8626]) tensor([1.1832]) >>> >>> # Share values between workers >>> def make_env(): ... env = GymEnv("Pendulum-v1") ... env_trsf = env.append_transform( ... VecNormV2(in_keys=["observation", "reward"], out_keys=["observation_norm", "reward_norm"]) ... ) ... return env_trsf ... ... >>> if __name__ == "__main__": ... # EnvCreator will share the loc/scale vals ... make_env = EnvCreator(make_env) ... # Create a local env to track the loc/scale ... local_env = make_env() ... env = ParallelEnv(2, [make_env] * 2) ... r = env.rollout(10) ... # Non-zero loc and scale testify that the sub-envs share their summary stats with us ... print("Remotely updated loc / scale", local_env.transform.loc["reward"], local_env.transform.scale["reward"]) Remotely updated loc / scale tensor([-0.4307]) tensor([0.9613]) ... env.close() """ # TODO: # - test 2 different vecnorms, one for reward one for obs and that they don't collide # - test that collision is spotted # - customize the vecnorm keys in stateless def __init__( self, in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey] | None = None, *, lock: mp.Lock = None, stateful: bool = True, decay: float = 0.9999, eps: float = 1e-4, shared_data: TensorDictBase | None = None, reduce_batch_dims: bool = False, ) -> None: self.stateful = stateful if lock is None: lock = mp.Lock() if out_keys is None: out_keys = copy(in_keys) super().__init__(in_keys=in_keys, out_keys=out_keys) self.lock = lock self.decay = decay self.eps = eps self.frozen = False self._cast_int_to_float = False if self.stateful: self.register_buffer("initialized", torch.zeros((), dtype=torch.bool)) if shared_data: self._loc = shared_data["loc"] self._var = shared_data["var"] self._count = shared_data["count"] else: self._loc = None self._var = None self._count = None else: self.initialized = False if shared_data: # FIXME raise NotImplementedError if reduce_batch_dims and not stateful: raise RuntimeError( "reduce_batch_dims=True and stateful=False are not supported." ) self.reduce_batch_dims = reduce_batch_dims @property def in_keys(self) -> Sequence[NestedKey]: in_keys = self._in_keys if not self.stateful: in_keys = in_keys + [ f"{self.prefix}_count", f"{self.prefix}_loc", f"{self.prefix}_var", ] return in_keys @in_keys.setter def in_keys(self, in_keys: Sequence[NestedKey]): self._in_keys = in_keys def set_container(self, container: Transform | EnvBase) -> None: super().set_container(container) if self.stateful: parent = getattr(self, "parent", None) if parent is not None and isinstance(parent, EnvBase): if not parent.batch_locked: warnings.warn( f"Support of {type(self).__name__} for unbatched container is experimental and subject to change." ) if parent.batch_size: warnings.warn( f"Support of {type(self).__name__} for containers with non-empty batch-size is experimental and subject to change." ) # init data = parent.fake_tensordict().get("next") self._maybe_stateful_init(data) else: parent = getattr(self, "parent", None) if parent is not None and isinstance(parent, EnvBase): self._make_prefix(parent.output_spec)
[docs] def freeze(self) -> VecNormV2: """Freezes the VecNorm, avoiding the stats to be updated when called. See :meth:`~.unfreeze`. """ self.frozen = True return self
[docs] def unfreeze(self) -> VecNormV2: """Unfreezes the VecNorm. See :meth:`~.freeze`. """ self.frozen = False return self
[docs] def frozen_copy(self): """Returns a copy of the Transform that keeps track of the stats but does not update them.""" if not self.stateful: raise RuntimeError("Cannot create a frozen copy of a statelss VecNorm.") if self._loc is None: raise RuntimeError( "Make sure the VecNorm has been initialized before creating a frozen copy." ) clone = self.clone() if self.stateful: # replace values clone._var = self._var.clone() clone._loc = self._loc.clone() clone._count = self._count.clone() # freeze return clone.freeze()
[docs] def clone(self) -> VecNormV2: other = super().clone() if self.stateful: delattr(other, "initialized") other.register_buffer("initialized", self.initialized.clone()) if self._loc is not None: other.initialized.fill_(True) other._loc = self._loc.clone() other._var = self._var.clone() other._count = self._count.clone() return other
def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: # TODO: remove this decorator when trackers are in data with _set_missing_tolerance(self, True): return self._step(tensordict_reset, tensordict_reset) return tensordict_reset def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: if self.lock is not None: self.lock.acquire() try: if self.stateful: self._maybe_stateful_init(next_tensordict) next_tensordict_select = next_tensordict.select( *self.in_keys, strict=not self.missing_tolerance ) if self.missing_tolerance and next_tensordict_select.is_empty(): return next_tensordict self._stateful_update(next_tensordict_select) next_tensordict_norm = self._stateful_norm(next_tensordict_select) else: self._maybe_stateless_init(tensordict) next_tensordict_select = next_tensordict.select( *self._in_keys_safe, strict=not self.missing_tolerance ) if self.missing_tolerance and next_tensordict_select.is_empty(): return next_tensordict loc = tensordict[f"{self.prefix}_loc"] var = tensordict[f"{self.prefix}_var"] count = tensordict[f"{self.prefix}_count"] loc, var, count = self._stateless_update( next_tensordict_select, loc, var, count ) next_tensordict_norm = self._stateless_norm( next_tensordict_select, loc, var, count ) # updates have been done in-place, we're good next_tensordict_norm.set(f"{self.prefix}_loc", loc) next_tensordict_norm.set(f"{self.prefix}_var", var) next_tensordict_norm.set(f"{self.prefix}_count", count) next_tensordict.update(next_tensordict_norm) finally: if self.lock is not None: self.lock.release() return next_tensordict def _maybe_cast_to_float(self, data): if self._cast_int_to_float: dtype = torch.get_default_dtype() data = data.apply( lambda x: x.to(dtype) if not x.dtype.is_floating_point else x ) return data @staticmethod def _maybe_make_float(x): if x.dtype.is_floating_point: return x return x.to(torch.get_default_dtype()) def _maybe_stateful_init(self, data): if not self.initialized: self.initialized.copy_(True) # Some keys (specifically rewards) may be missing, but we can use the # specs for them try: data_select = data.select(*self._in_keys_safe, strict=True) except KeyError: data_select = self.parent.full_observation_spec.zero().update( self.parent.full_reward_spec.zero() ) data_select = data_select.update(data) data_select = data_select.select(*self._in_keys_safe, strict=True) if self.reduce_batch_dims and data_select.ndim: # collapse the batch-dims data_select = data_select.mean(dim=tuple(range(data.ndim))) # For the count, we must use a TD because some keys (eg Reward) may be missing at some steps (eg, reset) # We use mean() to eliminate all dims - since it's local we don't need to expand the shape count = ( torch.zeros_like(data_select, dtype=torch.float32) .mean() .to(torch.int64) ) # create loc loc = torch.zeros_like(data_select.apply(self._maybe_make_float)) # create var var = torch.zeros_like(data_select.apply(self._maybe_make_float)) self._loc = loc self._var = var self._count = count @property def _in_keys_safe(self): if not self.stateful: return self.in_keys[:-3] return self.in_keys def _norm(self, data, loc, var, count): if self.missing_tolerance: loc = loc.select(*data.keys(True, True)) var = var.select(*data.keys(True, True)) count = count.select(*data.keys(True, True)) if loc.is_empty(): return data if self.decay < 1.0: bias_correction = 1 - (count * math.log(self.decay)).exp() bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), data) else: bias_correction = 1 var = var - loc.pow(2) loc = loc / bias_correction var = var / bias_correction scale = var.sqrt().clamp_min(self.eps) data_update = (data - loc) / scale if self.out_keys[: len(self.in_keys)] != self.in_keys: # map names for in_key, out_key in _zip_strict(self._in_keys_safe, self.out_keys): if in_key in data_update: data_update.rename_key_(in_key, out_key) else: pass return data_update def _stateful_norm(self, data): return self._norm(data, self._loc, self._var, self._count) def _stateful_update(self, data): if self.frozen: return if self.missing_tolerance: var = self._var.select(*data.keys(True, True)) loc = self._loc.select(*data.keys(True, True)) count = self._count.select(*data.keys(True, True)) else: var = self._var loc = self._loc count = self._count data = self._maybe_cast_to_float(data) if self.reduce_batch_dims and data.ndim: # The naive way to do this would be to convert the data to a list and iterate over it, but (1) that is # slow, and (2) it makes the value of the loc/var conditioned on the order we take to iterate over the data. # The second approach would be to average the data, but that would mean that having one vecnorm per batched # env or one per sub-env will lead to different results as a batch of N elements will actually be # considered as a single one. # What we go for instead is to average the data (and its squared value) then do the moving average with # adapted decay. n = data.numel() count += n data2 = data.pow(2).mean(dim=tuple(range(data.ndim))) data_mean = data.mean(dim=tuple(range(data.ndim))) if self.decay != 1.0: weight = 1 - self.decay**n else: weight = n / count else: count += 1 data2 = data.pow(2) data_mean = data if self.decay != 1.0: weight = 1 - self.decay else: weight = 1 / count loc.lerp_(end=data_mean, weight=weight) var.lerp_(end=data2, weight=weight) def _maybe_stateless_init(self, data): if not self.initialized or f"{self.prefix}_loc" not in data.keys(): self.initialized = True # select all except vecnorm # Some keys (specifically rewards) may be missing, but we can use the # specs for them try: data_select = data.select(*self._in_keys_safe, strict=True) except KeyError: data_select = self.parent.full_observation_spec.zero().update( self.parent.full_reward_spec.zero() ) data_select = data_select.update(data) data_select = data_select.select(*self._in_keys_safe, strict=True) data[f"{self.prefix}_count"] = torch.zeros_like( data_select, dtype=torch.int64 ) # create loc loc = torch.zeros_like(data_select.apply(self._maybe_make_float)) # create var var = torch.zeros_like(data_select.apply(self._maybe_make_float)) data[f"{self.prefix}_loc"] = loc data[f"{self.prefix}_var"] = var def _stateless_norm(self, data, loc, var, count): data = self._norm(data, loc, var, count) return data def _stateless_update(self, data, loc, var, count): if self.frozen: return loc, var, count count = count + 1 data = self._maybe_cast_to_float(data) if self.decay != 1.0: weight = 1 - self.decay else: weight = 1 / count loc = loc.lerp(end=data, weight=weight) var = var.lerp(end=data.pow(2), weight=weight) return loc, var, count
[docs] def transform_observation_spec(self, observation_spec: Composite) -> Composite: return self._transform_spec(observation_spec)
[docs] def transform_reward_spec( self, reward_spec: Composite, observation_spec ) -> Composite: return self._transform_spec(reward_spec, observation_spec)
[docs] def transform_output_spec(self, output_spec: Composite) -> Composite: # This is a copy-paste of the parent methd to ensure that we correct the reward spec properly output_spec = output_spec.clone() observation_spec = self.transform_observation_spec( output_spec["full_observation_spec"] ) if "full_reward_spec" in output_spec.keys(): output_spec["full_reward_spec"] = self.transform_reward_spec( output_spec["full_reward_spec"], observation_spec ) output_spec["full_observation_spec"] = observation_spec if "full_done_spec" in output_spec.keys(): output_spec["full_done_spec"] = self.transform_done_spec( output_spec["full_done_spec"] ) output_spec_keys = [ unravel_key(k[1:]) for k in output_spec.keys(True) if isinstance(k, tuple) ] out_keys = {unravel_key(k) for k in self.out_keys} in_keys = {unravel_key(k) for k in self.in_keys} for key in out_keys - in_keys: if unravel_key(key) not in output_spec_keys: warnings.warn( f"The key '{key}' is unaccounted for by the transform (expected keys {output_spec_keys}). " f"Every new entry in the tensordict resulting from a call to a transform must be " f"registered in the specs for torchrl rollouts to be consistently built. " f"Make sure transform_output_spec/transform_observation_spec/... is coded correctly. " "This warning will trigger a KeyError in v0.9, make sure to adapt your code accordingly.", category=FutureWarning, ) return output_spec
def _maybe_convert_bounded(self, in_spec): if isinstance(in_spec, Composite): return Composite( { key: self._maybe_convert_bounded(value) for key, value in in_spec.items() } ) dtype = in_spec.dtype if dtype is not None and not dtype.is_floating_point: # we need to cast the tensor and spec to a float type in_spec = in_spec.clone() in_spec.dtype = torch.get_default_dtype() self._cast_int_to_float = True if isinstance(in_spec, Bounded): in_spec = Unbounded( shape=in_spec.shape, device=in_spec.device, dtype=in_spec.dtype ) return in_spec @property def prefix(self): prefix = getattr(self, "_prefix", "_vecnorm") return prefix def _make_prefix(self, output_spec): prefix = getattr(self, "_prefix", None) if prefix is not None: return prefix if ( "_vecnorm_loc" in output_spec["full_observation_spec"].keys() or "_vecnorm_loc" in output_spec["full_reward_spec"].keys() ): prefix = "_vecnorm" + str(uuid.uuid1()) else: prefix = "_vecnorm" self._prefix = prefix return prefix def _proc_count_spec(self, count_spec, parent_shape=None): if isinstance(count_spec, Composite): for key, spec in count_spec.items(): spec = self._proc_count_spec(spec, parent_shape=count_spec.shape) count_spec[key] = spec return count_spec if count_spec.dtype: count_spec = Unbounded( shape=count_spec.shape, dtype=torch.int64, device=count_spec.device ) return count_spec def _transform_spec( self, spec: Composite, obs_spec: Composite | None = None ) -> Composite: in_specs = {} for in_key, out_key in zip(self._in_keys_safe, self.out_keys): if unravel_key(in_key) in spec.keys(True): in_spec = spec.get(in_key).clone() in_spec = self._maybe_convert_bounded(in_spec) spec.set(out_key, in_spec) in_specs[in_key] = in_spec if not self.stateful and in_specs: if obs_spec is None: obs_spec = spec loc_spec = obs_spec.get(f"{self.prefix}_loc", default=None) var_spec = obs_spec.get(f"{self.prefix}_var", default=None) count_spec = obs_spec.get(f"{self.prefix}_count", default=None) if loc_spec is None: loc_spec = Composite(shape=obs_spec.shape, device=obs_spec.device) var_spec = Composite(shape=obs_spec.shape, device=obs_spec.device) count_spec = Composite(shape=obs_spec.shape, device=obs_spec.device) loc_spec.update(in_specs) # should we clone? var_spec.update(in_specs) count_spec = count_spec.update(in_specs) count_spec = self._proc_count_spec(count_spec) obs_spec[f"{self.prefix}_loc"] = loc_spec obs_spec[f"{self.prefix}_var"] = var_spec obs_spec[f"{self.prefix}_count"] = count_spec return spec
[docs] def to_observation_norm(self) -> Compose | ObservationNorm: if not self.stateful: # FIXME raise NotImplementedError() result = [] loc, scale = self._get_loc_scale() for key, key_out in _zip_strict(self.in_keys, self.out_keys): local_result = ObservationNorm( loc=loc.get(key), scale=scale.get(key), standard_normal=True, in_keys=key, out_keys=key_out, eps=self.eps, ) result += [local_result] if len(self.in_keys) > 1: return Compose(*result) return local_result
def _get_loc_scale(self, loc_only: bool = False) -> tuple: if self.stateful: loc = self._loc count = self._count if self.decay != 1.0: bias_correction = 1 - (count * math.log(self.decay)).exp() bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), loc) else: bias_correction = 1 if loc_only: return loc / bias_correction, None var = self._var var = var - loc.pow(2) loc = loc / bias_correction var = var / bias_correction scale = var.sqrt().clamp_min(self.eps) return loc, scale else: raise RuntimeError("_get_loc_scale() called on stateless vecnorm.") def __getstate__(self) -> dict[str, Any]: state = super().__getstate__() _lock = state.pop("lock", None) if _lock is not None: state["lock_placeholder"] = None return state def __setstate__(self, state: dict[str, Any]): if "lock_placeholder" in state: state.pop("lock_placeholder") _lock = mp.Lock() state["lock"] = _lock super().__setstate__(state) SEP = ".-|-."
[docs] def set_extra_state(self, state: OrderedDict) -> None: if not self.stateful: return if not state: if self._loc is None: # we're good, not init yet return raise RuntimeError( "set_extra_state() called with a void state-dict while the instance is initialized." ) td = TensorDict(state).unflatten_keys(self.SEP) if self._loc is None and not all(v.is_shared() for v in td.values(True, True)): warnings.warn( "VecNorm wasn't initialized and the tensordict is not shared. In single " "process settings, this is ok, but if you need to share the statistics " "between workers this should require some attention. " "Make sure that the content of VecNorm is transmitted to the workers " "after calling load_state_dict and not before, as other workers " "may not have access to the loaded TensorDict." ) td.share_memory_() self._loc = td["loc"] self._var = td["var"] self._count = td["count"]
[docs] def get_extra_state(self) -> OrderedDict: if not self.stateful: return {} if self._loc is None: warnings.warn( "Querying state_dict on an uninitialized VecNorm transform will " "return a `None` value for the summary statistics. " "Loading such a state_dict on an initialized VecNorm will result in " "an error." ) return {} td = TensorDict( loc=self._loc, var=self._var, count=self._count, ) return td.flatten_keys(self.SEP).to_dict()
@property def loc(self): """Returns a TensorDict with the loc to be used for an affine transform.""" if not self.stateful: raise RuntimeError("loc cannot be computed with stateless vecnorm.") # We can't cache that value bc the summary stats could be updated by a different process loc, _ = self._get_loc_scale(loc_only=True) return loc @property def scale(self): """Returns a TensorDict with the scale to be used for an affine transform.""" if not self.stateful: raise RuntimeError("scale cannot be computed with stateless vecnorm.") # We can't cache that value bc the summary stats could be updated by a different process _, scale = self._get_loc_scale() return scale @property def standard_normal(self): """Whether the affine transform given by `loc` and `scale` follows the standard normal equation. Similar to :class:`~torchrl.envs.ObservationNorm` standard_normal attribute. Always returns ``True``. """ return True

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources