Shortcuts

tensordict.nn.add_custom_mapping

tensordict.nn.add_custom_mapping(name: str, mapping: Callable[[Tensor], Tensor])

Adds a custom mapping to be used in mapping classes.

Parameters:
  • name (str) – a mapping name.

  • mapping (callable) – a callable that takes a tensor as input and outputs a tensor with the same shape.

Examples

>>> from tensordict.nn import add_custom_mapping, NormalParamExtractor
>>> add_custom_mapping("my_mapping", lambda x: torch.zeros_like(x))
>>> npe = NormalParamExtractor(scale_mapping="my_mapping", scale_lb=0.0)
>>> assert (npe(torch.randn(10))[1] == torch.zeros(5)).all()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources