EGreedyWrapper¶
- class torchrl.modules.tensordict_module.EGreedyWrapper(*args, **kwargs)[source]¶
[Deprecated] Epsilon-Greedy PO wrapper.
- Parameters:
policy (TensorDictModule) – a deterministic policy.
- Keyword Arguments:
eps_init (scalar, optional) – initial epsilon value. default: 1.0
eps_end (scalar, optional) – final epsilon value. default: 0.1
annealing_num_steps (int, optional) – number of steps it will take for epsilon to reach the eps_end value
action_key (NestedKey, optional) – the key where the action can be found in the input tensordict. Default is
"action".action_mask_key (NestedKey, optional) – the key where the action mask can be found in the input tensordict. Default is
None(corresponding to no mask).spec (TensorSpec, optional) – if provided, the sampled action will be taken from this action space. If not provided, the exploration wrapper will attempt to recover it from the policy.
Note
Once a module has been wrapped in
EGreedyWrapper, it is crucial to incorporate a call tostep()in the training loop to update the exploration factor. Since it is not easy to capture this omission no warning or exception will be raised if this is ommitted!Examples
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import EGreedyWrapper, Actor >>> from torchrl.data import BoundedTensorSpec >>> torch.manual_seed(0) >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) >>> module = torch.nn.Linear(4, 4, bias=False) >>> policy = Actor(spec=spec, module=module) >>> explorative_policy = EGreedyWrapper(policy, eps_init=0.2) >>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) >>> print(explorative_policy(td).get("action")) tensor([[ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.9055, -0.9277, -0.6295, -0.2532], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<AddBackward0>)