CrossGroupCritic#
- class torchrl.modules.CrossGroupCritic(*args, **kwargs)[source]#
Centralised critic that conditions on observations from multiple agent groups.
Standard
MultiAgentMLPcentralises only within a single group.CrossGroupCriticremoves that restriction: it reads observations from an arbitrary number of groups (each potentially with a different observation dimensionality), encodes them to a shared embedding space, processes the joint representation through a shared MLP trunk, and writes a per-group value estimate back to the tensordict.This enables two use-cases that single-group critics cannot handle:
Heterogeneous teams — agents in different groups have different observation / action specs. Each group gets its own encoder (
Linear(obs_dim_g → d_model)), so no padding or obs-dim alignment is required.Ad-hoc teamwork — one group follows a fixed (non-training) policy but its observations still inform the value baseline of the training group. Pass the fixed group’s name via
detach_groupsso its encoder output is detached before building the team state: the critic sees the full team state but gradients do not flow into the fixed group’s observations.
Because
CrossGroupCriticis a plainTensorDictModule, it plugs intoMAPPOLossandIPPOLosswithout any changes to those classes.- Parameters:
group_map (dict[str, CrossCriticGroupSpec]) – ordered mapping from a group name to a
CrossCriticGroupSpecthat describes the group’s observation dimensionality, agent count, and tensordict keys.- Keyword Arguments:
d_model (int) – common embedding dimension. All per-group encoders project to this size. Defaults to
64.trunk_depth (int) – number of hidden layers in the shared MLP trunk. Defaults to
2.trunk_cells (int) – width of each trunk hidden layer. Defaults to
256.activation_class (type[nn.Module]) – activation used in encoders and trunk. Defaults to
Tanh.share_params (bool) – if
Truea single value head is shared across all groups (useful when groups are homogeneous or have the same role). IfFalseeach group gets its own head. Encoders are always group-specific and the central trunk is always shared. Defaults toFalse.detach_groups (iterable of str, optional) – names of groups whose encoder outputs should be detached before the trunk. Use this to include fixed-policy agents in the centralised state without propagating gradients to their observations. Defaults to
None.device (DEVICE_TYPING, optional) – device on which to allocate parameters. Defaults to
None(CPU).
Note
The order of keys in
group_mapdetermines the order of positional inputs to the inner network. Pythondictpreserves insertion order (Python 3.7+), so the mapping is stable.Example
>>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules.models.cross_group_critic import CrossGroupCritic, CrossCriticGroupSpec >>> group_map = { ... "soldiers": CrossCriticGroupSpec(obs_dim=12, n_agents=3, ... obs_key=("soldiers", "observation"), ... value_key=("soldiers", "state_value")), ... "medics": CrossCriticGroupSpec(obs_dim=8, n_agents=2, ... obs_key=("medics", "observation"), ... value_key=("medics", "state_value")), ... } >>> critic = CrossGroupCritic(group_map, d_model=32, trunk_depth=1, trunk_cells=64) >>> td = TensorDict( ... { ... "soldiers": {"observation": torch.zeros(4, 3, 12)}, ... "medics": {"observation": torch.zeros(4, 2, 8)}, ... }, ... batch_size=[4], ... ) >>> td = critic(td) >>> print(td["soldiers", "state_value"].shape) torch.Size([4, 3, 1]) >>> print(td["medics", "state_value"].shape) torch.Size([4, 2, 1])