Source code for shimmy.meltingpot_compatibility

"""Wrapper to convert a Melting Pot substrate into a PettingZoo compatible environment.

Taken from
https://github.com/deepmind/meltingpot/blob/main/examples/pettingzoo/utils.py
and modified to modern PettingZoo API
"""
# pyright: reportOptionalSubscript=false
from __future__ import annotations

import functools
from typing import TYPE_CHECKING, Any, Optional

import dm_env
import gymnasium
import numpy as np
import pygame
from gymnasium.utils.ezpickle import EzPickle
from pettingzoo.utils.env import ActionDict, AgentID, ObsDict, ParallelEnv

import shimmy.utils.meltingpot as utils

if TYPE_CHECKING:
    from meltingpot.utils.substrates import substrate


[docs] class MeltingPotCompatibilityV0(ParallelEnv, EzPickle): """This compatibility wrapper converts a Melting Pot substrate into a PettingZoo environment. Due to how the underlying environment is set up, this environment is nondeterministic, so seeding doesn't work. Melting Pot is a research tool developed to facilitate work on multi-agent artificial intelligence. It assesses generalization to novel social situations involving both familiar and unfamiliar individuals, and has been designed to test a broad range of social interactions such as: cooperation, competition, deception, reciprocation, trust, stubbornness and so on. Melting Pot offers researchers a set of over 50 multi-agent reinforcement learning substrates (multi-agent games) on which to train agents, and over 256 unique test scenarios on which to evaluate these trained agents. """ metadata = { "render_modes": ["human", "rgb_array"], "name": "MeltingPotCompatibilityV0", } PLAYER_STR_FORMAT = "player_{index}" MAX_CYCLES = 1000 def __init__( self, env: substrate.Substrate | None = None, substrate_name: str | None = None, max_cycles: int = MAX_CYCLES, render_mode: str | None = None, ): """Wrapper that converts a Melting Pot environment into a PettingZoo environment. Args: env (Optional[substrate.Substrate]): existing Melting Pot environment to wrap substrate_name (Optional[str]): name of Melting Pot substrate to load (instead of existing environment) max_cycles (Optional[int]): maximum number of cycles before truncation render_mode (Optional[str]): rendering mode """ EzPickle.__init__( self, env, substrate_name, max_cycles, render_mode, ) # Only one of substrate_name and env can be provided, the other should be None if env is None and substrate_name is None: raise ValueError( "No environment provided. Use `env` to specify an existing environment, or load an environment with `substrate_name`." ) elif env is not None and substrate_name is not None: raise ValueError( "Two environments provided. Use `env` to specify an existing environment, or load an environment with `substrate_name`." ) elif substrate_name is not None: self._env = utils.load_meltingpot(substrate_name) elif env is not None: self._env = env self.max_cycles = max_cycles # Set up PettingZoo variables self.render_mode = render_mode self.state_space = utils.dm_spec2gym_space( self._env.observation_spec()[0]["WORLD.RGB"] ) self._num_players = len(self._env.observation_spec()) self.possible_agents = [ self.PLAYER_STR_FORMAT.format(index=index) for index in range(self._num_players) ] self.agents = [agent for agent in self.possible_agents] self.num_cycles = 0 # Set up pygame rendering if self.render_mode == "human": self.display_scale = 4 self.display_fps = 5 pygame.init() self.clock = pygame.time.Clock() pygame.display.set_caption("Melting Pot") shape = self.state_space.shape self.game_display = pygame.display.set_mode( ( int(shape[1] * self.display_scale), int(shape[0] * self.display_scale), ) )
[docs] @functools.lru_cache(maxsize=None) def observation_space(self, agent: AgentID) -> gymnasium.spaces.Space: """observation_space. Get the observation space from the underlying Melting Pot substrate. Args: agent (AgentID): agent Returns: observation_space: spaces.Space """ observation_space = utils.remove_world_observations_from_space( utils.dm_spec2gym_space(self._env.observation_spec()[0]) # type: ignore ) return observation_space
[docs] @functools.lru_cache(maxsize=None) def action_space(self, agent: AgentID) -> gymnasium.spaces.Space: """action_space. Get the action space from the underlying Melting Pot substrate. Args: agent (AgentID): agent Returns: action_space: spaces.Space """ action_space = utils.dm_spec2gym_space(self._env.action_spec()[0]) return action_space
[docs] def state(self) -> np.ndarray: """State. Get an observation of the current environment's state. Used in rendering. Returns: observation """ return self._env.observation()
[docs] def reset( self, seed: int | None = None, options: dict | None = None, ) -> tuple[ObsDict, dict[AgentID, Any]]: """reset. Resets the environment. Args: seed: the seed to reset the environment with (not used, due to nondeterministic underlying environment) options: the options to reset the environment with Returns: observations """ timestep: dm_env.TimeStep = self._env.reset() self.agents = self.possible_agents[:] self.num_cycles = 0 observations = utils.timestep_to_observations(timestep) # duplicate infos across agents infos = { agent: { "timestep.discount": timestep.discount, "timestep.step_type": timestep.step_type, } for agent in self.agents } return observations, infos
[docs] def step( self, actions: ActionDict ) -> tuple[ ObsDict, dict[str, float], dict[str, bool], dict[str, bool], dict[str, dict] ]: """step. Steps through all agents with one action Args: actions: actions to step through the environment with Returns: (observations, rewards, terminations, truncations, infos) """ timestep = self._env.step([actions[agent] for agent in self.agents]) rewards = { agent: timestep.reward[index] for index, agent in enumerate(self.agents) } self.num_cycles += 1 termination = timestep.last() terminations = {agent: termination for agent in self.agents} truncation = self.num_cycles >= self.max_cycles truncations = {agent: truncation for agent in self.agents} infos = {agent: {} for agent in self.agents} if termination or truncation: self.agents = [] observations = utils.timestep_to_observations(timestep) if self.render_mode == "human": self.render() return observations, rewards, terminations, truncations, infos
[docs] def close(self): """close. Closes the environment. """ self._env.close()
[docs] def render(self) -> None | np.ndarray: """render. Renders the environment. Returns: The rendering of the environment, depending on the render mode """ rgb_arr = self.state()[0]["WORLD.RGB"] if self.render_mode is None: gymnasium.logger.warn( "You are calling render method without specifying any render mode." ) return elif self.render_mode == "human": rgb_arr = np.transpose(rgb_arr, (1, 0, 2)) surface = pygame.surfarray.make_surface(rgb_arr) rect = surface.get_rect() surf = pygame.transform.scale( surface, (int(rect[2] * self.display_scale), int(rect[3] * self.display_scale)), ) self.game_display.blit(surf, dest=(0, 0)) pygame.display.update() self.clock.tick(self.display_fps) return None elif self.render_mode == "rgb_array": return rgb_arr