Source code for torchvision.transforms.v2._utils
from __future__ import annotations
import collections.abc
import numbers
from collections.abc import Sequence
from contextlib import suppress
from typing import Any, Callable, Literal
import PIL.Image
import torch
from torchvision import tv_tensors
from torchvision._utils import sequence_to_str
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]:
if not isinstance(arg, (int, float, Sequence)):
raise TypeError(f"{name} should be a number or a sequence of numbers. Got {type(arg)}")
if isinstance(arg, Sequence) and len(arg) not in (1, 2):
raise ValueError(f"If {name} is a sequence its length should be 1 or 2. Got {len(arg)}")
if isinstance(arg, Sequence):
for element in arg:
if not isinstance(element, (int, float)):
raise ValueError(f"{name} should be a sequence of numbers. Got {type(element)}")
if isinstance(arg, (int, float)):
arg = [float(arg), float(arg)]
elif isinstance(arg, Sequence):
if len(arg) == 1:
arg = [float(arg[0]), float(arg[0])]
else:
arg = [float(arg[0]), float(arg[1])]
return arg
def _check_fill_arg(fill: _FillType | dict[type | str, _FillType]) -> None:
if isinstance(fill, dict):
for value in fill.values():
_check_fill_arg(value)
else:
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.")
def _convert_fill_arg(fill: _FillType) -> _FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
# if fill is None:
# fill = 0
if fill is None:
return fill
if not isinstance(fill, (int, float)):
fill = [float(v) for v in list(fill)]
return fill # type: ignore[return-value]
def _setup_fill_arg(fill: _FillType | dict[type | str, _FillType]) -> dict[type | str, _FillTypeJIT]:
_check_fill_arg(fill)
if isinstance(fill, dict):
for k, v in fill.items():
fill[k] = _convert_fill_arg(v)
return fill # type: ignore[return-value]
else:
return {"others": _convert_fill_arg(fill)}
def _get_fill(fill_dict, inpt_type):
if inpt_type in fill_dict:
return fill_dict[inpt_type]
elif "others" in fill_dict:
return fill_dict["others"]
else:
RuntimeError("This should never happen, please open an issue on the torchvision repo if you hit this.")
def _check_padding_arg(padding: int | Sequence[int]) -> None:
err_msg = f"Padding must be an int or a 1, 2, or 4 element of tuple or list, got {padding}."
if isinstance(padding, (tuple, list)):
if len(padding) not in [1, 2, 4] or not all(isinstance(p, int) for p in padding):
raise ValueError(err_msg)
elif not isinstance(padding, int):
raise ValueError(err_msg)
# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
# https://github.com/pytorch/vision/issues/6250
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
"""
This heuristic covers three cases:
1. The input is tuple or list whose second item is a labels tensor. This happens for already batched
classification inputs for MixUp and CutMix (typically after the Dataloder).
2. The input is a tuple or list whose second item is a dictionary that contains the labels tensor
under a label-like (see below) key. This happens for the inputs of detection models.
3. The input is a dictionary that is structured as the one from 2.
What is "label-like" key? We first search for an case-insensitive match of 'labels' inside the keys of the
dictionary. This is the name our detection models expect. If we can't find that, we look for a case-insensitive
match of the term 'label' anywhere inside the key, i.e. 'FooLaBeLBar'. If we can't find that either, the dictionary
contains no "label-like" key.
"""
if isinstance(inputs, (tuple, list)):
inputs = inputs[1]
# MixUp, CutMix
if is_pure_tensor(inputs):
return inputs
if not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"When using the default labels_getter, the input passed to forward must be a dictionary or a two-tuple "
f"whose second item is a dictionary or a tensor, but got {inputs} instead."
)
candidate_key = None
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
if candidate_key is None:
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if "label" in key.lower())
if candidate_key is None:
raise ValueError(
"Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?"
"If there are no labels in the sample by design, pass labels_getter=None."
)
return inputs[candidate_key]
def _parse_labels_getter(labels_getter: str | Callable[[Any], Any] | None) -> Callable[[Any], Any]:
if labels_getter == "default":
return _find_labels_default_heuristic
elif callable(labels_getter):
return labels_getter
elif labels_getter is None:
return lambda _: None
else:
raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.")
[docs]def get_bounding_boxes(flat_inputs: list[Any]) -> tv_tensors.BoundingBoxes:
"""Return the Bounding Boxes in the input.
Assumes only one ``BoundingBoxes`` object is present.
"""
# This assumes there is only one bbox per sample as per the general convention
try:
return next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.BoundingBoxes))
except StopIteration:
raise ValueError("No bounding boxes were found in the sample")
[docs]def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
"""Return Channel, Height, and Width."""
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
}
if not chws:
raise TypeError("No image or video was found in the sample")
elif len(chws) > 1:
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
c, h, w = chws.pop()
return c, h, w
[docs]def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
"""Return Height and Width."""
sizes = {
tuple(get_size(inpt))
for inpt in flat_inputs
if check_type(
inpt,
(
is_pure_tensor,
tv_tensors.Image,
PIL.Image.Image,
tv_tensors.Video,
tv_tensors.Mask,
tv_tensors.BoundingBoxes,
),
)
}
if not sizes:
raise TypeError("No image, video, mask or bounding box was found in the sample")
elif len(sizes) > 1:
raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}")
h, w = sizes.pop()
return h, w
def check_type(obj: Any, types_or_checks: tuple[type | Callable[[Any], bool], ...]) -> bool:
for type_or_check in types_or_checks:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
return True
return False
def has_any(flat_inputs: list[Any], *types_or_checks: type | Callable[[Any], bool]) -> bool:
for inpt in flat_inputs:
if check_type(inpt, types_or_checks):
return True
return False
def has_all(flat_inputs: list[Any], *types_or_checks: type | Callable[[Any], bool]) -> bool:
for type_or_check in types_or_checks:
for inpt in flat_inputs:
if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt):
break
else:
return False
return True