Shortcuts

Source code for torchrl.trainers.algorithms.configs.data

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any

from omegaconf import MISSING

from torchrl.trainers.algorithms.configs.common import ConfigBase


@dataclass
class WriterConfig(ConfigBase):
    """Base configuration class for replay buffer writers."""

    _target_: str = "torchrl.data.replay_buffers.Writer"

    def __post_init__(self) -> None:
        """Post-initialization hook for writer configurations."""


[docs]@dataclass class RoundRobinWriterConfig(WriterConfig): """Configuration for round-robin writer that distributes data across multiple storages.""" _target_: str = "torchrl.data.replay_buffers.RoundRobinWriter" compilable: bool = False def __post_init__(self) -> None: """Post-initialization hook for round-robin writer configurations.""" super().__post_init__()
@dataclass class SamplerConfig(ConfigBase): """Base configuration class for replay buffer samplers.""" _target_: str = "torchrl.data.replay_buffers.Sampler" def __post_init__(self) -> None: """Post-initialization hook for sampler configurations."""
[docs]@dataclass class RandomSamplerConfig(SamplerConfig): """Configuration for random sampling from replay buffer.""" _target_: str = "torchrl.data.replay_buffers.RandomSampler" def __post_init__(self) -> None: """Post-initialization hook for random sampler configurations.""" super().__post_init__()
@dataclass class WriterEnsembleConfig(WriterConfig): """Configuration for ensemble writer that combines multiple writers.""" _target_: str = "torchrl.data.replay_buffers.WriterEnsemble" writers: list[Any] = field(default_factory=list) p: Any = None @dataclass class TensorDictMaxValueWriterConfig(WriterConfig): """Configuration for TensorDict max value writer.""" _target_: str = "torchrl.data.replay_buffers.TensorDictMaxValueWriter" rank_key: Any = None reduction: str = "sum" @dataclass class TensorDictRoundRobinWriterConfig(WriterConfig): """Configuration for TensorDict round-robin writer.""" _target_: str = "torchrl.data.replay_buffers.TensorDictRoundRobinWriter" compilable: bool = False @dataclass class ImmutableDatasetWriterConfig(WriterConfig): """Configuration for immutable dataset writer.""" _target_: str = "torchrl.data.replay_buffers.ImmutableDatasetWriter" @dataclass class SamplerEnsembleConfig(SamplerConfig): """Configuration for ensemble sampler that combines multiple samplers.""" _target_: str = "torchrl.data.replay_buffers.SamplerEnsemble" samplers: list[Any] = field(default_factory=list) p: Any = None @dataclass class PrioritizedSliceSamplerConfig(SamplerConfig): """Configuration for prioritized slice sampling from replay buffer.""" num_slices: int | None = None slice_len: int | None = None end_key: Any = None traj_key: Any = None ends: Any = None trajectories: Any = None cache_values: bool = False truncated_key: Any = ("next", "truncated") strict_length: bool = True compile: Any = False span: Any = False use_gpu: Any = False max_capacity: int | None = None alpha: float | None = None beta: float | None = None eps: float | None = None reduction: str | None = None _target_: str = "torchrl.data.replay_buffers.PrioritizedSliceSampler"
[docs]@dataclass class SliceSamplerWithoutReplacementConfig(SamplerConfig): """Configuration for slice sampling without replacement.""" _target_: str = "torchrl.data.replay_buffers.SliceSamplerWithoutReplacement" num_slices: int | None = None slice_len: int | None = None end_key: Any = None traj_key: Any = None ends: Any = None trajectories: Any = None cache_values: bool = False truncated_key: Any = ("next", "truncated") strict_length: bool = True compile: Any = False span: Any = False use_gpu: Any = False
[docs]@dataclass class SliceSamplerConfig(SamplerConfig): """Configuration for slice sampling from replay buffer.""" _target_: str = "torchrl.data.replay_buffers.SliceSampler" num_slices: int | None = None slice_len: int | None = None end_key: Any = None traj_key: Any = None ends: Any = None trajectories: Any = None cache_values: bool = False truncated_key: Any = ("next", "truncated") strict_length: bool = True compile: Any = False span: Any = False use_gpu: Any = False
[docs]@dataclass class PrioritizedSamplerConfig(SamplerConfig): """Configuration for prioritized sampling from replay buffer.""" max_capacity: int | None = None alpha: float | None = None beta: float | None = None eps: float | None = None reduction: str | None = None _target_: str = "torchrl.data.replay_buffers.PrioritizedSampler"
[docs]@dataclass class SamplerWithoutReplacementConfig(SamplerConfig): """Configuration for sampling without replacement.""" _target_: str = "torchrl.data.replay_buffers.SamplerWithoutReplacement" drop_last: bool = False shuffle: bool = True
@dataclass class StorageConfig(ConfigBase): """Base configuration class for replay buffer storage.""" _partial_: bool = False _target_: str = "torchrl.data.replay_buffers.Storage" def __post_init__(self) -> None: """Post-initialization hook for storage configurations."""
[docs]@dataclass class TensorStorageConfig(StorageConfig): """Configuration for tensor-based storage in replay buffer.""" _target_: str = "torchrl.data.replay_buffers.TensorStorage" max_size: int | None = None storage: Any = None device: Any = None ndim: int | None = None compilable: bool = False def __post_init__(self) -> None: """Post-initialization hook for tensor storage configurations.""" super().__post_init__()
[docs]@dataclass class ListStorageConfig(StorageConfig): """Configuration for list-based storage in replay buffer.""" _target_: str = "torchrl.data.replay_buffers.ListStorage" max_size: int | None = None compilable: bool = False
[docs]@dataclass class StorageEnsembleWriterConfig(StorageConfig): """Configuration for storage ensemble writer.""" _target_: str = "torchrl.data.replay_buffers.StorageEnsembleWriter" writers: list[Any] = MISSING transforms: list[Any] = MISSING
[docs]@dataclass class LazyStackStorageConfig(StorageConfig): """Configuration for lazy stack storage.""" _target_: str = "torchrl.data.replay_buffers.LazyStackStorage" max_size: int | None = None compilable: bool = False stack_dim: int = 0
[docs]@dataclass class StorageEnsembleConfig(StorageConfig): """Configuration for storage ensemble.""" _target_: str = "torchrl.data.replay_buffers.StorageEnsemble" storages: list[Any] = MISSING transforms: list[Any] = MISSING
[docs]@dataclass class LazyMemmapStorageConfig(StorageConfig): """Configuration for lazy memory-mapped storage.""" _target_: str = "torchrl.data.replay_buffers.LazyMemmapStorage" max_size: int | None = None device: Any = None ndim: int = 1 compilable: bool = False
[docs]@dataclass class LazyTensorStorageConfig(StorageConfig): """Configuration for lazy tensor storage.""" _target_: str = "torchrl.data.replay_buffers.LazyTensorStorage" max_size: int | None = None device: Any = None ndim: int = 1 compilable: bool = False
@dataclass class ReplayBufferBaseConfig(ConfigBase): """Base configuration class for replay buffers.""" _partial_: bool = False def __post_init__(self) -> None: """Post-initialization hook for replay buffer configurations."""
[docs]@dataclass class TensorDictReplayBufferConfig(ReplayBufferBaseConfig): """Configuration for TensorDict-based replay buffer.""" _target_: str = "torchrl.data.replay_buffers.TensorDictReplayBuffer" sampler: Any = None storage: Any = None writer: Any = None transform: Any = None batch_size: int | None = None def __post_init__(self) -> None: """Post-initialization hook for TensorDict replay buffer configurations.""" super().__post_init__()
[docs]@dataclass class ReplayBufferConfig(ReplayBufferBaseConfig): """Configuration for generic replay buffer.""" _target_: str = "torchrl.data.replay_buffers.ReplayBuffer" sampler: Any = None storage: Any = None writer: Any = None transform: Any = None batch_size: int | None = None

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources