Program Listing for File lambda.h#
↰ Return to documentation for file (torch/csrc/api/include/torch/data/transforms/lambda.h)
#pragma once
#include <torch/data/transforms/base.h>
#include <functional>
#include <utility>
#include <vector>
namespace torch::data::transforms {
template <typename Input, typename Output = Input>
class BatchLambda : public BatchTransform<Input, Output> {
public:
using typename BatchTransform<Input, Output>::InputBatchType;
using typename BatchTransform<Input, Output>::OutputBatchType;
using FunctionType = std::function<OutputBatchType(InputBatchType)>;
explicit BatchLambda(FunctionType function)
: function_(std::move(function)) {}
OutputBatchType apply_batch(InputBatchType input_batch) override {
return function_(std::move(input_batch));
}
private:
FunctionType function_;
};
// A `Transform` that applies a user-provided functor to individual examples.
template <typename Input, typename Output = Input>
class Lambda : public Transform<Input, Output> {
public:
using typename Transform<Input, Output>::InputType;
using typename Transform<Input, Output>::OutputType;
using FunctionType = std::function<Output(Input)>;
explicit Lambda(FunctionType function) : function_(std::move(function)) {}
OutputType apply(InputType input) override {
return function_(std::move(input));
}
private:
FunctionType function_;
};
} // namespace torch::data::transforms