register_dataset#
- torchrl.data.datasets.register_dataset(prefix: str, dataset: str | Callable[[...], BaseDatasetExperienceReplay], *, replace: bool = False) None[source]#
Register a dataset factory for
load_dataset().The registered prefix can then be used in strings of the form
"<prefix>:<dataset-id>". The dataset factory is called asdataset(dataset_id, **kwargs).- Parameters:
prefix (str) – source prefix used before the
":"separator.dataset (Callable or str) – dataset factory, or an import string of the form
"module:attribute"resolved lazily when the prefix is used.replace (bool, optional) – if
True, replace an existing registration. Defaults toFalse.
Examples
>>> from torchrl.data.datasets import register_dataset, load_dataset >>> class ToyDataset: ... def __init__(self, dataset_id, **kwargs): ... self.dataset_id = dataset_id >>> register_dataset("toy", ToyDataset, replace=True) >>> load_dataset("toy:example").dataset_id 'example'