isin¶
- class tensordict.utils.isin(input: TensorDictBase, reference: TensorDictBase, key: NestedKey, dim: int = 0)¶
Tests if each element of
keyin inputdimis also present in the reference.This function returns a boolean tensor of length
input.batch_size[dim]that isTruefor elements in the entrykeythat are also present in thereference. This function assumes that bothinputandreferencehave the same batch size and contain the specified entry, otherwise an error will be raised.- Parameters:
input (TensorDictBase) – Input TensorDict.
reference (TensorDictBase) – Target TensorDict against which to test.
key (Nestedkey) – The key to test.
dim (int, optional) – The dimension along which to test. Defaults to
0.
- Returns:
- A boolean tensor of length
input.batch_size[dim]that isTruefor elements in the
inputkeytensor that are also present in thereference.
- A boolean tensor of length
- Return type:
out (Tensor)
Examples
>>> td = TensorDict( ... { ... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9]]), ... "tensor2": torch.tensor([[10, 20], [30, 40], [40, 50], [50, 60]]), ... }, ... batch_size=[4], ... ) >>> td_ref = TensorDict( ... { ... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [10, 11, 12]]), ... "tensor2": torch.tensor([[10, 20], [30, 40], [50, 60]]), ... }, ... batch_size=[3], ... ) >>> in_reference = isin(td, td_ref, key="tensor1") >>> expected_in_reference = torch.tensor([True, True, True, False]) >>> torch.testing.assert_close(in_reference, expected_in_reference)