torch.func.linearize¶
- torch.func.linearize(func, *primals)¶
- Returns the value of - funcat- primalsand linear approximation at- primals.- Parameters:
- func (Callable) – A Python function that takes one or more arguments. 
- primals (Tensors) – Positional arguments to - functhat must all be Tensors. These are the values at which the function is linearly approximated.
 
- Returns:
- Returns a - (output, jvp_fn)tuple containing the output of- funcapplied to- primalsand a function that computes the jvp of- funcevaluated at- primals.
- Return type:
 - linearize is useful if jvp is to be computed multiple times at - primals. However, to achieve this, linearize saves intermediate computation and has higher memory requrements than directly applying jvp. So, if all the- tangentsare known, it maybe more efficient to compute vmap(jvp) instead of using linearize.- Note - linearize evaluates - functwice. Please file an issue for an implementation with a single evaluation.- Example::
- >>> import torch >>> from torch.func import linearize >>> def fn(x): ... return x.sin() ... >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3)) >>> jvp_fn(torch.ones(3, 3)) tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) >>>