SelectKeys

class torchrl.trainers.SelectKeys(keys: Sequence[str])[source]

Selects keys in a TensorDict batch.

Parameters:

keys (iterable of strings) – keys to be selected in the tensordict.

Examples

>>> trainer = make_trainer()
>>> key1 = "first key"
>>> key2 = "second key"
>>> td = TensorDict(
...     {
...         key1: torch.randn(3),
...         key2: torch.randn(3),
...     },
...     [],
... )
>>> trainer.register_op("batch_process", SelectKeys([key1]))
>>> td_out = trainer._process_batch_hook(td)
>>> assert key1 in td_out.keys()
>>> assert key2 not in td_out.keys()
register(trainer, name='select_keys') None[source]

Registers the hook in the trainer at a default location.

Parameters:
  • trainer (Trainer) – the trainer where the hook must be registered.

  • name (str) – the name of the hook.

Note

To register the hook at another location than the default, use register_op().

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources