torch.Tensor

A torch.Tensor is a multi-dimensional matrix containing elements of a single data type.

Torch defines seven CPU tensor types and eight GPU tensor types:

Data type CPU tensor GPU tensor
32-bit floating point torch.FloatTensor torch.cuda.FloatTensor
64-bit floating point torch.DoubleTensor torch.cuda.DoubleTensor
16-bit floating point N/A torch.cuda.HalfTensor
8-bit integer (unsigned) torch.ByteTensor torch.cuda.ByteTensor
8-bit integer (signed) torch.CharTensor torch.cuda.CharTensor
16-bit integer (signed) torch.ShortTensor torch.cuda.ShortTensor
32-bit integer (signed) torch.IntTensor torch.cuda.IntTensor
64-bit integer (signed) torch.LongTensor torch.cuda.LongTensor

The torch.Tensor constructor is an alias for the default tensor type (torch.FloatTensor).

A tensor can be constructed from a Python list or sequence:

>>> torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
1  2  3
4  5  6
[torch.FloatTensor of size 2x3]

An empty tensor can be constructed by specifying its size:

>>> torch.IntTensor(2, 4).zero_()
0  0  0  0
0  0  0  0
[torch.IntTensor of size 2x4]

The contents of a tensor can be accessed and modified using Python’s indexing and slicing notation:

>>> x = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
>>> print(x[1][2])
6.0
>>> x[0][1] = 8
>>> print(x)
 1  8  3
 4  5  6
[torch.FloatTensor of size 2x3]

Each tensor has an associated torch.Storage, which holds its data. The tensor class provides multi-dimensional, strided view of a storage and defines numeric operations on it.

Note

Methods which mutate a tensor are marked with an underscore suffix. For example, torch.FloatTensor.abs_() computes the absolute value in-place and returns the modified tensor, while torch.FloatTensor.abs() computes the result in a new tensor.

class torch.Tensor
class torch.Tensor(*sizes)
class torch.Tensor(size)
class torch.Tensor(sequence)
class torch.Tensor(ndarray)
class torch.Tensor(tensor)
class torch.Tensor(storage)

Creates a new tensor from an optional size or data.

If no arguments are given, an empty zero-dimensional tensor is returned. If a numpy.ndarray, torch.Tensor, or torch.Storage is given, a new tensor that shares the same data is returned. If a Python sequence is given, a new tensor is created from a copy of the sequence.

abs() → Tensor

See torch.abs()

abs_() → Tensor

In-place version of abs()

acos() → Tensor

See torch.acos()

acos_() → Tensor

In-place version of acos()

add(value)

See torch.add()

add_(value)

In-place version of add()

addbmm(beta=1, mat, alpha=1, batch1, batch2) → Tensor

See torch.addbmm()

addbmm_(beta=1, mat, alpha=1, batch1, batch2) → Tensor

In-place version of addbmm()

addcdiv(value=1, tensor1, tensor2) → Tensor

See torch.addcdiv()

addcdiv_(value=1, tensor1, tensor2) → Tensor

In-place version of addcdiv()

addcmul(value=1, tensor1, tensor2) → Tensor

See torch.addcmul()

addcmul_(value=1, tensor1, tensor2) → Tensor

In-place version of addcmul()

addmm(beta=1, mat, alpha=1, mat1, mat2) → Tensor

See torch.addmm()

addmm_(beta=1, mat, alpha=1, mat1, mat2) → Tensor

In-place version of addmm()

addmv(beta=1, tensor, alpha=1, mat, vec) → Tensor

See torch.addmv()

addmv_(beta=1, tensor, alpha=1, mat, vec) → Tensor

In-place version of addmv()

addr(beta=1, alpha=1, vec1, vec2) → Tensor

See torch.addr()

addr_(beta=1, alpha=1, vec1, vec2) → Tensor

In-place version of addr()

apply_(callable) → Tensor

Applies the function callable to each element in the tensor, replacing each element with the value returned by callable.

Note

This function only works with CPU tensors and should not be used in code sections that require high performance.

asin() → Tensor

See torch.asin()

asin_() → Tensor

In-place version of asin()

atan() → Tensor

See torch.atan()

atan2(other) → Tensor

See torch.atan2()

atan2_(other) → Tensor

In-place version of atan2()

atan_() → Tensor

In-place version of atan()

baddbmm(beta=1, alpha=1, batch1, batch2) → Tensor

See torch.baddbmm()

baddbmm_(beta=1, alpha=1, batch1, batch2) → Tensor

In-place version of baddbmm()

bernoulli() → Tensor

See torch.bernoulli()

bernoulli_() → Tensor

In-place version of bernoulli()

bmm(batch2) → Tensor

See torch.bmm()

byte()

Casts this tensor to byte type

cauchy_(median=0, sigma=1, *, generator=None) → Tensor

Fills the tensor with numbers drawn from the Cauchy distribution:

\[P(x) = \dfrac{1}{\pi} \dfrac{\sigma}{(x - median)^2 + \sigma^2}\]
ceil() → Tensor

See torch.ceil()

ceil_() → Tensor

In-place version of ceil()

char()

Casts this tensor to char type

chunk(n_chunks, dim=0)

Splits this tensor into a tuple of tensors.

See torch.chunk().

clamp(min, max) → Tensor

See torch.clamp()

clamp_(min, max) → Tensor

In-place version of clamp()

clone() → Tensor

Returns a copy of the tensor. The copy has the same size and data type as the original tensor.

contiguous() → Tensor

Returns a contiguous Tensor containing the same data as this tensor. If this tensor is contiguous, this function returns the original tensor.

copy_(src, async=False) → Tensor

Copies the elements from src into this tensor and returns this tensor.

The source tensor should have the same number of elements as this tensor. It may be of a different data type or reside on a different device.

Parameters:
  • src (Tensor) – Source tensor to copy
  • async (bool) – If True and this copy is between CPU and GPU, then the copy may occur asynchronously with respect to the host. For other copies, this argument has no effect.
cos() → Tensor

See torch.cos()

cos_() → Tensor

In-place version of cos()

cosh() → Tensor

See torch.cosh()

cosh_() → Tensor

In-place version of cosh()

cpu()

Returns a CPU copy of this tensor if it’s not already on the CPU

cross(other, dim=-1) → Tensor

See torch.cross()

cuda(device=None, async=False)

Returns a copy of this object in CUDA memory.

If this object is already in CUDA memory and on the correct device, then no copy is performed and the original object is returned.

Parameters:
  • device (int) – The destination GPU id. Defaults to the current device.
  • async (bool) – If True and the source is in pinned memory, the copy will be asynchronous with respect to the host. Otherwise, the argument has no effect.
cumprod(dim) → Tensor

See torch.cumprod()

cumsum(dim) → Tensor

See torch.cumsum()

data_ptr() → int

Returns the address of the first element of this tensor.

diag(diagonal=0) → Tensor

See torch.diag()

dim() → int

Returns the number of dimensions of this tensor.

dist(other, p=2) → Tensor

See torch.dist()

div(value)

See torch.div()

div_(value)

In-place version of div()

dot(tensor2) → float

See torch.dot()

double()

Casts this tensor to double type

eig(eigenvectors=False) -> (Tensor, Tensor)

See torch.eig()

element_size() → int

Returns the size in bytes of an individual element.

Example

>>> torch.FloatTensor().element_size()
4
>>> torch.ByteTensor().element_size()
1
eq(other) → Tensor

See torch.eq()

eq_(other) → Tensor

In-place version of eq()

equal(other) → bool

See torch.equal()

exp() → Tensor

See torch.exp()

exp_() → Tensor

In-place version of exp()

expand(tensor, sizes) → Tensor

Returns a new view of the tensor with singleton dimensions expanded to a larger size.

Tensor can be also expanded to a larger number of dimensions, and the new ones will be appended at the front.

Expanding a tensor does not allocate new memory, but only creates a new view on the existing tensor where a dimension of size one is expanded to a larger size by setting the stride to 0. Any dimension of size 1 can be expanded to an arbitrary value without allocating new memory.

Parameters:*sizes (torch.Size or int...) – The desired expanded size

Example

>>> x = torch.Tensor([[1], [2], [3]])
>>> x.size()
torch.Size([3, 1])
>>> x.expand(3, 4)
 1  1  1  1
 2  2  2  2
 3  3  3  3
[torch.FloatTensor of size 3x4]
expand_as(tensor)

Expands this tensor to the size of the specified tensor.

This is equivalent to:

self.expand(tensor.size())
exponential_(lambd=1, *, generator=None) → Tensor

Fills this tensor with elements drawn from the exponential distribution:

\[P(x) = \lambda e^{-\lambda x}\]
fill_(value) → Tensor

Fills this tensor with the specified value.

float()

Casts this tensor to float type

floor() → Tensor

See torch.floor()

floor_() → Tensor

In-place version of floor()

fmod(divisor) → Tensor

See torch.fmod()

fmod_(divisor) → Tensor

In-place version of fmod()

frac() → Tensor

See torch.frac()

frac_() → Tensor

In-place version of frac()

gather(dim, index) → Tensor

See torch.gather()

ge(other) → Tensor

See torch.ge()

ge_(other) → Tensor

In-place version of ge()

gels(A) → Tensor

See torch.gels()

geometric_(p, *, generator=None) → Tensor

Fills this tensor with elements drawn from the geometric distribution:

\[P(X=k) = (1 - p)^{k - 1} p\]
geqrf() -> (Tensor, Tensor)

See torch.geqrf()

ger(vec2) → Tensor

See torch.ger()

gesv(A) → Tensor, Tensor

See torch.gesv()

gt(other) → Tensor

See torch.gt()

gt_(other) → Tensor

In-place version of gt()

half()

Casts this tensor to half-precision float type

histc(bins=100, min=0, max=0) → Tensor

See torch.histc()

index(m) → Tensor

Selects elements from this tensor using a binary mask or along a given dimension. The expression tensor.index(m) is equivalent to tensor[m].

Parameters:m (int or ByteTensor or slice) – The dimension or mask used to select elements
index_add_(dim, index, tensor) → Tensor

Accumulate the elements of tensor into the original tensor by adding to the indices in the order given in index. The shape of tensor must exactly match the elements indexed or an error will be raised.

Parameters:
  • dim (int) – Dimension along which to index
  • index (LongTensor) – Indices to select from tensor
  • tensor (Tensor) – Tensor containing values to add

Example

>>> x = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
>>> t = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> index = torch.LongTensor([0, 2, 1])
>>> x.index_add_(0, index, t)
>>> x
  2   3   4
  8   9  10
  5   6   7
[torch.FloatTensor of size 3x3]
index_copy_(dim, index, tensor) → Tensor

Copies the elements of tensor into the original tensor by selecting the indices in the order given in index. The shape of tensor must exactly match the elements indexed or an error will be raised.

Parameters:
  • dim (int) – Dimension along which to index
  • index (LongTensor) – Indices to select from tensor
  • tensor (Tensor) – Tensor containing values to copy

Example

>>> x = torch.Tensor(3, 3)
>>> t = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> index = torch.LongTensor([0, 2, 1])
>>> x.index_copy_(0, index, t)
>>> x
 1  2  3
 7  8  9
 4  5  6
[torch.FloatTensor of size 3x3]
index_fill_(dim, index, val) → Tensor

Fills the elements of the original tensor with value val by selecting the indices in the order given in index.

Parameters:
  • dim (int) – Dimension along which to index
  • index (LongTensor) – Indices
  • val (float) – Value to fill

Example

>>> x = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> index = torch.LongTensor([0, 2])
>>> x.index_fill_(1, index, -1)
>>> x
-1  2 -1
-1  5 -1
-1  8 -1
[torch.FloatTensor of size 3x3]
index_select(dim, index) → Tensor

See torch.index_select()

int()

Casts this tensor to int type

inverse() → Tensor

See torch.inverse()

is_contiguous() → bool

Returns True if this tensor is contiguous in memory in C order.

is_cuda
is_pinned()

Returns true if this tensor resides in pinned memory

is_set_to(tensor) → bool

Returns True if this object refers to the same THTensor object from the Torch C API as the given tensor.

is_signed()
kthvalue(k, dim=None) -> (Tensor, LongTensor)

See torch.kthvalue()

le(other) → Tensor

See torch.le()

le_(other) → Tensor

In-place version of le()

lerp(start, end, weight)

See torch.lerp()

lerp_(start, end, weight)

In-place version of lerp()

log() → Tensor

See torch.log()

log1p() → Tensor

See torch.log1p()

log1p_() → Tensor

In-place version of log1p()

log_() → Tensor

In-place version of log()

log_normal_(mean=1, std=2, *, generator=None)

Fills this tensor with numbers samples from the log-normal distribution parameterized by the given mean (µ) and standard deviation (σ). Note that mean and stdv are the mean and standard deviation of the underlying normal distribution, and not of the returned distribution:

\[P(x) = \dfrac{1}{x \sigma \sqrt{2\pi}} e^{-\dfrac{(\ln x - \mu)^2}{2\sigma^2}}\]
long()

Casts this tensor to long type

lt(other) → Tensor

See torch.lt()

lt_(other) → Tensor

In-place version of lt()

map_(tensor, callable)

Applies callable for each element in this tensor and the given tensor and stores the results in this tensor. The callable should have the signature:

def callable(a, b) -> number
masked_copy_(mask, source)

Copies elements from source into this tensor at positions where the mask is one. The mask should have the same number of elements as this tensor. The source should have at least as many elements as the number of ones in mask

Parameters:
  • mask (ByteTensor) – The binary mask
  • source (Tensor) – The tensor to copy from

Note

The mask operates on the self tensor, not on the given source tensor.

masked_fill_(mask, value)

Fills elements of this tensor with value where mask is one. The mask should have the same number of elements as this tensor, but the shape may differ.

Parameters:
  • mask (ByteTensor) – The binary mask
  • value (Tensor) – The value to fill
masked_select(mask) → Tensor

See torch.masked_select()

max(dim=None) -> float or (Tensor, Tensor)

See torch.max()

mean(dim=None) -> float or (Tensor, Tensor)

See torch.mean()

median(dim=-1, values=None, indices=None) -> (Tensor, LongTensor)

See torch.median()

min(dim=None) -> float or (Tensor, Tensor)

See torch.min()

mm(mat2) → Tensor

See torch.mm()

mode(dim=-1, values=None, indices=None) -> (Tensor, LongTensor)

See torch.mode()

mul(value) → Tensor

See torch.mul()

mul_(value)

In-place version of mul()

multinomial(num_samples, replacement=False, *, generator=None)

See torch.multinomial()

mv(vec) → Tensor

See torch.mv()

narrow(dimension, start, length) → Tensor

Returns a new tensor that is a narrowed version of this tensor. The dimension dim is narrowed from start to start + length. The returned tensor and this tensor share the same underlying storage.

Parameters:
  • dimension (int) – The dimension along which to narrow
  • start (int) – The starting dimension
  • length (int) –

Example

>>> x = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> x.narrow(0, 0, 2)
 1  2  3
 4  5  6
[torch.FloatTensor of size 2x3]
>>> x.narrow(1, 1, 2)
 2  3
 5  6
 8  9
[torch.FloatTensor of size 3x2]
ndimension() → int

Alias for dim()

ne(other) → Tensor

See torch.ne()

ne_(other) → Tensor

In-place version of ne()

neg() → Tensor

See torch.neg()

neg_() → Tensor

In-place version of neg()

nelement() → int

Alias for numel()

new(*args, **kwargs)

Constructs a new tensor of the same data type.

nonzero() → LongTensor

See torch.nonzero()

norm(p=2) → float

See torch.norm()

normal_(mean=0, std=1, *, generator=None)

Fills this tensor with elements samples from the normal distribution parameterized by mean and std.

numel() → int

See torch.numel()

numpy() → ndarray

Returns this tensor as a NumPy ndarray. This tensor and the returned ndarray share the same underlying storage. Changes to this tensor will be reflected in the ndarray and vice versa.

orgqr(input2) → Tensor

See torch.orgqr()

ormqr(input2, input3, left=True, transpose=False) → Tensor

See torch.ormqr()

permute(*dims)

Permute the dimensions of this tensor.

Parameters:*dims (int...) – The desired ordering of dimensions

Example

>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(2, 0, 1).size()
torch.Size([5, 2, 3])
pin_memory()

Copies the tensor to pinned memory, if it’s not already pinned.

potrf(upper=True) → Tensor

See torch.potrf()

potri(upper=True) → Tensor

See torch.potri()

potrs(input2, upper=True) → Tensor

See torch.potrs()

pow(exponent)

See torch.pow()

pow_(exponent)

In-place version of pow()

prod() → float

See torch.prod()

pstrf(upper=True, tol=-1) -> (Tensor, IntTensor)

See torch.pstrf()

qr() -> (Tensor, Tensor)

See torch.qr()

random_(from=0, to=None, *, generator=None)

Fills this tensor with numbers sampled from the uniform distribution or discrete uniform distribution over [from, to - 1]. If not specified, the values are only bounded by this tensor’s data type.

reciprocal() → Tensor

See torch.reciprocal()

reciprocal_() → Tensor

In-place version of reciprocal()

remainder(divisor) → Tensor

See torch.remainder()

remainder_(divisor) → Tensor

In-place version of remainder()

renorm(p, dim, maxnorm) → Tensor

See torch.renorm()

renorm_(p, dim, maxnorm) → Tensor

In-place version of renorm()

repeat(*sizes)

Repeats this tensor along the specified dimensions.

Unlike expand(), this function copies the tensor’s data.

Parameters:*sizes (torch.Size or int...) – The number of times to repeat this tensor along each dimension

Example

>>> x = torch.Tensor([1, 2, 3])
>>> x.repeat(4, 2)
 1  2  3  1  2  3
 1  2  3  1  2  3
 1  2  3  1  2  3
 1  2  3  1  2  3
[torch.FloatTensor of size 4x6]
>>> x.repeat(4, 2, 1).size()
torch.Size([4, 2, 3])
resize_(*sizes)

Resizes this tensor to the specified size. If the number of elements is larger than the current storage size, then the underlying storage is resized to fit the new number of elements. If the number of elements is smaller, the underlying storage is not changed. Existing elements are preserved but any new memory is uninitialized.

Parameters:sizes (torch.Size or int...) – The desired size

Example

>>> x = torch.Tensor([[1, 2], [3, 4], [5, 6]])
>>> x.resize_(2, 2)
>>> x
 1  2
 3  4
[torch.FloatTensor of size 2x2]
resize_as_(tensor)

Resizes the current tensor to be the same size as the specified tensor. This is equivalent to:

self.resize_(tensor.size())
round() → Tensor

See torch.round()

round_() → Tensor

In-place version of round()

rsqrt() → Tensor

See torch.rsqrt()

rsqrt_() → Tensor

In-place version of rsqrt()

scatter_(input, dim, index, src) → Tensor

Writes all values from the Tensor src into self at the indices specified in the index Tensor. The indices are specified with respect to the given dimension, dim, in the manner described in gather().

Note that, as for gather, the values of index must be between 0 and (self.size(dim) -1) inclusive and all values in a row along the specified dimension must be unique.

Parameters:
  • input (Tensor) – The source tensor
  • dim (int) – The axis along which to index
  • index (LongTensor) – The indices of elements to scatter
  • src (Tensor or float) – The source element(s) to scatter

Example:

>>> x = torch.rand(2, 5)
>>> x

 0.4319  0.6500  0.4080  0.8760  0.2355
 0.2609  0.4711  0.8486  0.8573  0.1029
[torch.FloatTensor of size 2x5]

>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

 0.4319  0.4711  0.8486  0.8760  0.2355
 0.0000  0.6500  0.0000  0.8573  0.0000
 0.2609  0.0000  0.4080  0.0000  0.1029
[torch.FloatTensor of size 3x5]

>>> z = torch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23)
>>> z

 0.0000  0.0000  1.2300  0.0000
 0.0000  0.0000  0.0000  1.2300
[torch.FloatTensor of size 2x4]
select(dim, index) → Tensor or number

Slices the tensor along the selected dimension at the given index. If this tensor is one dimensional, this function returns a number. Otherwise, it returns a tensor with the given dimension removed.

Parameters:
  • dim (int) – Dimension to slice
  • index (int) – Index to select

Note

select() is equivalent to slicing. For example, tensor.select(0, index) is equivalent to tensor[index] and tensor.select(2, index) is equivalent to tensor[:,:,index].

set_(source=None, storage_offset=0, size=None, stride=None)

Sets the underlying storage, size, and strides. If source is a tensor, this tensor will share the same storage and have the same size and strides as the given tensor. Changes to elements in one tensor will be reflected in the other.

If source is a Storage, the method sets the underlying storage, offset, size, and stride.

Parameters:
  • source (Tensor or Storage) – The tensor or storage to use
  • storage_offset (int) – The offset in the storage
  • size (torch.Size) – The desired size. Defaults to the size of the source.
  • stride (tuple) – The desired stride. Defaults to C-contiguous strides.
share_memory_()

Moves the underlying storage to shared memory.

This is a no-op if the underlying storage is already in shared memory and for CUDA tensors. Tensors in shared memory cannot be resized.

short()

Casts this tensor to short type

sigmoid() → Tensor

See torch.sigmoid()

sigmoid_() → Tensor

In-place version of sigmoid()

sign() → Tensor

See torch.sign()

sign_() → Tensor

In-place version of sign()

sin() → Tensor

See torch.sin()

sin_() → Tensor

In-place version of sin()

sinh() → Tensor

See torch.sinh()

sinh_() → Tensor

In-place version of sinh()

size() → torch.Size

Returns the size of the tensor. The returned value is a subclass of tuple.

Example

>>> torch.Tensor(3, 4, 5).size()
torch.Size([3, 4, 5])
sort(dim=None, descending=False) -> (Tensor, LongTensor)

See torch.sort()

split(split_size, dim=0)

Splits this tensor into a tuple of tensors.

See torch.split().

sqrt() → Tensor

See torch.sqrt()

sqrt_() → Tensor

In-place version of sqrt()

squeeze(dim=None)

See torch.squeeze()

squeeze_(dim=None)

In-place version of squeeze()

std() → float

See torch.std()

storage() → torch.Storage

Returns the underlying storage

storage_offset() → int

Returns this tensor’s offset in the underlying storage in terms of number of storage elements (not bytes).

Example

>>> x = torch.Tensor([1, 2, 3, 4, 5])
>>> x.storage_offset()
0
>>> x[3:].storage_offset()
3
classmethod storage_type()
stride() → tuple

Returns the stride of the tensor.

sub(value, other) → Tensor

Subtracts a scalar or tensor from this tensor. If both value and other are specified, each element of other is scaled by value before being used.

sub_(x) → Tensor

In-place version of sub()

sum(dim=None) → float

See torch.sum()

svd(some=True) -> (Tensor, Tensor, Tensor)

See torch.svd()

symeig(eigenvectors=False, upper=True) -> (Tensor, Tensor)

See torch.symeig()

t() → Tensor

See torch.t()

t_() → Tensor

In-place version of t()

tan() → Tensor

See torch.tan()

tan_() → Tensor

In-place version of tan()

tanh() → Tensor

See torch.tanh()

tanh_() → Tensor

In-place version of tanh()

tolist()

Returns a nested list represenation of this tensor.

topk(k, dim=None, largest=True, sorted=True) -> (Tensor, LongTensor)

See torch.topk()

trace() → float

See torch.trace()

transpose(dim0, dim1) → Tensor

See torch.transpose()

transpose_(dim0, dim1) → Tensor

In-place version of transpose()

tril(k=0) → Tensor

See torch.tril()

tril_(k=0) → Tensor

In-place version of tril()

triu(k=0) → Tensor

See torch.triu()

triu_(k=0) → Tensor

In-place version of triu()

trtrs(A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor)

See torch.trtrs()

trunc() → Tensor

See torch.trunc()

trunc_() → Tensor

In-place version of trunc()

type(new_type=None, async=False)

Casts this object to the specified type.

If this is already of the correct type, no copy is performed and the original object is returned.

Parameters:
  • new_type (type or string) – The desired type
  • async (bool) – If True, and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
type_as(tensor)

Returns this tensor cast to the type of the given tensor.

This is a no-op if the tensor is already of the correct type. This is equivalent to:

self.type(tensor.type())
Params:
tensor (Tensor): the tensor which has the desired type
unfold(dim, size, step) → Tensor

Returns a tensor which contains all slices of size size in the dimension dim.

Step between two slices is given by step.

If sizedim is the original size of dimension dim, the size of dimension dim in the returned tensor will be (sizedim - size) / step + 1

An additional dimension of size size is appended in the returned tensor.

Parameters:
  • dim (int) – dimension in which unfolding happens
  • size (int) – size of each slice that is unfolded
  • step (int) – the step between each slice

Example:

>>> x = torch.arange(1, 8)
>>> x

 1
 2
 3
 4
 5
 6
 7
[torch.FloatTensor of size 7]

>>> x.unfold(0, 2, 1)

 1  2
 2  3
 3  4
 4  5
 5  6
 6  7
[torch.FloatTensor of size 6x2]

>>> x.unfold(0, 2, 2)

 1  2
 3  4
 5  6
[torch.FloatTensor of size 3x2]
uniform_(from=0, to=1) → Tensor

Fills this tensor with numbers sampled from the uniform distribution:

unsqueeze(dim)

See torch.unsqueeze()

unsqueeze_(dim)

In-place version of unsqueeze()

var() → float

See torch.var()

view(*args) → Tensor

Returns a new tensor with the same data but different size.

The returned tensor shares the same data and must have the same number of elements, but may have a different size. A tensor must be contiguous() to be viewed.

Parameters:args (torch.Size or int...) – Desired size

Example

>>> x = torch.randn(4, 4)
>>> x.size()
torch.Size([4, 4])
>>> y = x.view(16)
>>> y.size()
torch.Size([16])
>>> z = x.view(-1, 8)  # the size -1 is inferred from other dimensions
>>> z.size()
torch.Size([2, 8])
view_as(tensor)

Returns this tensor viewed as the size as the specified tensor.

This is equivalent to:

self.view(tensor.size())
zero_()

Fills this tensor with zeros.