RunningMeanStd#
- class torchrl.envs.transforms.RunningMeanStd(shape: tuple = (), epsilon: float = 0.0001)[source]#
Tracks running mean and variance using Welford’s parallel algorithm.
Buffers are registered so the statistics are included in
state_dict()and move correctly with.to(device).- Parameters:
shape (tuple) – feature shape to track (e.g.
(obs_dim,)or()for scalars).epsilon (float, optional) – small initial count for numerical stability. Default:
1e-4.
Examples
>>> rms = RunningMeanStd(shape=(4,)) >>> rms.update(torch.randn(32, 4)) >>> normed = rms.normalize(torch.randn(8, 4)) >>> normed.shape torch.Size([8, 4])
- update(x: Tensor) None[source]#
Update running statistics with a new batch.
- Parameters:
x (torch.Tensor) – batch of samples. All leading dimensions are treated as the batch dimension; trailing dimensions must match
self.mean.shape.