Shortcuts

Source code for pyvers.implement_for

# Dynamic dispatch implementation based on module versions.
from __future__ import annotations

import collections
import inspect
import logging
import re
import sys
import warnings
from collections.abc import Callable
from copy import copy
from functools import partial, update_wrapper, wraps
from importlib import import_module
from typing import TYPE_CHECKING, Any, TypeVar


def _parse_version(version: str) -> tuple[int, ...]:
    """Parse a version string into a tuple of integers for comparison.

    Handles PEP 440 versions by extracting the release segment (numeric parts).
    Examples: "1.2.3" -> (1, 2, 3), "2.0.0rc1" -> (2, 0, 0), "1.0a1" -> (1, 0)
    """
    # Extract only the numeric release parts, stopping at first non-numeric segment
    parts = re.split(r"[^0-9]+", version)
    # Filter out empty strings and convert to integers
    return tuple(int(p) for p in parts if p)

if TYPE_CHECKING:
    from typing import Self

logger = logging.getLogger(__name__)

# Global flag for verbose output
VERBOSE = False

T = TypeVar("T", bound=Callable)


class _RegisterableFunction:
    """Wrapper that provides .register() for version-specific implementations.

    This class wraps a function decorated with @implement_for and provides a
    .register() method similar to functools.singledispatch, allowing users to
    register additional implementations for different version ranges without
    triggering linter warnings about function redefinition.

    Example:
        >>> @implement_for("numpy")
        ... def process_array(arr):
        ...     raise NotImplementedError("No matching implementation")
        ...
        >>> @process_array.register(from_version=None, to_version="2.0.0")
        ... def _(arr):
        ...     # numpy < 2.0 implementation
        ...     return arr * 2
        ...
        >>> @process_array.register(from_version="2.0.0")
        ... def _(arr):
        ...     # numpy >= 2.0 implementation
        ...     return arr * 3
    """

    def __init__(self, fn: Callable, implement_for_instance: implement_for) -> None:
        self._impl = implement_for_instance
        self._fn = fn
        update_wrapper(self, fn)

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        return self._fn(*args, **kwargs)

    def __get__(self, obj: Any, objtype: type | None = None) -> Any:
        """Implement the descriptor protocol to bind self for instance methods."""
        if obj is None:
            return self
        # Return a bound method-like callable that passes obj as first argument
        return partial(self, obj)

    def register(
        self, from_version: str | None = None, to_version: str | None = None
    ) -> Callable[[T], Self]:
        """Register an implementation for a specific version range.

        This method provides a singledispatch-style API for registering
        version-specific implementations. Use ``_`` as the function name
        to avoid linter warnings about redefinition.

        Args:
            from_version: Version from which this implementation is compatible.
                Can be None for open lower bound.
            to_version: Version from which this implementation is no longer
                compatible. Can be None for open upper bound.

        Returns:
            A decorator that registers the implementation and returns self.

        Example:
            >>> @my_function.register(from_version="1.0.0", to_version="2.0.0")
            ... def _(x):
            ...     return x + 1
        """

        def decorator(impl_fn: T) -> Self:
            setter = implement_for(
                self._impl.module_name,
                from_version,
                to_version,
                class_method=self._impl.class_method,
                compilable=self._impl._compilable,
            )
            # Use the original function name for registration
            setter.func_name = self._impl.func_name
            setter.fn = impl_fn
            implement_for._lazy_impl[self._impl.func_name].append(setter._call)
            return self

        return decorator

    def __repr__(self) -> str:
        return f"<RegisterableFunction {self._impl.func_name}>"


[docs] class implement_for: # noqa: N801 """A version decorator that checks version compatibility and implements functions. If specified module is missing or there is no fitting implementation, call of the decorated function will lead to the explicit error. In case of intersected ranges, last fitting implementation is used. This wrapper also works to implement different backends for a same function (eg. gym vs gymnasium, numpy vs jax-numpy etc). Args: module_name (str or callable): version is checked for the module with this name (e.g. "gym"). If a callable is provided, it should return the module. from_version: version from which implementation is compatible. Can be open (None). to_version: version from which implementation is no longer compatible. Can be open (None). Keyword Args: class_method (bool, optional): if ``True``, the function will be written as a class method. Defaults to ``False``. compilable (bool, optional): If ``False``, the module import happens only on the first call to the wrapped function. If ``True``, the module import happens when the wrapped function is initialized. Defaults to ``False``. Examples: Traditional API (requires ``# noqa: F811`` on redefinitions): >>> @implement_for("gym", "0.13", "0.14") ... def fun(self, x): ... # Older gym versions will return x + 1 ... return x + 1 ... >>> @implement_for("gym", "0.14", "0.23") ... def fun(self, x): # noqa: F811 ... # More recent gym versions will return x + 2 ... return x + 2 This indicates that the function is compatible with gym 0.13+, but doesn't with gym 0.14+. Register API (recommended, no ``# noqa`` needed): The decorated function has a ``.register()`` method similar to ``functools.singledispatch``. Use ``_`` as the function name for registered implementations to avoid linter warnings: >>> @implement_for("numpy") ... def process_array(arr): ... '''Process array with version-specific implementation.''' ... raise NotImplementedError("No matching implementation") ... >>> @process_array.register(from_version=None, to_version="2.0.0") ... def _(arr): ... # numpy < 2.0 implementation ... return arr * 2 ... >>> @process_array.register(from_version="2.0.0") ... def _(arr): ... # numpy >= 2.0 implementation ... return arr * 3 """ # Stores pointers to fitting implementations: dict[func_name] = func_pointer _implementations: dict[str, implement_for] = {} _setters: list[implement_for] = [] _cache_modules: dict[str, Any] = {} def __init__( self, module_name: str | Callable[[], Any], from_version: str | None = None, to_version: str | None = None, *, class_method: bool = False, compilable: bool = False, ): self.module_name = module_name self.from_version = from_version self.to_version = to_version self.class_method = class_method self._compilable = compilable self.fn: Callable | None = None self.func_name: str | None = None self.do_set: bool = False implement_for._setters.append(self) @staticmethod def check_version( version: str, from_version: str | None, to_version: str | None ) -> bool: version_tuple = _parse_version(version) return ( from_version is None or version_tuple >= _parse_version(from_version) ) and (to_version is None or version_tuple < _parse_version(to_version))
[docs] @staticmethod def get_class_that_defined_method(f: Callable) -> Any | None: """Returns the class of a method, if it is defined, and None otherwise.""" out = f.__globals__.get(f.__qualname__.split(".")[0], None) return out
@classmethod def get_func_name(cls, fn: Callable) -> str: # Unwrap _RegisterableFunction to get the underlying function if isinstance(fn, _RegisterableFunction): # Use the stored func_name from the implement_for instance return fn._impl.func_name # produces a name like module.Class.method or module.function fn_str = str(fn).split(".") if fn_str[0].startswith("<bound method "): first = fn_str[0][len("<bound method ") :] elif fn_str[0].startswith("<function "): first = fn_str[0][len("<function ") :] else: raise RuntimeError(f"Unknown func representation {fn}") last = fn_str[1:] if last: first = [first] last[-1] = last[-1].split(" ")[0] else: last = [first.split(" ")[0]] first = [] return ".".join([fn.__module__] + first + last) def _get_cls(self, fn: Callable) -> Any | None: cls = self.get_class_that_defined_method(fn) if cls is None: # class not yet defined return None if cls.__class__.__name__ == "function": cls = inspect.getmodule(fn) return cls
[docs] def module_set(self) -> None: """Sets the function in its module, if it exists already.""" if self.fn is None: return # Use self.func_name if set (for registered implementations), # otherwise compute from fn func_name = self.func_name or self.get_func_name(self.fn) prev_setter = type(self)._implementations.get(func_name, None) if prev_setter is not None: prev_setter.do_set = False type(self)._implementations[func_name] = self cls = self.get_class_that_defined_method(self.fn) if cls is not None: # If cls is not a class (it's a function, _RegisterableFunction, or other # callable), use the module instead if not isinstance(cls, type): cls = inspect.getmodule(self.fn) else: # class not yet defined return try: existing = getattr(cls, self.fn.__name__, None) delattr(cls, self.fn.__name__) except AttributeError: existing = None name = self.fn.__name__ if self.class_method: fn = classmethod(self.fn) else: fn = self.fn setattr(cls, name, fn)
[docs] @classmethod def import_module(cls, module_name: str | Callable[[], Any]) -> str: """Imports module and returns its version.""" if not callable(module_name): module = cls._cache_modules.get(module_name, None) if module is None: if module_name in sys.modules: sys.modules[module_name] = module = import_module(module_name) else: cls._cache_modules[module_name] = module = import_module( module_name ) else: module = module_name() return module.__version__
_lazy_impl = collections.defaultdict(list) def _delazify(self, func_name: str) -> Callable | None: out = None for local_call in implement_for._lazy_impl[func_name]: out = local_call() return out def __call__(self, fn: T) -> T | _RegisterableFunction: # function names are unique self.func_name = self.get_func_name(fn) self.fn = fn implement_for._lazy_impl[self.func_name].append(self._call) if self._compilable: _call_fn = self._delazify(self.func_name) if self.class_method and _call_fn is not None: return classmethod(_call_fn) # type: ignore result_fn = _call_fn if _call_fn is not None else fn return _RegisterableFunction(result_fn, self) @wraps(fn) def _lazy_call_fn(*args: Any, **kwargs: Any) -> Any: # first time we call the function, we also do the replacement. # This will cause the imports to occur only during the first call to fn result = self._delazify(self.func_name) if result is not None: return result(*args, **kwargs) return fn(*args, **kwargs) if self.class_method: return classmethod(_lazy_call_fn) # type: ignore return _RegisterableFunction(_lazy_call_fn, self) def _check_backend_conflict(self, version: str, func_name: str) -> bool: """Check if there's a backend conflict and handle it.""" if self.check_version(version, self.from_version, self.to_version): if VERBOSE: module = ( import_module(self.module_name) if isinstance(self.module_name, str) else self.module_name() ) msg = ( f"Got multiple backends for {func_name}. " f"Using last queried ({module}, version {version})." ) warnings.warn(msg, stacklevel=2) return True return False def _handle_existing_implementation( self, func_name: str, implementations: dict[str, implement_for] ) -> Callable | None: """Handle the case where an implementation already exists.""" try: version = self.import_module(self.module_name) if self._check_backend_conflict(version, func_name): self.do_set = True if not self.do_set: return implementations[func_name].fn except ModuleNotFoundError: return implementations[func_name].fn return None def _handle_new_implementation(self) -> bool: """Handle the case where this is a new implementation.""" try: version = self.import_module(self.module_name) return self.check_version(version, self.from_version, self.to_version) except ModuleNotFoundError: return False def _call(self) -> Callable: """Handle the function call and return appropriate implementation.""" if self.fn is None: raise RuntimeError("Function not set") fn = self.fn func_name = self.func_name if func_name is None: raise RuntimeError("Function name not set") implementations = implement_for._implementations @wraps(fn) def unsupported(*args: Any, **kwargs: Any) -> Any: raise ModuleNotFoundError( f"Supported version of '{func_name}' has not been found." ) self.do_set = False if func_name in implementations: result = self._handle_existing_implementation(func_name, implementations) if result is not None: return result else: self.do_set = self._handle_new_implementation() if not self.do_set: return unsupported if self.do_set: self.module_set() return fn return unsupported
[docs] @classmethod def reset(cls, setters_dict: dict[str, implement_for] | None = None) -> None: """Resets the setters in setter_dict. Args: setters_dict: A copy of implementations. We iterate through its values and call :meth:`module_set` for each. """ if VERBOSE: logger.info("resetting implement_for") if setters_dict is None: setters_dict = copy(cls._implementations) for setter in setters_dict.values(): setter.module_set()
def __repr__(self) -> str: return ( f"{self.__class__.__name__}(module_name={self.module_name}, " f"from_version={self.from_version}, to_version={self.to_version})" )

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources