Shortcuts

Source code for torchvision.models.convnext

from collections.abc import Sequence
from functools import partial
from typing import Any, Callable, Optional

import torch
from torch import nn, Tensor
from torch.nn import functional as F

from ..ops.misc import Conv2dNormActivation, Permute
from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface


__all__ = [
    "ConvNeXt",
    "ConvNeXt_Tiny_Weights",
    "ConvNeXt_Small_Weights",
    "ConvNeXt_Base_Weights",
    "ConvNeXt_Large_Weights",
    "convnext_tiny",
    "convnext_small",
    "convnext_base",
    "convnext_large",
]


class LayerNorm2d(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        x = x.permute(0, 2, 3, 1)
        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        x = x.permute(0, 3, 1, 2)
        return x


class CNBlock(nn.Module):
    def __init__(
        self,
        dim,
        layer_scale: float,
        stochastic_depth_prob: float,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)

        self.block = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
            Permute([0, 2, 3, 1]),
            norm_layer(dim),
            nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
            nn.GELU(),
            nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
            Permute([0, 3, 1, 2]),
        )
        self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")

    def forward(self, input: Tensor) -> Tensor:
        result = self.layer_scale * self.block(input)
        result = self.stochastic_depth(result)
        result += input
        return result


class CNBlockConfig:
    # Stores information listed at Section 3 of the ConvNeXt paper
    def __init__(
        self,
        input_channels: int,
        out_channels: Optional[int],
        num_layers: int,
    ) -> None:
        self.input_channels = input_channels
        self.out_channels = out_channels
        self.num_layers = num_layers

    def __repr__(self) -> str:
        s = self.__class__.__name__ + "("
        s += "input_channels={input_channels}"
        s += ", out_channels={out_channels}"
        s += ", num_layers={num_layers}"
        s += ")"
        return s.format(**self.__dict__)


class ConvNeXt(nn.Module):
    def __init__(
        self,
        block_setting: list[CNBlockConfig],
        stochastic_depth_prob: float = 0.0,
        layer_scale: float = 1e-6,
        num_classes: int = 1000,
        block: Optional[Callable[..., nn.Module]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__()
        _log_api_usage_once(self)

        if not block_setting:
            raise ValueError("The block_setting should not be empty")
        elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
            raise TypeError("The block_setting should be List[CNBlockConfig]")

        if block is None:
            block = CNBlock

        if norm_layer is None:
            norm_layer = partial(LayerNorm2d, eps=1e-6)

        layers: list[nn.Module] = []

        # Stem
        firstconv_output_channels = block_setting[0].input_channels
        layers.append(
            Conv2dNormActivation(
                3,
                firstconv_output_channels,
                kernel_size=4,
                stride=4,
                padding=0,
                norm_layer=norm_layer,
                activation_layer=None,
                bias=True,
            )
        )

        total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
        stage_block_id = 0
        for cnf in block_setting:
            # Bottlenecks
            stage: list[nn.Module] = []
            for _ in range(cnf.num_layers):
                # adjust stochastic depth probability based on the depth of the stage block
                sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
                stage.append(block(cnf.input_channels, layer_scale, sd_prob))
                stage_block_id += 1
            layers.append(nn.Sequential(*stage))
            if cnf.out_channels is not None:
                # Downsampling
                layers.append(
                    nn.Sequential(
                        norm_layer(cnf.input_channels),
                        nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
                    )
                )

        self.features = nn.Sequential(*layers)
        self.avgpool = nn.AdaptiveAvgPool2d(1)

        lastblock = block_setting[-1]
        lastconv_output_channels = (
            lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels
        )
        self.classifier = nn.Sequential(
            norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes)
        )

        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def _forward_impl(self, x: Tensor) -> Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = self.classifier(x)
        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)


def _convnext(
    block_setting: list[CNBlockConfig],
    stochastic_depth_prob: float,
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> ConvNeXt:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

    model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

    return model


_COMMON_META = {
    "min_size": (32, 32),
    "categories": _IMAGENET_CATEGORIES,
    "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
    "_docs": """
        These weights improve upon the results of the original paper by using a modified version of TorchVision's
        `new training recipe
        <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
    """,
}


[docs]class ConvNeXt_Tiny_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth", transforms=partial(ImageClassification, crop_size=224, resize_size=236), meta={ **_COMMON_META, "num_params": 28589128, "_metrics": { "ImageNet-1K": { "acc@1": 82.520, "acc@5": 96.146, } }, "_ops": 4.456, "_file_size": 109.119, }, ) DEFAULT = IMAGENET1K_V1
[docs]class ConvNeXt_Small_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_small-0c510722.pth", transforms=partial(ImageClassification, crop_size=224, resize_size=230), meta={ **_COMMON_META, "num_params": 50223688, "_metrics": { "ImageNet-1K": { "acc@1": 83.616, "acc@5": 96.650, } }, "_ops": 8.684, "_file_size": 191.703, }, ) DEFAULT = IMAGENET1K_V1
[docs]class ConvNeXt_Base_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_base-6075fbad.pth", transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 88591464, "_metrics": { "ImageNet-1K": { "acc@1": 84.062, "acc@5": 96.870, } }, "_ops": 15.355, "_file_size": 338.064, }, ) DEFAULT = IMAGENET1K_V1
[docs]class ConvNeXt_Large_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_large-ea097f82.pth", transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 197767336, "_metrics": { "ImageNet-1K": { "acc@1": 84.414, "acc@5": 96.976, } }, "_ops": 34.361, "_file_size": 754.537, }, ) DEFAULT = IMAGENET1K_V1
[docs]@register_model() @handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1)) def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: """ConvNeXt Tiny model architecture from the `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper. Args: weights (:class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_ for more details about this class. .. autoclass:: torchvision.models.ConvNeXt_Tiny_Weights :members: """ weights = ConvNeXt_Tiny_Weights.verify(weights) block_setting = [ CNBlockConfig(96, 192, 3), CNBlockConfig(192, 384, 3), CNBlockConfig(384, 768, 9), CNBlockConfig(768, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
[docs]@register_model() @handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1)) def convnext_small( *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any ) -> ConvNeXt: """ConvNeXt Small model architecture from the `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper. Args: weights (:class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Small_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_ for more details about this class. .. autoclass:: torchvision.models.ConvNeXt_Small_Weights :members: """ weights = ConvNeXt_Small_Weights.verify(weights) block_setting = [ CNBlockConfig(96, 192, 3), CNBlockConfig(192, 384, 3), CNBlockConfig(384, 768, 27), CNBlockConfig(768, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4) return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
[docs]@register_model() @handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1)) def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: """ConvNeXt Base model architecture from the `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper. Args: weights (:class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Base_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_ for more details about this class. .. autoclass:: torchvision.models.ConvNeXt_Base_Weights :members: """ weights = ConvNeXt_Base_Weights.verify(weights) block_setting = [ CNBlockConfig(128, 256, 3), CNBlockConfig(256, 512, 3), CNBlockConfig(512, 1024, 27), CNBlockConfig(1024, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
[docs]@register_model() @handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1)) def convnext_large( *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any ) -> ConvNeXt: """ConvNeXt Large model architecture from the `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper. Args: weights (:class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`, optional): The pretrained weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Large_Weights` below for more details and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext`` base class. Please refer to the `source code <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_ for more details about this class. .. autoclass:: torchvision.models.ConvNeXt_Large_Weights :members: """ weights = ConvNeXt_Large_Weights.verify(weights) block_setting = [ CNBlockConfig(192, 384, 3), CNBlockConfig(384, 768, 3), CNBlockConfig(768, 1536, 27), CNBlockConfig(1536, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources