# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import importlib.util
import torch
from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
from torchrl.data.tensor_specs import (
Bounded,
Categorical,
Composite,
DEVICE_TYPING,
MultiCategorical,
MultiOneHot,
OneHot,
StackedComposite,
TensorSpec,
Unbounded,
)
from torchrl.data.utils import numpy_to_torch_dtype_dict
from torchrl.envs.common import _EnvWrapper, EnvBase
from torchrl.envs.libs.gym import gym_backend, set_gym_backend
from torchrl.envs.utils import (
_classproperty,
_selective_unsqueeze,
check_marl_grouping,
MarlGroupMapType,
)
_has_vmas = importlib.util.find_spec("vmas") is not None
__all__ = ["VmasWrapper", "VmasEnv"]
def _get_envs():
if not _has_vmas:
raise ImportError("VMAS is not installed in your virtual environment.")
import vmas
all_scenarios = vmas.scenarios + vmas.mpe_scenarios + vmas.debug_scenarios
return all_scenarios
@set_gym_backend("gym")
def _vmas_to_torchrl_spec_transform(
spec,
device,
categorical_action_encoding,
) -> TensorSpec:
gym_spaces = gym_backend("spaces")
if isinstance(spec, gym_spaces.discrete.Discrete):
action_space_cls = Categorical if categorical_action_encoding else OneHot
dtype = (
numpy_to_torch_dtype_dict[spec.dtype]
if categorical_action_encoding
else torch.long
)
return action_space_cls(spec.n, device=device, dtype=dtype)
elif isinstance(spec, gym_spaces.multi_discrete.MultiDiscrete):
dtype = (
numpy_to_torch_dtype_dict[spec.dtype]
if categorical_action_encoding
else torch.long
)
return (
MultiCategorical(spec.nvec, device=device, dtype=dtype)
if categorical_action_encoding
else MultiOneHot(spec.nvec, device=device, dtype=dtype)
)
elif isinstance(spec, gym_spaces.Box):
shape = spec.shape
if not len(shape):
shape = torch.Size([1])
dtype = numpy_to_torch_dtype_dict[spec.dtype]
low = torch.tensor(spec.low, device=device, dtype=dtype)
high = torch.tensor(spec.high, device=device, dtype=dtype)
is_unbounded = low.isinf().all() and high.isinf().all()
return (
Unbounded(shape, device=device, dtype=dtype)
if is_unbounded
else Bounded(
low,
high,
shape,
dtype=dtype,
device=device,
)
)
elif isinstance(spec, gym_spaces.Dict):
spec_out = {}
for key in spec.keys():
spec_out[key] = _vmas_to_torchrl_spec_transform(
spec[key],
device=device,
categorical_action_encoding=categorical_action_encoding,
)
# the batch-size must be set later
return Composite(spec_out, device=device)
else:
raise NotImplementedError(
f"spec of type {type(spec).__name__} is currently unaccounted for vmas"
)
[docs]
class VmasWrapper(_EnvWrapper):
"""Vmas environment wrapper.
GitHub: https://github.com/proroklab/VectorizedMultiAgentSimulator
Paper: https://arxiv.org/abs/2207.03530
Args:
env (``vmas.simulator.environment.environment.Environment``): the vmas environment to wrap.
Keyword Args:
num_envs (int): Number of vectorized simulation environments. VMAS performs vectorized simulations using PyTorch.
This argument indicates the number of vectorized environments that should be simulated in a batch. It will also
determine the batch size of the environment.
device (torch.device, optional): Device for simulation. Defaults to the default device. All the tensors created by VMAS
will be placed on this device.
continuous_actions (bool, optional): Whether to use continuous actions. Defaults to ``True``. If ``False``, actions
will be discrete. The number of actions and their size will depend on the chosen scenario.
See the VMAS repository for more info.
max_steps (int, optional): Horizon of the task. Defaults to ``None`` (infinite horizon). Each VMAS scenario can
be terminating or not. If ``max_steps`` is specified,
the scenario is also terminated (and the ``"terminated"`` flag is set) whenever this horizon is reached.
Unlike gym's ``TimeLimit`` transform or torchrl's :class:`~torchrl.envs.transforms.StepCounter`,
this argument will not set the ``"truncated"`` entry in the tensordict.
categorical_actions (bool, optional): if the environment actions are discrete, whether to transform
them to categorical or one-hot. Defaults to ``True``.
group_map (MarlGroupMapType or Dict[str, List[str]], optional): how to group agents in tensordicts for
input/output. By default, if the agent names follow the ``"<name>_<int>"``
convention, they will be grouped by ``"<name>"``. If they do not follow this convention, they will be all put
in one group named ``"agents"``.
Otherwise, a group map can be specified or selected from some premade options.
See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
Attributes:
group_map (Dict[str, List[str]]): how to group agents in tensordicts for
input/output. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
agent_names (list of str): names of the agent in the environment
agent_names_to_indices_map (Dict[str, int]): dictionary mapping agent names to their index in the environment
full_action_spec_unbatched (TensorSpec): version of the spec without the vectorized dimension
full_observation_spec_unbatched (TensorSpec): version of the spec without the vectorized dimension
full_reward_spec_unbatched (TensorSpec): version of the spec without the vectorized dimension
full_done_spec_unbatched (TensorSpec): version of the spec without the vectorized dimension
het_specs (bool): whether the environment has any lazy spec
het_specs_map (Dict[str, bool]): dictionary mapping each group to a flag representing of the group has lazy specs
available_envs (List[str]): the list of the scenarios available to build.
.. warning::
VMAS returns a single ``done`` flag which does not distinguish between
when the env reached ``max_steps`` and termination.
If you deem the ``truncation`` signal necessary, set ``max_steps`` to
``None`` and use a :class:`~torchrl.envs.transforms.StepCounter` transform.
Examples:
>>> env = VmasWrapper(
... vmas.make_env(
... scenario="flocking",
... num_envs=32,
... continuous_actions=True,
... max_steps=200,
... device="cpu",
... seed=None,
... # Scenario kwargs
... n_agents=5,
... )
... )
>>> print(env.rollout(10))
TensorDict(
fields={
agents: TensorDict(
fields={
action: Tensor(shape=torch.Size([32, 10, 5, 2]), device=cpu, dtype=torch.float32, is_shared=False),
info: TensorDict(
fields={
agent_collision_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
agent_distance_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([32, 10, 5]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([32, 10, 5, 18]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([32, 10, 5]),
device=cpu,
is_shared=False),
done: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
agents: TensorDict(
fields={
info: TensorDict(
fields={
agent_collision_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
agent_distance_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([32, 10, 5]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([32, 10, 5, 18]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([32, 10, 5]),
device=cpu,
is_shared=False),
done: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
terminated: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([32, 10]),
device=cpu,
is_shared=False),
terminated: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([32, 10]),
device=cpu,
is_shared=False)
"""
git_url = "https://github.com/proroklab/VectorizedMultiAgentSimulator"
libname = "vmas"
@property
def lib(self):
import vmas
return vmas
@_classproperty
def available_envs(cls):
if not _has_vmas:
return []
return list(_get_envs())
def __init__(
self,
env: vmas.simulator.environment.environment.Environment = None, # noqa
categorical_actions: bool = True,
group_map: MarlGroupMapType | dict[str, list[str]] | None = None,
**kwargs,
):
if env is not None:
kwargs["env"] = env
if "device" in kwargs.keys() and kwargs["device"] != str(env.device):
raise TypeError("Env device is different from vmas device")
kwargs["device"] = str(env.device)
self.group_map = group_map
self.categorical_actions = categorical_actions
super().__init__(**kwargs, allow_done_after_reset=True)
def _build_env(
self,
env: vmas.simulator.environment.environment.Environment, # noqa
from_pixels: bool = False,
pixels_only: bool = False,
):
self.from_pixels = from_pixels
self.pixels_only = pixels_only
# TODO pixels
if self.from_pixels:
raise NotImplementedError("vmas rendering not yet implemented")
# Adjust batch size
if len(self.batch_size) == 0:
# Batch size not set
self.batch_size = torch.Size((env.num_envs,))
elif len(self.batch_size) == 1:
# Batch size is set
if not self.batch_size[0] == env.num_envs:
raise TypeError(
"Batch size used in constructor does not match vmas batch size."
)
else:
raise TypeError(
"Batch size used in constructor is not compatible with vmas."
)
return env
def _get_default_group_map(self, agent_names: list[str]):
# This function performs the default grouping in vmas.
# Agents with names "<name>_<int>" will be grouped in group name "<name>".
# If any of the agents does not follow the naming convention, we fall back
# back on having all agents in one group named "agents".
group_map = {}
follows_convention = True
for agent_name in agent_names:
# See if the agent follows the convention "<name>_<int>"
agent_name_split = agent_name.split("_")
if len(agent_name_split) == 1:
follows_convention = False
follows_convention = follows_convention and agent_name_split[-1].isdigit()
if not follows_convention:
break
# Group it with other agents that follow the same convention
group_name = "_".join(agent_name_split[:-1])
if group_name in group_map:
group_map[group_name].append(agent_name)
else:
group_map[group_name] = [agent_name]
if not follows_convention:
group_map = MarlGroupMapType.ALL_IN_ONE_GROUP.get_group_map(agent_names)
# For BC-compatibility rename the "agent" group to "agents"
if "agent" in group_map and len(group_map) == 1:
agent_group = group_map["agent"]
group_map["agents"] = agent_group
del group_map["agent"]
return group_map
def _make_specs(
self,
env: vmas.simulator.environment.environment.Environment, # noqa
) -> None:
# Create and check group map
self.agent_names = [agent.name for agent in self.agents]
self.agent_names_to_indices_map = {
agent.name: i for i, agent in enumerate(self.agents)
}
if self.group_map is None:
self.group_map = self._get_default_group_map(self.agent_names)
elif isinstance(self.group_map, MarlGroupMapType):
self.group_map = self.group_map.get_group_map(self.agent_names)
check_marl_grouping(self.group_map, self.agent_names)
full_action_spec_unbatched = Composite(device=self.device)
full_observation_spec_unbatched = Composite(device=self.device)
full_reward_spec_unbatched = Composite(device=self.device)
self.het_specs = False
self.het_specs_map = {}
for group in self.group_map.keys():
(
group_observation_spec,
group_action_spec,
group_reward_spec,
group_info_spec,
) = self._make_unbatched_group_specs(group)
full_action_spec_unbatched[group] = group_action_spec
full_observation_spec_unbatched[group] = group_observation_spec
full_reward_spec_unbatched[group] = group_reward_spec
if group_info_spec is not None:
full_observation_spec_unbatched[(group, "info")] = group_info_spec
group_het_specs = isinstance(
group_observation_spec, StackedComposite
) or isinstance(group_action_spec, StackedComposite)
self.het_specs_map[group] = group_het_specs
self.het_specs = self.het_specs or group_het_specs
full_done_spec_unbatched = Composite(
{
"done": Categorical(
n=2,
shape=torch.Size((1,)),
dtype=torch.bool,
device=self.device,
),
},
)
self.full_action_spec_unbatched = full_action_spec_unbatched
self.full_observation_spec_unbatched = full_observation_spec_unbatched
self.full_reward_spec_unbatched = full_reward_spec_unbatched
self.full_done_spec_unbatched = full_done_spec_unbatched
def _make_unbatched_group_specs(self, group: str):
# Agent specs
action_specs = []
observation_specs = []
reward_specs = []
info_specs = []
for agent_name in self.group_map[group]:
agent_index = self.agent_names_to_indices_map[agent_name]
agent = self.agents[agent_index]
action_specs.append(
Composite(
{
"action": _vmas_to_torchrl_spec_transform(
self.action_space[agent_index],
categorical_action_encoding=self.categorical_actions,
device=self.device,
) # shape = (n_actions_per_agent,)
},
)
)
observation_specs.append(
Composite(
{
"observation": _vmas_to_torchrl_spec_transform(
self.observation_space[agent_index],
device=self.device,
categorical_action_encoding=self.categorical_actions,
) # shape = (n_obs_per_agent,)
},
)
)
reward_specs.append(
Composite(
{
"reward": Unbounded(
shape=torch.Size((1,)),
device=self.device,
) # shape = (1,)
}
)
)
agent_info = self.scenario.info(agent)
if len(agent_info):
info_specs.append(
Composite(
{
key: Unbounded(
shape=_selective_unsqueeze(
value, batch_size=self.batch_size
).shape[1:],
device=self.device,
dtype=torch.float32,
)
for key, value in agent_info.items()
},
).to(self.device)
)
# Create multi-agent specs
group_action_spec = torch.stack(
action_specs, dim=0
) # shape = (n_agents, n_actions_per_agent)
group_observation_spec = torch.stack(
observation_specs, dim=0
) # shape = (n_agents, n_obs_per_agent)
group_reward_spec = torch.stack(reward_specs, dim=0) # shape = (n_agents, 1)
group_info_spec = None
if len(info_specs):
group_info_spec = torch.stack(info_specs, dim=0)
return (
group_observation_spec,
group_action_spec,
group_reward_spec,
group_info_spec,
)
def _check_kwargs(self, kwargs: dict):
vmas = self.lib
if "env" not in kwargs:
raise TypeError("Could not find environment key 'env' in kwargs.")
env = kwargs["env"]
if not isinstance(env, vmas.simulator.environment.Environment):
raise TypeError(
"env is not of type 'vmas.simulator.environment.Environment'."
)
def _init_env(self) -> int | None:
pass
def _set_seed(self, seed: int | None) -> None:
self._env.seed(seed)
def _reset(
self, tensordict: TensorDictBase | None = None, **kwargs
) -> TensorDictBase:
if tensordict is not None and "_reset" in tensordict.keys():
_reset = tensordict.get("_reset")
envs_to_reset = _reset.squeeze(-1)
if envs_to_reset.all():
self._env.reset(return_observations=False)
else:
for env_index, to_reset in enumerate(envs_to_reset):
if to_reset:
self._env.reset_at(env_index, return_observations=False)
else:
self._env.reset(return_observations=False)
obs, dones, infos = self._env.get_from_scenario(
get_observations=True,
get_infos=True,
get_rewards=False,
get_dones=True,
)
dones = self.read_done(dones)
source = {"done": dones, "terminated": dones.clone()}
for group, agent_names in self.group_map.items():
agent_tds = []
for agent_name in agent_names:
i = self.agent_names_to_indices_map[agent_name]
agent_obs = self.read_obs(obs[i])
agent_info = self.read_info(infos[i])
agent_td = TensorDict(
source={
"observation": agent_obs,
},
batch_size=self.batch_size,
device=self.device,
)
if agent_info is not None:
agent_td.set("info", agent_info)
agent_tds.append(agent_td)
agent_tds = LazyStackedTensorDict.maybe_dense_stack(agent_tds, dim=1)
if not self.het_specs_map[group]:
agent_tds = agent_tds.to_tensordict()
source.update({group: agent_tds})
tensordict_out = TensorDict(
source=source,
batch_size=self.batch_size,
device=self.device,
)
return tensordict_out
def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
agent_indices = {}
action_list = []
n_agents = 0
for group, agent_names in self.group_map.items():
group_action = tensordict.get((group, "action"))
group_action_list = list(self.read_action(group_action, group=group))
agent_indices.update(
{
self.agent_names_to_indices_map[agent_name]: i + n_agents
for i, agent_name in enumerate(agent_names)
}
)
n_agents += len(agent_names)
action_list += group_action_list
action = [action_list[agent_indices[i]] for i in range(self.n_agents)]
obs, rews, dones, infos = self._env.step(action)
dones = self.read_done(dones)
source = {"done": dones, "terminated": dones.clone()}
for group, agent_names in self.group_map.items():
agent_tds = []
for agent_name in agent_names:
i = self.agent_names_to_indices_map[agent_name]
agent_obs = self.read_obs(obs[i])
agent_rew = self.read_reward(rews[i])
agent_info = self.read_info(infos[i])
agent_td = TensorDict(
source={
"observation": agent_obs,
"reward": agent_rew,
},
batch_size=self.batch_size,
device=self.device,
)
if agent_info is not None:
agent_td.set("info", agent_info)
agent_tds.append(agent_td)
agent_tds = LazyStackedTensorDict.maybe_dense_stack(agent_tds, dim=1)
if not self.het_specs_map[group]:
agent_tds = agent_tds.to_tensordict()
source.update({group: agent_tds})
tensordict_out = TensorDict(
source=source,
batch_size=self.batch_size,
device=self.device,
)
return tensordict_out
def read_obs(self, observations: dict | torch.Tensor) -> dict | torch.Tensor:
if isinstance(observations, torch.Tensor):
return _selective_unsqueeze(observations, batch_size=self.batch_size)
return TensorDict(
source={key: self.read_obs(value) for key, value in observations.items()},
batch_size=self.batch_size,
)
def read_info(self, infos: dict[str, torch.Tensor]) -> torch.Tensor:
if len(infos) == 0:
return None
infos = TensorDict(
source={
key: _selective_unsqueeze(
value.to(torch.float32), batch_size=self.batch_size
)
for key, value in infos.items()
},
batch_size=self.batch_size,
device=self.device,
)
return infos
def read_done(self, done):
done = _selective_unsqueeze(done, batch_size=self.batch_size)
return done
def read_reward(self, rewards):
rewards = _selective_unsqueeze(rewards, batch_size=self.batch_size)
return rewards
def read_action(self, action, group: str = "agents"):
if not self.continuous_actions and not self.categorical_actions:
action = self.full_action_spec_unbatched[group, "action"].to_categorical(
action
)
agent_actions = action.unbind(dim=1)
return agent_actions
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(num_envs={self.num_envs}, n_agents={self.n_agents},"
f" batch_size={self.batch_size}, device={self.device})"
)
def to(self, device: DEVICE_TYPING) -> EnvBase:
self._env.to(device)
return super().to(device)
[docs]
class VmasEnv(VmasWrapper):
"""Vmas environment wrapper.
GitHub: https://github.com/proroklab/VectorizedMultiAgentSimulator
Paper: https://arxiv.org/abs/2207.03530
Args:
scenario (str or vmas.simulator.scenario.BaseScenario): the vmas scenario to build.
Must be one of :attr:`~.available_envs`. For a description and rendering of available scenarios see
`the README <https://github.com/proroklab/VectorizedMultiAgentSimulator/tree/VMAS-1.3.3?tab=readme-ov-file#main-scenarios>`__.
Keyword Args:
num_envs (int): Number of vectorized simulation environments. VMAS performs vectorized simulations using PyTorch.
This argument indicates the number of vectorized environments that should be simulated in a batch. It will also
determine the batch size of the environment.
device (torch.device, optional): Device for simulation. Defaults to the defaultt device. All the tensors created by VMAS
will be placed on this device.
continuous_actions (bool, optional): Whether to use continuous actions. Defaults to ``True``. If ``False``, actions
will be discrete. The number of actions and their size will depend on the chosen scenario.
See the VMAS repository for more info.
max_steps (int, optional): Horizon of the task. Defaults to ``None`` (infinite horizon). Each VMAS scenario can
be terminating or not. If ``max_steps`` is specified,
the scenario is also terminated (and the ``"terminated"`` flag is set) whenever this horizon is reached.
Unlike gym's ``TimeLimit`` transform or torchrl's :class:`~torchrl.envs.transforms.StepCounter`,
this argument will not set the ``"truncated"`` entry in the tensordict.
categorical_actions (bool, optional): if the environment actions are discrete, whether to transform
them to categorical or one-hot. Defaults to ``True``.
group_map (MarlGroupMapType or Dict[str, List[str]], optional): how to group agents in tensordicts for
input/output. By default, if the agent names follow the ``"<name>_<int>"``
convention, they will be grouped by ``"<name>"``. If they do not follow this convention, they will be all put
in one group named ``"agents"``.
Otherwise, a group map can be specified or selected from some premade options.
See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
scenario_kwargs (Dict, optional): dictionary of additional arguments passed to the VMAS
scenario constructor (e.g., number of agents, reward sparsity).
This is convenient when scenario parameters are stored under a dedicated config field.
**kwargs (Dict, optional): Additional arguments passed to the VMAS scenario constructor.
This allows passing scenario arguments directly as keyword arguments.
If the same key is provided in both ``scenario_kwargs`` and ``kwargs``, the value in
``kwargs`` takes precedence.
The available arguments will vary based on the chosen scenario.
To see the available arguments for a specific scenario, see the constructor in its file from
`the scenario folder <https://github.com/proroklab/VectorizedMultiAgentSimulator/tree/VMAS-1.3.3/vmas/scenarios>`__.
Attributes:
group_map (Dict[str, List[str]]): how to group agents in tensordicts for
input/output. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info.
agent_names (list of str): names of the agent in the environment
agent_names_to_indices_map (Dict[str, int]): dictionary mapping agent names to their index in the environment
full_action_spec_unbatched (TensorSpec): version of the spec without the vectorized dimension
full_observation_spec_unbatched (TensorSpec): version of the spec without the vectorized dimension
full_reward_spec_unbatched (TensorSpec): version of the spec without the vectorized dimension
full_done_spec_unbatched (TensorSpec): version of the spec without the vectorized dimension
het_specs (bool): whether the environment has any lazy spec
het_specs_map (Dict[str, bool]): dictionary mapping each group to a flag representing of the group has lazy specs
available_envs (List[str]): the list of the scenarios available to build.
.. warning::
VMAS returns a single ``done`` flag which does not distinguish between
when the env reached ``max_steps`` and termination.
If you deem the ``truncation`` signal necessary, set ``max_steps`` to
``None`` and use a :class:`~torchrl.envs.transforms.StepCounter` transform.
Examples:
>>> env = VmasEnv(
... scenario="flocking",
... num_envs=32,
... continuous_actions=True,
... max_steps=200,
... device="cpu",
... seed=None,
... # Scenario kwargs
... n_agents=5,
... )
>>> print(env.rollout(10))
TensorDict(
fields={
agents: TensorDict(
fields={
action: Tensor(shape=torch.Size([32, 10, 5, 2]), device=cpu, dtype=torch.float32, is_shared=False),
info: TensorDict(
fields={
agent_collision_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
agent_distance_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([32, 10, 5]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([32, 10, 5, 18]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([32, 10, 5]),
device=cpu,
is_shared=False),
done: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
agents: TensorDict(
fields={
info: TensorDict(
fields={
agent_collision_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
agent_distance_rew: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([32, 10, 5]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([32, 10, 5, 18]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([32, 10, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([32, 10, 5]),
device=cpu,
is_shared=False),
done: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
terminated: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([32, 10]),
device=cpu,
is_shared=False),
terminated: Tensor(shape=torch.Size([32, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([32, 10]),
device=cpu,
is_shared=False)
"""
def __init__(
self,
scenario: str | vmas.simulator.scenario.BaseScenario, # noqa
*,
num_envs: int,
continuous_actions: bool = True,
max_steps: int | None = None,
categorical_actions: bool = True,
seed: int | None = None,
group_map: MarlGroupMapType | dict[str, list[str]] | None = None,
scenario_kwargs: dict | None = None,
**kwargs,
):
if not _has_vmas:
raise ImportError(
f"vmas python package was not found. Please install this dependency. "
f"More info: {self.git_url}."
)
scenario_kwargs = {**(scenario_kwargs or {}), **kwargs}
super().__init__(
scenario=scenario,
num_envs=num_envs,
continuous_actions=continuous_actions,
max_steps=max_steps,
seed=seed,
categorical_actions=categorical_actions,
group_map=group_map,
**scenario_kwargs,
)
def _check_kwargs(self, kwargs: dict):
if "scenario" not in kwargs:
raise TypeError("Could not find environment key 'scenario' in kwargs.")
if "num_envs" not in kwargs:
raise TypeError("Could not find environment key 'num_envs' in kwargs.")
def _build_env(
self,
scenario: str | vmas.simulator.scenario.BaseScenario, # noqa
num_envs: int,
continuous_actions: bool,
max_steps: int | None,
seed: int | None,
**scenario_kwargs,
) -> vmas.simulator.environment.environment.Environment: # noqa
vmas = self.lib
self.scenario_name = scenario
from_pixels = scenario_kwargs.pop("from_pixels", False)
pixels_only = scenario_kwargs.pop("pixels_only", False)
return super()._build_env(
env=vmas.make_env(
scenario=scenario,
num_envs=num_envs,
device=self.device
if self.device is not None
else getattr(
torch, "get_default_device", lambda: torch.device("cpu")
)(),
continuous_actions=continuous_actions,
max_steps=max_steps,
seed=seed,
wrapper=None,
**scenario_kwargs,
),
pixels_only=pixels_only,
from_pixels=from_pixels,
)
def __repr__(self):
return f"{super().__repr__()} (scenario={self.scenario_name})"