Source code for shimmy.dm_control_multiagent_compatibility

"""Wrapper to convert a dm_env multiagent environment into a pettingzoo compatible environment."""
from __future__ import annotations

import functools
from typing import TYPE_CHECKING, Any

import dm_control.composer
import dm_env
import gymnasium
import numpy as np
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
from gymnasium.utils import EzPickle
from pettingzoo.utils.env import ActionDict, AgentID, ObsDict, ParallelEnv

from shimmy.utils.dm_control_multiagent import load_dm_control_soccer
from shimmy.utils.dm_env import dm_obs2gym_obs, dm_spec2gym_space

if TYPE_CHECKING:
    from dm_control.locomotion import soccer as dm_soccer


def _unravel_ma_timestep(
    timestep: dm_env.TimeStep, agents: list[AgentID]
) -> tuple[
    dict[AgentID, Any],
    dict[AgentID, float],
    dict[AgentID, bool],
    dict[AgentID, bool],
    dict[AgentID, Any],
]:
    """Opens up the timestep to return obs, reward, terminated, truncated, info."""
    # set terminated and truncated
    term, trunc = False, False
    if timestep.last():
        if timestep.discount == 0:
            trunc = True
        else:
            term = True

    # expand the observations
    list_observations = [dm_obs2gym_obs(obs) for obs in timestep.observation]
    observations: dict[AgentID, Any] = dict(zip(agents, list_observations))

    # sometimes deepmind decides not to reward people
    rewards: dict[AgentID, float] = {agent: 0.0 for agent in agents}

    if timestep.reward:
        rewards = dict(zip(agents, timestep.reward))

    # expand everything else
    terminations: dict[AgentID, bool] = {agent: term for agent in agents}
    truncations: dict[AgentID, bool] = {agent: trunc for agent in agents}

    # duplicate infos across agents
    infos = {
        agent: {
            "timestep.discount": timestep.discount,
            "timestep.step_type": timestep.step_type,
        }
        for agent in agents
    }

    return (
        observations,
        rewards,
        terminations,
        truncations,
        infos,
    )


[docs] class DmControlMultiAgentCompatibilityV0(ParallelEnv, EzPickle): """This compatibility wrapper converts multi-agent dm-control environments, primarily soccer, into a PettingZoo environment. Dm-control is DeepMind's software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo physics. This compatibility wrapper converts a dm-control environment into a gymnasium environment. """ metadata = {"render_modes": ["human"], "name": "DmControlMultiAgentCompatibilityV0"} def __init__( self, env: dm_control.composer.Environment | None = None, team_size: int | None = None, time_limit: float | None = None, disable_walker_contacts: bool | None = None, enable_field_box: bool | None = None, terminate_on_goal: bool | None = None, walker_type: dm_soccer.WalkerType | None = None, render_mode: str | None = None, ): """Wrapper to convert a dm control multi-agent environment into a PettingZoo environment. Due to how the underlying environment is set up, this environment is nondeterministic, so seeding does not work. Note: to wrap an existing environment, only the env and render_mode arguments can be specified. All other arguments (marked [DM CONTROL ARG]) are specific to DM Lab and will be used to load a new environment. Args: env (Optional[dm_env.Environment]): existing dm control multi-agent environment to wrap team_size (Optional[int]): number of players for each team [DM CONTROL ARG] time_limit (Optional[float]): time limit for the game [DM CONTROL ARG] disable_walker_contacts (Optional[bool]): flag to disable walker contacts [DM CONTROL ARG] enable_field_box (Optional[bool]): flag to enable field box [DM CONTROL ARG] terminate_on_goal (Optional[bool]): flag to terminate the environment on goal [DM CONTROL ARG] walker_type (Optional[dm_soccer.WalkerType]): specify walker type (BOXHEAD, ANT, or HUMANOID) [DM CONTROL ARG] render_mode (Optional[str]): rendering mode """ EzPickle.__init__(self, env=env, render_mode=render_mode) ParallelEnv.__init__(self) DM_CONTROL_ARGS = [ team_size, time_limit, disable_walker_contacts, enable_field_box, terminate_on_goal, walker_type, ] # Only one of env and DM_CONTROL_ARGS can be provided, the other should be None. if env is None and all(arg is None for arg in DM_CONTROL_ARGS): raise ValueError( "No environment provided. Use `env` to specify an existing environment, or load an environment by specifying at least one of `team_size`, `time_limit`, `disable_walker_contacts`, `enable_field_box` `terminate_on_goal`, or `walker_type`." ) elif env is not None and any(arg is not None for arg in DM_CONTROL_ARGS): raise ValueError( "Two environments provided. Use `env` to specify an existing environment, or load an environment by specifying at least one of `team_size`, `time_limit`, `disable_walker_contacts`, `enable_field_box` `terminate_on_goal`, or `walker_type`." ) elif any(arg is not None for arg in DM_CONTROL_ARGS): self._env = load_dm_control_soccer( team_size, time_limit, disable_walker_contacts, enable_field_box, terminate_on_goal, walker_type, ) elif env is not None: self._env = env self.render_mode = render_mode # get action and observation specs first all_obs_spaces = [ dm_spec2gym_space(spec) for spec in self._env.observation_spec() ] all_act_spaces = [dm_spec2gym_space(spec) for spec in self._env.action_spec()] num_players = len(all_obs_spaces) # agent definitions self.possible_agents = ["player_" + str(r) for r in range(num_players)] self.agent_id_name_mapping = dict(zip(range(num_players), self.possible_agents)) self.agent_name_id_mapping = dict(zip(self.possible_agents, range(num_players))) # the official spaces self.obs_spaces = dict(zip(self.possible_agents, all_obs_spaces)) self.act_spaces = dict(zip(self.possible_agents, all_act_spaces)) if self.render_mode == "human": assert self._env.physics is not None self.viewer = MujocoRenderer( self._env.physics.model.ptr, self._env.physics.data.ptr )
[docs] @functools.lru_cache(maxsize=None) def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space: """observation_space. Get the observation space from the underlying meltingpot substrate. Args: agent (AgentID): agent Returns: observation_space: spaces.Space """ return self.obs_spaces[agent]
[docs] @functools.lru_cache(maxsize=None) def action_space(self, agent: AgentID) -> gymnasium.spaces.Space: """action_space. Get the action space from the underlying dm-control env. Args: agent (AgentID): agent Returns: action_space: spaces.Space """ return self.act_spaces[agent]
[docs] def render(self) -> np.ndarray | None: """render. Renders the environment. Returns: The rendering of the environment, depending on the render mode """ if self.render_mode is None: gymnasium.logger.warn( "You are calling render method without specifying any render mode." ) return
[docs] def close(self): """close. Closes the environment. """ assert self._env.physics is not None self._env.physics.free() self._env.close() if hasattr(self, "viewer"): self.viewer.close()
[docs] def reset( self, seed: int | None = None, options: dict[AgentID, Any] | None = None ) -> tuple[ObsDict, dict[str, Any]]: """reset. Resets the dm-control environment. Args: seed: the seed to reset the environment with options: the options to reset the environment with (unused) Returns: observations """ self.agents = self.possible_agents[:] self.num_moves = 0 self._env._random_state = np.random.RandomState(seed) timestep = self._env.reset() observations, _, _, _, info = _unravel_ma_timestep(timestep, self.agents) if self.render_mode == "human": self.viewer.close() self.viewer = MujocoRenderer( self._env.physics.model.ptr, self._env.physics.data.ptr ) return observations, info
[docs] def step( self, actions: ActionDict ) -> tuple[ ObsDict, dict[AgentID, float], dict[AgentID, bool], dict[AgentID, bool], dict[AgentID, Any], ]: """step. Steps through all agents with the actions. Args: actions: dict of actions to step through the environment with Returns: (observations, rewards, terminations, truncations, infos) """ # assert that the actions _must_ have actions for all agents assert len(actions) == len( self.agents ), f"Must have actions for all {len(self.agents)} agents, currently only found {len(actions)}." timestep = self._env.step(actions.values()) obs, rewards, terminations, truncations, infos = _unravel_ma_timestep( timestep, self.agents ) if self.render_mode == "human": self.viewer.render(self.render_mode) if any(terminations.values()) or any(truncations.values()): self.agents = [] return obs, rewards, terminations, truncations, infos