Tensor Indexing#
The PyTorch C++ API provides tensor indexing similar to Python. Use
torch::indexing namespace for index types:
using namespace torch::indexing;
The main difference from Python is that instead of using the [] operator,
the C++ API uses the index and index_put_ methods:
torch::Tensor::index— read elementstorch::Tensor::index_put_— write elements
Index Types#
The TensorIndex class accepts six types of indices via implicit constructors:
Type |
C++ |
Python equivalent |
|---|---|---|
None (unsqueeze) |
|
|
Ellipsis |
|
|
Integer |
|
|
Boolean |
|
|
Slice |
|
|
Tensor |
|
|
Getter Operations#
Python |
C++ |
|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Setter Operations#
Python |
C++ |
|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
The index_put_ method also accepts an optional accumulate parameter.
When true, values are added to existing values instead of replacing them:
tensor.index_put_({mask}, values, /*accumulate=*/true);
Slice Syntax#
The Slice constructor signature is:
Slice(
std::optional<c10::SymInt> start = std::nullopt,
std::optional<c10::SymInt> stop = std::nullopt,
std::optional<c10::SymInt> step = std::nullopt);
Pass None for open-ended bounds:
Python |
C++ |
|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
Full Example#
#include <torch/torch.h>
using namespace torch::indexing;
auto tensor = torch::arange(2 * 3 * 4).reshape({2, 3, 4});
// Basic indexing
auto row = tensor.index({0}); // tensor[0]
auto elem = tensor.index({1, 2, 3}); // tensor[1, 2, 3]
// Slicing
auto sliced = tensor.index({Slice(), Slice(0, 2)}); // tensor[:, 0:2]
// None (unsqueeze) and Ellipsis
auto unsqueezed = tensor.index({None}); // tensor[None]
auto last_dim = tensor.index({Ellipsis, -1}); // tensor[..., -1]
// Boolean mask indexing
auto mask = tensor > 10;
auto selected = tensor.index({mask}); // tensor[tensor > 10]
// Integer tensor (fancy) indexing
auto idx = torch::tensor({0, 2});
auto gathered = tensor.index({Slice(), idx}); // tensor[:, [0, 2]]
// Setting values
tensor.index_put_({0, Slice(), 0}, 99); // tensor[0, :, 0] = 99
tensor.index_put_({mask}, 0); // tensor[tensor > 10] = 0