Transforms#
Transforms apply preprocessing to data samples, such as normalization or
augmentation. They can be chained using the .map() method on datasets.
Transform (Base Class)#
The base class for all transforms. Subclass this to create custom transforms.
-
template<typename Input, typename Output>
class Transform : public torch::data::transforms::BatchTransform<std::vector<Input>, std::vector<Output>># A transformation of individual input examples to individual output examples.
Just like a
Datasetis aBatchDataset, aTransformis aBatchTransformthat can operate on the level of individual examples rather than entire batches. The batch-level transform is implemented (by default) in terms of the example-level transform, though this can be customized.
BatchTransform (Base Class)#
Base class for transforms that operate on entire batches.
-
template<typename InputBatch, typename OutputBatch>
class BatchTransform# A transformation of a batch to a new batch.
Subclassed by torch::data::transforms::Transform< Example< Tensor, Tensor >, Example< Tensor, Tensor > >, torch::data::transforms::Transform< Input, Input >, torch::data::transforms::Stack< Example<> >, torch::data::transforms::Stack< TensorExample >
Public Functions
-
virtual ~BatchTransform() = default#
-
virtual OutputBatch apply_batch(InputBatch input_batch) = 0#
Applies the transformation to the given
input_batch.
-
virtual ~BatchTransform() = default#
TensorTransform#
Base class for transforms that operate on tensors specifically.
-
template<typename Target = Tensor>
class TensorTransform : public torch::data::transforms::Transform<Example<Tensor, Tensor>, Example<Tensor, Tensor>># A
Transformthat is specialized for the typicalExample<Tensor, Tensor>combination.It exposes a single
operator()interface hook (for subclasses), and calls this function on inputExampleobjects.Public Functions
-
inline virtual OutputType apply(InputType input) override#
Implementation of
Transform::applythat callsoperator().
-
inline virtual OutputType apply(InputType input) override#
Normalize#
Normalizes tensors with a given mean and standard deviation.
-
template<typename Target = Tensor>
struct Normalize : public torch::data::transforms::TensorTransform<Tensor># Normalizes input tensors by subtracting the supplied mean and dividing by the given standard deviation.
Stack#
Stacks a batch of tensors into a single tensor.
Example:
auto dataset = torch::data::datasets::MNIST("./data")
.map(torch::data::transforms::Normalize<>(0.5, 0.5))
.map(torch::data::transforms::Stack<>());
Lambda#
-
template<typename Input, typename Output = Input>
class Lambda : public torch::data::transforms::Transform<Input, Input># -
Public Functions
-
inline explicit Lambda(FunctionType function)#
Constructs the
Lambdafrom the givenfunctionobject.
-
inline virtual OutputType apply(InputType input) override#
Applies the user-provided function object to the
input.
-
inline explicit Lambda(FunctionType function)#
TensorLambda#
-
template<typename Target = Tensor>
class TensorLambda : public torch::data::transforms::TensorTransform<Tensor># A
Lambdaspecialized for the typicalExample<Tensor, Tensor>input type.Public Functions
-
inline explicit TensorLambda(FunctionType function)#
Creates a
TensorLambdafrom the givenfunction.
-
inline explicit TensorLambda(FunctionType function)#
BatchLambda#
-
template<typename Input, typename Output = Input>
class BatchLambda : public torch::data::transforms::BatchTransform<Input, Input># A
BatchTransformthat applies a user-provided functor to a batch.Public Types
-
using FunctionType = std::function<OutputBatchType(InputBatchType)>#
Public Functions
-
inline explicit BatchLambda(FunctionType function)#
Constructs the
BatchLambdafrom the givenfunctionobject.
-
inline virtual OutputBatchType apply_batch(InputBatchType input_batch) override#
Applies the user-provided function object to the
input_batch.
-
using FunctionType = std::function<OutputBatchType(InputBatchType)>#
Chaining Transforms#
Transforms can be chained together using .map():
auto dataset = torch::data::datasets::MNIST("./data")
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
.map(torch::data::transforms::Stack<>());