Shortcuts

Source code for torchrl.trainers.algorithms.configs.transforms

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass
from typing import Any

from torchrl.trainers.algorithms.configs.common import ConfigBase


[docs]@dataclass class TransformConfig(ConfigBase): """Base configuration class for transforms.""" def __post_init__(self) -> None: """Post-initialization hook for transform configurations."""
[docs]@dataclass class NoopResetEnvConfig(TransformConfig): """Configuration for NoopResetEnv transform.""" noops: int = 30 random: bool = True _target_: str = "torchrl.envs.transforms.transforms.NoopResetEnv" def __post_init__(self) -> None: """Post-initialization hook for NoopResetEnv configuration.""" super().__post_init__()
[docs]@dataclass class StepCounterConfig(TransformConfig): """Configuration for StepCounter transform.""" max_steps: int | None = None truncated_key: str | None = "truncated" step_count_key: str | None = "step_count" update_done: bool = True _target_: str = "torchrl.envs.transforms.transforms.StepCounter" def __post_init__(self) -> None: """Post-initialization hook for StepCounter configuration.""" super().__post_init__()
[docs]@dataclass class ComposeConfig(TransformConfig): """Configuration for Compose transform.""" transforms: list[Any] | None = None _target_: str = "torchrl.envs.transforms.transforms.Compose" def __post_init__(self) -> None: """Post-initialization hook for Compose configuration.""" super().__post_init__() if self.transforms is None: self.transforms = []
[docs]@dataclass class DoubleToFloatConfig(TransformConfig): """Configuration for DoubleToFloat transform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None in_keys_inv: list[str] | None = None out_keys_inv: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.DoubleToFloat" def __post_init__(self) -> None: """Post-initialization hook for DoubleToFloat configuration.""" super().__post_init__()
[docs]@dataclass class ToTensorImageConfig(TransformConfig): """Configuration for ToTensorImage transform.""" from_int: bool | None = None unsqueeze: bool = False dtype: str | None = None in_keys: list[str] | None = None out_keys: list[str] | None = None shape_tolerant: bool = False _target_: str = "torchrl.envs.transforms.transforms.ToTensorImage" def __post_init__(self) -> None: """Post-initialization hook for ToTensorImage configuration.""" super().__post_init__()
[docs]@dataclass class ClipTransformConfig(TransformConfig): """Configuration for ClipTransform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None in_keys_inv: list[str] | None = None out_keys_inv: list[str] | None = None low: float | None = None high: float | None = None _target_: str = "torchrl.envs.transforms.transforms.ClipTransform" def __post_init__(self) -> None: """Post-initialization hook for ClipTransform configuration.""" super().__post_init__()
[docs]@dataclass class ResizeConfig(TransformConfig): """Configuration for Resize transform.""" w: int = 84 h: int = 84 interpolation: str = "bilinear" in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.Resize" def __post_init__(self) -> None: """Post-initialization hook for Resize configuration.""" super().__post_init__()
[docs]@dataclass class CenterCropConfig(TransformConfig): """Configuration for CenterCrop transform.""" height: int = 84 width: int = 84 in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.CenterCrop" def __post_init__(self) -> None: """Post-initialization hook for CenterCrop configuration.""" super().__post_init__()
[docs]@dataclass class FlattenObservationConfig(TransformConfig): """Configuration for FlattenObservation transform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.FlattenObservation" def __post_init__(self) -> None: """Post-initialization hook for FlattenObservation configuration.""" super().__post_init__()
[docs]@dataclass class GrayScaleConfig(TransformConfig): """Configuration for GrayScale transform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.GrayScale" def __post_init__(self) -> None: """Post-initialization hook for GrayScale configuration.""" super().__post_init__()
[docs]@dataclass class ObservationNormConfig(TransformConfig): """Configuration for ObservationNorm transform.""" loc: float = 0.0 scale: float = 1.0 in_keys: list[str] | None = None out_keys: list[str] | None = None standard_normal: bool = False eps: float = 1e-8 _target_: str = "torchrl.envs.transforms.transforms.ObservationNorm" def __post_init__(self) -> None: """Post-initialization hook for ObservationNorm configuration.""" super().__post_init__()
[docs]@dataclass class CatFramesConfig(TransformConfig): """Configuration for CatFrames transform.""" N: int = 4 dim: int = -3 in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.CatFrames" def __post_init__(self) -> None: """Post-initialization hook for CatFrames configuration.""" super().__post_init__()
[docs]@dataclass class RewardClippingConfig(TransformConfig): """Configuration for RewardClipping transform.""" clamp_min: float | None = None clamp_max: float | None = None in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.RewardClipping" def __post_init__(self) -> None: """Post-initialization hook for RewardClipping configuration.""" super().__post_init__()
[docs]@dataclass class RewardScalingConfig(TransformConfig): """Configuration for RewardScaling transform.""" loc: float = 0.0 scale: float = 1.0 in_keys: list[str] | None = None out_keys: list[str] | None = None standard_normal: bool = False eps: float = 1e-8 _target_: str = "torchrl.envs.transforms.transforms.RewardScaling" def __post_init__(self) -> None: """Post-initialization hook for RewardScaling configuration.""" super().__post_init__()
[docs]@dataclass class VecNormConfig(TransformConfig): """Configuration for VecNorm transform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None decay: float = 0.99 eps: float = 1e-8 _target_: str = "torchrl.envs.transforms.transforms.VecNorm" def __post_init__(self) -> None: """Post-initialization hook for VecNorm configuration.""" super().__post_init__()
[docs]@dataclass class FrameSkipTransformConfig(TransformConfig): """Configuration for FrameSkipTransform.""" frame_skip: int = 4 in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.FrameSkipTransform" def __post_init__(self) -> None: """Post-initialization hook for FrameSkipTransform configuration.""" super().__post_init__()
[docs]@dataclass class EndOfLifeTransformConfig(TransformConfig): """Configuration for EndOfLifeTransform.""" eol_key: str = "end-of-life" lives_key: str = "lives" done_key: str = "done" eol_attribute: str = "unwrapped.ale.lives" _target_: str = "torchrl.envs.transforms.gym_transforms.EndOfLifeTransform" def __post_init__(self) -> None: """Post-initialization hook for EndOfLifeTransform configuration.""" super().__post_init__()
[docs]@dataclass class MultiStepTransformConfig(TransformConfig): """Configuration for MultiStepTransform.""" n_steps: int = 3 gamma: float = 0.99 in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.rb_transforms.MultiStepTransform" def __post_init__(self) -> None: """Post-initialization hook for MultiStepTransform configuration.""" super().__post_init__()
[docs]@dataclass class TargetReturnConfig(TransformConfig): """Configuration for TargetReturn transform.""" target_return: float = 10.0 mode: str = "reduce" in_keys: list[str] | None = None out_keys: list[str] | None = None reset_key: str | None = None _target_: str = "torchrl.envs.transforms.transforms.TargetReturn" def __post_init__(self) -> None: """Post-initialization hook for TargetReturn configuration.""" super().__post_init__()
[docs]@dataclass class BinarizeRewardConfig(TransformConfig): """Configuration for BinarizeReward transform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.BinarizeReward" def __post_init__(self) -> None: """Post-initialization hook for BinarizeReward configuration.""" super().__post_init__()
[docs]@dataclass class ActionDiscretizerConfig(TransformConfig): """Configuration for ActionDiscretizer transform.""" num_intervals: int = 10 action_key: str = "action" out_action_key: str | None = None sampling: str | None = None categorical: bool = True _target_: str = "torchrl.envs.transforms.transforms.ActionDiscretizer" def __post_init__(self) -> None: """Post-initialization hook for ActionDiscretizer configuration.""" super().__post_init__()
[docs]@dataclass class AutoResetTransformConfig(TransformConfig): """Configuration for AutoResetTransform.""" replace: bool | None = None fill_float: str = "nan" fill_int: int = -1 fill_bool: bool = False _target_: str = "torchrl.envs.transforms.transforms.AutoResetTransform" def __post_init__(self) -> None: """Post-initialization hook for AutoResetTransform configuration.""" super().__post_init__()
[docs]@dataclass class BatchSizeTransformConfig(TransformConfig): """Configuration for BatchSizeTransform.""" batch_size: list[int] | None = None reshape_fn: Any = None reset_func: Any = None env_kwarg: bool = False _target_: str = "torchrl.envs.transforms.transforms.BatchSizeTransform" def __post_init__(self) -> None: """Post-initialization hook for BatchSizeTransform configuration.""" super().__post_init__()
[docs]@dataclass class DeviceCastTransformConfig(TransformConfig): """Configuration for DeviceCastTransform.""" device: str = "cpu" in_keys: list[str] | None = None out_keys: list[str] | None = None in_keys_inv: list[str] | None = None out_keys_inv: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.DeviceCastTransform" def __post_init__(self) -> None: """Post-initialization hook for DeviceCastTransform configuration.""" super().__post_init__()
[docs]@dataclass class DTypeCastTransformConfig(TransformConfig): """Configuration for DTypeCastTransform.""" dtype: str = "torch.float32" in_keys: list[str] | None = None out_keys: list[str] | None = None in_keys_inv: list[str] | None = None out_keys_inv: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.DTypeCastTransform" def __post_init__(self) -> None: """Post-initialization hook for DTypeCastTransform configuration.""" super().__post_init__()
[docs]@dataclass class UnsqueezeTransformConfig(TransformConfig): """Configuration for UnsqueezeTransform.""" dim: int = 0 in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.UnsqueezeTransform" def __post_init__(self) -> None: """Post-initialization hook for UnsqueezeTransform configuration.""" super().__post_init__()
[docs]@dataclass class SqueezeTransformConfig(TransformConfig): """Configuration for SqueezeTransform.""" dim: int = 0 in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.SqueezeTransform" def __post_init__(self) -> None: """Post-initialization hook for SqueezeTransform configuration.""" super().__post_init__()
[docs]@dataclass class PermuteTransformConfig(TransformConfig): """Configuration for PermuteTransform.""" dims: list[int] | None = None in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.PermuteTransform" def __post_init__(self) -> None: """Post-initialization hook for PermuteTransform configuration.""" super().__post_init__() if self.dims is None: self.dims = [0, 2, 1]
[docs]@dataclass class CatTensorsConfig(TransformConfig): """Configuration for CatTensors transform.""" dim: int = -1 in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.CatTensors" def __post_init__(self) -> None: """Post-initialization hook for CatTensors configuration.""" super().__post_init__()
[docs]@dataclass class StackConfig(TransformConfig): """Configuration for Stack transform.""" dim: int = 0 in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.Stack" def __post_init__(self) -> None: """Post-initialization hook for Stack configuration.""" super().__post_init__()
[docs]@dataclass class DiscreteActionProjectionConfig(TransformConfig): """Configuration for DiscreteActionProjection transform.""" num_actions: int = 4 in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.DiscreteActionProjection" def __post_init__(self) -> None: """Post-initialization hook for DiscreteActionProjection configuration.""" super().__post_init__()
[docs]@dataclass class TensorDictPrimerConfig(TransformConfig): """Configuration for TensorDictPrimer transform.""" primer_spec: Any = None in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.TensorDictPrimer" def __post_init__(self) -> None: """Post-initialization hook for TensorDictPrimer configuration.""" super().__post_init__()
[docs]@dataclass class PinMemoryTransformConfig(TransformConfig): """Configuration for PinMemoryTransform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.PinMemoryTransform" def __post_init__(self) -> None: """Post-initialization hook for PinMemoryTransform configuration.""" super().__post_init__()
[docs]@dataclass class RewardSumConfig(TransformConfig): """Configuration for RewardSum transform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.RewardSum" def __post_init__(self) -> None: """Post-initialization hook for RewardSum configuration.""" super().__post_init__()
[docs]@dataclass class ExcludeTransformConfig(TransformConfig): """Configuration for ExcludeTransform.""" exclude_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.ExcludeTransform" def __post_init__(self) -> None: """Post-initialization hook for ExcludeTransform configuration.""" super().__post_init__() if self.exclude_keys is None: self.exclude_keys = []
[docs]@dataclass class SelectTransformConfig(TransformConfig): """Configuration for SelectTransform.""" include_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.SelectTransform" def __post_init__(self) -> None: """Post-initialization hook for SelectTransform configuration.""" super().__post_init__() if self.include_keys is None: self.include_keys = []
[docs]@dataclass class TimeMaxPoolConfig(TransformConfig): """Configuration for TimeMaxPool transform.""" dim: int = -1 in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.TimeMaxPool" def __post_init__(self) -> None: """Post-initialization hook for TimeMaxPool configuration.""" super().__post_init__()
[docs]@dataclass class RandomCropTensorDictConfig(TransformConfig): """Configuration for RandomCropTensorDict transform.""" crop_size: list[int] | None = None in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.RandomCropTensorDict" def __post_init__(self) -> None: """Post-initialization hook for RandomCropTensorDict configuration.""" super().__post_init__() if self.crop_size is None: self.crop_size = [84, 84]
[docs]@dataclass class InitTrackerConfig(TransformConfig): """Configuration for InitTracker transform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.InitTracker" def __post_init__(self) -> None: """Post-initialization hook for InitTracker configuration.""" super().__post_init__()
[docs]@dataclass class RenameTransformConfig(TransformConfig): """Configuration for RenameTransform.""" key_mapping: dict[str, str] | None = None _target_: str = "torchrl.envs.transforms.transforms.RenameTransform" def __post_init__(self) -> None: """Post-initialization hook for RenameTransform configuration.""" super().__post_init__() if self.key_mapping is None: self.key_mapping = {}
[docs]@dataclass class Reward2GoTransformConfig(TransformConfig): """Configuration for Reward2GoTransform.""" gamma: float = 0.99 in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.Reward2GoTransform" def __post_init__(self) -> None: """Post-initialization hook for Reward2GoTransform configuration.""" super().__post_init__()
[docs]@dataclass class ActionMaskConfig(TransformConfig): """Configuration for ActionMask transform.""" mask_key: str = "action_mask" in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.ActionMask" def __post_init__(self) -> None: """Post-initialization hook for ActionMask configuration.""" super().__post_init__()
[docs]@dataclass class VecGymEnvTransformConfig(TransformConfig): """Configuration for VecGymEnvTransform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.VecGymEnvTransform" def __post_init__(self) -> None: """Post-initialization hook for VecGymEnvTransform configuration.""" super().__post_init__()
[docs]@dataclass class BurnInTransformConfig(TransformConfig): """Configuration for BurnInTransform.""" burn_in: int = 10 in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.BurnInTransform" def __post_init__(self) -> None: """Post-initialization hook for BurnInTransform configuration.""" super().__post_init__()
[docs]@dataclass class SignTransformConfig(TransformConfig): """Configuration for SignTransform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.SignTransform" def __post_init__(self) -> None: """Post-initialization hook for SignTransform configuration.""" super().__post_init__()
[docs]@dataclass class RemoveEmptySpecsConfig(TransformConfig): """Configuration for RemoveEmptySpecs transform.""" _target_: str = "torchrl.envs.transforms.transforms.RemoveEmptySpecs" def __post_init__(self) -> None: """Post-initialization hook for RemoveEmptySpecs configuration.""" super().__post_init__()
[docs]@dataclass class TrajCounterConfig(TransformConfig): """Configuration for TrajCounter transform.""" out_key: str = "traj_count" repeats: int | None = None _target_: str = "torchrl.envs.transforms.transforms.TrajCounter" def __post_init__(self) -> None: """Post-initialization hook for TrajCounter configuration.""" super().__post_init__()
[docs]@dataclass class LineariseRewardsConfig(TransformConfig): """Configuration for LineariseRewards transform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None weights: list[float] | None = None _target_: str = "torchrl.envs.transforms.transforms.LineariseRewards" def __post_init__(self) -> None: """Post-initialization hook for LineariseRewards configuration.""" super().__post_init__() if self.in_keys is None: self.in_keys = []
[docs]@dataclass class ConditionalSkipConfig(TransformConfig): """Configuration for ConditionalSkip transform.""" cond: Any = None _target_: str = "torchrl.envs.transforms.transforms.ConditionalSkip" def __post_init__(self) -> None: """Post-initialization hook for ConditionalSkip configuration.""" super().__post_init__()
[docs]@dataclass class MultiActionConfig(TransformConfig): """Configuration for MultiAction transform.""" dim: int = 1 stack_rewards: bool = True stack_observations: bool = False _target_: str = "torchrl.envs.transforms.transforms.MultiAction" def __post_init__(self) -> None: """Post-initialization hook for MultiAction configuration.""" super().__post_init__()
[docs]@dataclass class TimerConfig(TransformConfig): """Configuration for Timer transform.""" out_keys: list[str] | None = None time_key: str = "time" _target_: str = "torchrl.envs.transforms.transforms.Timer" def __post_init__(self) -> None: """Post-initialization hook for Timer configuration.""" super().__post_init__()
[docs]@dataclass class ConditionalPolicySwitchConfig(TransformConfig): """Configuration for ConditionalPolicySwitch transform.""" policy: Any = None condition: Any = None _target_: str = "torchrl.envs.transforms.transforms.ConditionalPolicySwitch" def __post_init__(self) -> None: """Post-initialization hook for ConditionalPolicySwitch configuration.""" super().__post_init__()
[docs]@dataclass class KLRewardTransformConfig(TransformConfig): """Configuration for KLRewardTransform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.llm.KLRewardTransform" def __post_init__(self) -> None: """Post-initialization hook for KLRewardTransform configuration.""" super().__post_init__()
[docs]@dataclass class R3MTransformConfig(TransformConfig): """Configuration for R3MTransform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None model_name: str = "resnet18" device: str = "cpu" _target_: str = "torchrl.envs.transforms.r3m.R3MTransform" def __post_init__(self) -> None: """Post-initialization hook for R3MTransform configuration.""" super().__post_init__()
[docs]@dataclass class VC1TransformConfig(TransformConfig): """Configuration for VC1Transform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None device: str = "cpu" _target_: str = "torchrl.envs.transforms.vc1.VC1Transform" def __post_init__(self) -> None: """Post-initialization hook for VC1Transform configuration.""" super().__post_init__()
[docs]@dataclass class VIPTransformConfig(TransformConfig): """Configuration for VIPTransform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None device: str = "cpu" _target_: str = "torchrl.envs.transforms.vip.VIPTransform" def __post_init__(self) -> None: """Post-initialization hook for VIPTransform configuration.""" super().__post_init__()
[docs]@dataclass class VIPRewardTransformConfig(TransformConfig): """Configuration for VIPRewardTransform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None device: str = "cpu" _target_: str = "torchrl.envs.transforms.vip.VIPRewardTransform" def __post_init__(self) -> None: """Post-initialization hook for VIPRewardTransform configuration.""" super().__post_init__()
[docs]@dataclass class VecNormV2Config(TransformConfig): """Configuration for VecNormV2 transform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None decay: float = 0.99 eps: float = 1e-8 _target_: str = "torchrl.envs.transforms.vecnorm.VecNormV2" def __post_init__(self) -> None: """Post-initialization hook for VecNormV2 configuration.""" super().__post_init__()
[docs]@dataclass class FiniteTensorDictCheckConfig(TransformConfig): """Configuration for FiniteTensorDictCheck transform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.FiniteTensorDictCheck" def __post_init__(self) -> None: """Post-initialization hook for FiniteTensorDictCheck configuration.""" super().__post_init__()
[docs]@dataclass class UnaryTransformConfig(TransformConfig): """Configuration for UnaryTransform.""" fn: Any = None in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.UnaryTransform" def __post_init__(self) -> None: """Post-initialization hook for UnaryTransform configuration.""" super().__post_init__()
[docs]@dataclass class HashConfig(TransformConfig): """Configuration for Hash transform.""" in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.Hash" def __post_init__(self) -> None: """Post-initialization hook for Hash configuration.""" super().__post_init__()
[docs]@dataclass class TokenizerConfig(TransformConfig): """Configuration for Tokenizer transform.""" vocab_size: int = 1000 in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.Tokenizer" def __post_init__(self) -> None: """Post-initialization hook for Tokenizer configuration.""" super().__post_init__()
[docs]@dataclass class CropConfig(TransformConfig): """Configuration for Crop transform.""" top: int = 0 left: int = 0 height: int = 84 width: int = 84 in_keys: list[str] | None = None out_keys: list[str] | None = None _target_: str = "torchrl.envs.transforms.transforms.Crop" def __post_init__(self) -> None: """Post-initialization hook for Crop configuration.""" super().__post_init__()
@dataclass class FlattenTensorDictConfig(TransformConfig): """Configuration for flattening TensorDict during inverse pass. This transform reshapes the tensordict to have a flat batch dimension during the inverse pass, which is useful for replay buffers that need to store data with a flat batch structure. """ _target_: str = "torchrl.envs.transforms.transforms.FlattenTensorDict" def __post_init__(self) -> None: """Post-initialization hook for FlattenTensorDict configuration.""" super().__post_init__()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources