Shortcuts

Source code for pyvers.implement_for

# Dynamic dispatch implementation based on module versions.
from __future__ import annotations

import collections
import inspect
import logging
import sys
import warnings
from collections.abc import Callable
from copy import copy
from functools import wraps
from importlib import import_module
from typing import Any, TypeVar

from packaging.version import parse

logger = logging.getLogger(__name__)

# Global flag for verbose output
VERBOSE = False

T = TypeVar("T", bound=Callable)

[docs]class implement_for: # noqa: N801 """A version decorator that checks version compatibility and implements functions. If specified module is missing or there is no fitting implementation, call of the decorated function will lead to the explicit error. In case of intersected ranges, last fitting implementation is used. This wrapper also works to implement different backends for a same function (eg. gym vs gymnasium, numpy vs jax-numpy etc). Args: module_name (str or callable): version is checked for the module with this name (e.g. "gym"). If a callable is provided, it should return the module. from_version: version from which implementation is compatible. Can be open (None). to_version: version from which implementation is no longer compatible. Can be open (None). Keyword Args: class_method (bool, optional): if ``True``, the function will be written as a class method. Defaults to ``False``. compilable (bool, optional): If ``False``, the module import happens only on the first call to the wrapped function. If ``True``, the module import happens when the wrapped function is initialized. Defaults to ``False``. Examples: >>> @implement_for("gym", "0.13", "0.14") >>> def fun(self, x): ... # Older gym versions will return x + 1 ... return x + 1 ... >>> @implement_for("gym", "0.14", "0.23") >>> def fun(self, x): ... # More recent gym versions will return x + 2 ... return x + 2 ... >>> @implement_for(lambda: import_module("gym"), "0.23", None) >>> def fun(self, x): ... # More recent gym versions will return x + 2 ... return x + 2 ... >>> @implement_for("gymnasium", None, "1.0.0") >>> def fun(self, x): ... # If gymnasium is to be used instead of gym, x+3 will be returned ... return x + 3 ... This indicates that the function is compatible with gym 0.13+, but doesn't with gym 0.14+. """ # Stores pointers to fitting implementations: dict[func_name] = func_pointer _implementations: dict[str, implement_for] = {} _setters: list[implement_for] = [] _cache_modules: dict[str, Any] = {} def __init__( self, module_name: str | Callable[[], Any], from_version: str | None = None, to_version: str | None = None, *, class_method: bool = False, compilable: bool = False, ): self.module_name = module_name self.from_version = from_version self.to_version = to_version self.class_method = class_method self._compilable = compilable self.fn: Callable | None = None self.func_name: str | None = None self.do_set: bool = False implement_for._setters.append(self) @staticmethod def check_version( version: str, from_version: str | None, to_version: str | None ) -> bool: version = parse(".".join([str(v) for v in parse(version).release])) return (from_version is None or version >= parse(from_version)) and ( to_version is None or version < parse(to_version) )
[docs] @staticmethod def get_class_that_defined_method(f: Callable) -> Any | None: """Returns the class of a method, if it is defined, and None otherwise.""" out = f.__globals__.get(f.__qualname__.split(".")[0], None) return out
@classmethod def get_func_name(cls, fn: Callable) -> str: # produces a name like module.Class.method or module.function fn_str = str(fn).split(".") if fn_str[0].startswith("<bound method "): first = fn_str[0][len("<bound method ") :] elif fn_str[0].startswith("<function "): first = fn_str[0][len("<function ") :] else: raise RuntimeError(f"Unknown func representation {fn}") last = fn_str[1:] if last: first = [first] last[-1] = last[-1].split(" ")[0] else: last = [first.split(" ")[0]] first = [] return ".".join([fn.__module__] + first + last) def _get_cls(self, fn: Callable) -> Any | None: cls = self.get_class_that_defined_method(fn) if cls is None: # class not yet defined return None if cls.__class__.__name__ == "function": cls = inspect.getmodule(fn) return cls
[docs] def module_set(self) -> None: """Sets the function in its module, if it exists already.""" if self.fn is None: return prev_setter = type(self)._implementations.get(self.get_func_name(self.fn), None) if prev_setter is not None: prev_setter.do_set = False type(self)._implementations[self.get_func_name(self.fn)] = self cls = self.get_class_that_defined_method(self.fn) if cls is not None: if cls.__class__.__name__ == "function": cls = inspect.getmodule(self.fn) else: # class not yet defined return try: delattr(cls, self.fn.__name__) except AttributeError: pass name = self.fn.__name__ if self.class_method: fn = classmethod(self.fn) else: fn = self.fn setattr(cls, name, fn)
[docs] @classmethod def import_module(cls, module_name: str | Callable[[], Any]) -> str: """Imports module and returns its version.""" if not callable(module_name): module = cls._cache_modules.get(module_name, None) if module is None: if module_name in sys.modules: sys.modules[module_name] = module = import_module(module_name) else: cls._cache_modules[module_name] = module = import_module( module_name ) else: module = module_name() return module.__version__
_lazy_impl = collections.defaultdict(list) def _delazify(self, func_name: str) -> Callable | None: out = None for local_call in implement_for._lazy_impl[func_name]: out = local_call() return out def __call__(self, fn: T) -> T: # function names are unique self.func_name = self.get_func_name(fn) self.fn = fn implement_for._lazy_impl[self.func_name].append(self._call) if self._compilable: _call_fn = self._delazify(self.func_name) if self.class_method and _call_fn is not None: return classmethod(_call_fn) # type: ignore return _call_fn if _call_fn is not None else fn @wraps(fn) def _lazy_call_fn(*args: Any, **kwargs: Any) -> Any: # first time we call the function, we also do the replacement. # This will cause the imports to occur only during the first call to fn result = self._delazify(self.func_name) if result is not None: return result(*args, **kwargs) return fn(*args, **kwargs) if self.class_method: return classmethod(_lazy_call_fn) # type: ignore return _lazy_call_fn # type: ignore def _check_backend_conflict(self, version: str, func_name: str) -> bool: """Check if there's a backend conflict and handle it.""" if self.check_version(version, self.from_version, self.to_version): if VERBOSE: module = ( import_module(self.module_name) if isinstance(self.module_name, str) else self.module_name() ) msg = ( f"Got multiple backends for {func_name}. " f"Using last queried ({module}, version {version})." ) warnings.warn(msg, stacklevel=2) return True return False def _handle_existing_implementation( self, func_name: str, implementations: dict[str, implement_for] ) -> Callable | None: """Handle the case where an implementation already exists.""" try: version = self.import_module(self.module_name) if self._check_backend_conflict(version, func_name): self.do_set = True if not self.do_set: return implementations[func_name].fn except ModuleNotFoundError: return implementations[func_name].fn return None def _handle_new_implementation(self) -> bool: """Handle the case where this is a new implementation.""" try: version = self.import_module(self.module_name) return self.check_version(version, self.from_version, self.to_version) except ModuleNotFoundError: return False def _call(self) -> Callable: """Handle the function call and return appropriate implementation.""" if self.fn is None: raise RuntimeError("Function not set") fn = self.fn func_name = self.func_name if func_name is None: raise RuntimeError("Function name not set") implementations = implement_for._implementations @wraps(fn) def unsupported(*args: Any, **kwargs: Any) -> Any: raise ModuleNotFoundError( f"Supported version of '{func_name}' has not been found." ) self.do_set = False if func_name in implementations: result = self._handle_existing_implementation(func_name, implementations) if result is not None: return result else: self.do_set = self._handle_new_implementation() if not self.do_set: return unsupported if self.do_set: self.module_set() return fn return unsupported
[docs] @classmethod def reset(cls, setters_dict: dict[str, implement_for] | None = None) -> None: """Resets the setters in setter_dict. Args: setters_dict: A copy of implementations. We iterate through its values and call :meth:`module_set` for each. """ if VERBOSE: logger.info("resetting implement_for") if setters_dict is None: setters_dict = copy(cls._implementations) for setter in setters_dict.values(): setter.module_set()
def __repr__(self) -> str: return ( f"{self.__class__.__name__}(module_name={self.module_name}, " f"from_version={self.from_version}, to_version={self.to_version})" )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources