TransformedEnv¶
- class torchrl.envs.transforms.TransformedEnv(*args, **kwargs)[source]¶
A transformed environment.
- Parameters:
base_env (EnvBase) – original environment to be transformed.
transform (Transform or callable, optional) –
transform to apply to the tensordict resulting from
base_env.step(td). If none is provided, an empty Compose placeholder in an eval mode is used.Note
If
transformis a callable, it must receive as input a single tensordict and output a tensordict as well. The callable will be called atstepandresettime: if it acts on the reward (which is absent at reset time), a check needs to be implemented to ensure that the transform will run smoothly:>>> def add_1(data): ... if "reward" in data.keys(): ... return data.set("reward", data.get("reward") + 1) ... return data >>> env = TransformedEnv(base_env, add_1)
cache_specs (bool, optional) – if
True, the specs will be cached once and for all after the first call (i.e. the specs will be transformed only once). If the transform changes during training, the original spec transform may not be valid anymore, in which case this value should be set to False. Default is True.
- Keyword Arguments:
auto_unwrap (bool, optional) –
if
True, wrapping a transformed env in transformed env unwraps the transforms of the inner TransformedEnv in the outer one (the new instance). Defaults toTrue.Note
This behavior will switch to
Falsein v0.9.See also
Examples
>>> env = GymEnv("Pendulum-v0") >>> transform = RewardScaling(0.0, 1.0) >>> transformed_env = TransformedEnv(env, transform) >>> # check auto-unwrap >>> transformed_env = TransformedEnv(transformed_env, StepCounter()) >>> # The inner env has been unwrapped >>> assert isinstance(transformed_env.base_env, GymEnv)
Note
The first argument was renamed from
envtobase_envfor clarity. The oldenvargument is still supported for backward compatibility but will be removed in v0.12. A deprecation warning will be shown when using the old argument name.- add_truncated_keys() TransformedEnv[source]¶
Adds truncated keys to the environment.
- append_transform(transform: torchrl.envs.transforms.transforms.Transform | collections.abc.Callable[[tensordict.base.TensorDictBase], tensordict.base.TensorDictBase]) TransformedEnv[source]¶
Appends a transform to the env.
Transformor callable are accepted.
- property batch_locked: bool¶
Whether the environment can be used with a batch size different from the one it was initialized with or not.
If True, the env needs to be used with a tensordict having the same batch size as the env. batch_locked is an immutable property.
- property batch_size: Size¶
Number of envs batched in this environment instance organised in a torch.Size() object.
Environment may be similar or different but it is assumed that they have little if not no interactions between them (e.g., multi-task or batched execution in parallel).
- empty_cache()[source]¶
Erases all the cached values.
For regular envs, the key lists (reward, done etc) are cached, but in some cases they may change during the execution of the code (eg, when adding a transform).
- eval() TransformedEnv[source]¶
Set the module in evaluation mode.
This has an effect only on certain modules. See the documentation of particular modules for details of their behaviors in training/evaluation mode, i.e. whether they are affected, e.g.
Dropout,BatchNorm, etc.This is equivalent with
self.train(False).See Locally disabling gradient computation for a comparison between .eval() and several similar mechanisms that may be confused with it.
- Returns:
self
- Return type:
Module
- property input_spec: TensorSpec¶
Observation spec of the transformed environment.
- insert_transform(index: int, transform: Transform) TransformedEnv[source]¶
Inserts a transform to the env at the desired index.
Transformor callable are accepted.
- load_state_dict(state_dict: OrderedDict, **kwargs) None[source]¶
Copy parameters and buffers from
state_dictinto this module and its descendants.If
strictisTrue, then the keys ofstate_dictmust exactly match the keys returned by this module’sstate_dict()function.Warning
If
assignisTruethe optimizer must be created after the call toload_state_dictunlessget_swap_module_params_on_conversion()isTrue.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dictmatch the keys returned by this module’sstate_dict()function. Default:Trueassign (bool, optional) – When set to
False, the properties of the tensors in the current module are preserved whereas setting it toTruepreserves properties of the Tensors in the state dict. The only exception is therequires_gradfield ofParameterfor which the value from the module is preserved. Default:False
- Returns:
missing_keysis a list of str containing any keys that are expectedby this module but missing from the provided
state_dict.
unexpected_keysis a list of str containing the keys that are notexpected by this module but present in the provided
state_dict.
- Return type:
NamedTuplewithmissing_keysandunexpected_keysfields
Note
If a parameter or buffer is registered as
Noneand its corresponding key exists instate_dict,load_state_dict()will raise aRuntimeError.
- property output_spec: TensorSpec¶
Observation spec of the transformed environment.
- rand_action(tensordict: tensordict.base.TensorDictBase | None = None) TensorDict[source]¶
Performs a random action given the action_spec attribute.
- Parameters:
tensordict (TensorDictBase, optional) – tensordict where the resulting action should be written.
- Returns:
a tensordict object with the “action” entry updated with a random sample from the action-spec.
- set_missing_tolerance(mode=False)[source]¶
Indicates if an KeyError should be raised whenever an in_key is missing from the input tensordict.
- set_seed(seed: int | None = None, static_seed: bool = False) int | None[source]¶
Set the seeds of the environment.
- state_dict(*args, **kwargs) OrderedDict[source]¶
Return a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
Noneare not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters and buffers.
Warning
Currently
state_dict()also accepts positional arguments fordestination,prefixandkeep_varsin order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destinationas it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDictwill be created and returned. Default:None.prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''.keep_vars (bool, optional) – by default the
Tensors returned in the state dict are detached from autograd. If it’s set toTrue, detaching will not be performed. Default:False.
- Returns:
a dictionary containing a whole state of the module
- Return type:
dict
Example:
>>> # xdoctest: +SKIP("undefined vars") >>> module.state_dict().keys() ['bias', 'weight']
- to(*args, **kwargs) TransformedEnv[source]¶
Move and/or cast the parameters and buffers.
This can be called as
- to(device=None, dtype=None, non_blocking=False)[source]
- to(dtype, non_blocking=False)[source]
- to(tensor, non_blocking=False)[source]
- to(memory_format=torch.channels_last)[source]
Its signature is similar to
torch.Tensor.to(), but only accepts floating point or complexdtypes. In addition, this method will only cast the floating point or complex parameters and buffers todtype(if given). The integral parameters and buffers will be moveddevice, if that is given, but with dtypes unchanged. Whennon_blockingis set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.See below for examples.
Note
This method modifies the module in-place.
- Parameters:
device (
torch.device) – the desired device of the parameters and buffers in this moduledtype (
torch.dtype) – the desired floating point or complex dtype of the parameters and buffers in this moduletensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module
memory_format (
torch.memory_format) – the desired memory format for 4D parameters and buffers in this module (keyword only argument)
- Returns:
self
- Return type:
Module
Examples:
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device("cuda:1") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device("cpu") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
- train(mode: bool = True) TransformedEnv[source]¶
Set the module in training mode.
This has an effect only on certain modules. See the documentation of particular modules for details of their behaviors in training/evaluation mode, i.e., whether they are affected, e.g.
Dropout,BatchNorm, etc.- Parameters:
mode (bool) – whether to set training mode (
True) or evaluation mode (False). Default:True.- Returns:
self
- Return type:
Module