implement_for¶
- class torchrl.implement_for(module_name: str | Callable[[], Any], from_version: str | None = None, to_version: str | None = None, *, class_method: bool = False, compilable: bool = False)[source]¶
A version decorator that checks version compatibility and implements functions.
If specified module is missing or there is no fitting implementation, call of the decorated function will lead to the explicit error. In case of intersected ranges, last fitting implementation is used.
This wrapper also works to implement different backends for a same function (eg. gym vs gymnasium, numpy vs jax-numpy etc).
- Parameters:
module_name (str or callable) – version is checked for the module with this name (e.g. “gym”). If a callable is provided, it should return the module.
from_version – version from which implementation is compatible. Can be open (None).
to_version – version from which implementation is no longer compatible. Can be open (None).
- Keyword Arguments:
class_method (bool, optional) – if
True
, the function will be written as a class method. Defaults toFalse
.compilable (bool, optional) – If
False
, the module import happens only on the first call to the wrapped function. IfTrue
, the module import happens when the wrapped function is initialized. Defaults toFalse
.
Examples
>>> @implement_for("gym", "0.13", "0.14") >>> def fun(self, x): ... # Older gym versions will return x + 1 ... return x + 1 ... >>> @implement_for("gym", "0.14", "0.23") >>> def fun(self, x): ... # More recent gym versions will return x + 2 ... return x + 2 ... >>> @implement_for(lambda: import_module("gym"), "0.23", None) >>> def fun(self, x): ... # More recent gym versions will return x + 2 ... return x + 2 ... >>> @implement_for("gymnasium", None, "1.0.0") >>> def fun(self, x): ... # If gymnasium is to be used instead of gym, x+3 will be returned ... return x + 3 ...
This indicates that the function is compatible with gym 0.13+, but doesn’t with gym 0.14+.
- static get_class_that_defined_method(f: Callable) Any | None [source]¶
Returns the class of a method, if it is defined, and None otherwise.
- classmethod import_module(module_name: str | Callable[[], Any]) str [source]¶
Imports module and returns its version.
- classmethod reset(setters_dict: dict[str, implement_for] | None = None) None [source]¶
Resets the setters in setter_dict.
- Parameters:
setters_dict – A copy of implementations. We iterate through its values and call
module_set()
for each.