Source code for shimmy.utils.meltingpot
"""Utility functions for Melting Pot."""
# pyright: reportGeneralTypeIssues=false
# flake8: noqa F821
import dm_env
from gymnasium import spaces
from pettingzoo.utils.env import ObsDict
from shimmy.utils.dm_env import dm_spec2gym_space
PLAYER_STR_FORMAT = "player_{index}"
_WORLD_PREFIX = "WORLD."
[docs]
def load_meltingpot(substrate_name: str):
"""Helper function to load Melting Pot substrates.
Args:
substrate_name: str
Returns:
env: meltingpot.utils.substrates.substrate.Substrate
"""
import meltingpot
from ml_collections import config_dict
# Create env config
substrate_name = substrate_name
player_roles = meltingpot.substrate.get_config(substrate_name).default_player_roles
env_config = {
"substrate": substrate_name,
"roles": player_roles,
}
# Build substrate from pickle
env_config = config_dict.ConfigDict(env_config)
env = meltingpot.substrate.build(env_config["substrate"], roles=env_config["roles"])
return env
[docs]
def timestep_to_observations(timestep: dm_env.TimeStep) -> ObsDict:
"""Extracts Gymnasium-compatible observations from a Melting Pot timestep.
Args:
timestep: The dm_env timestep
Returns:
observation, reward, terminated, truncated, info.
"""
gym_observations = {}
for index, observation in enumerate(timestep.observation):
gym_observations[PLAYER_STR_FORMAT.format(index=index)] = {
key: value for key, value in observation.items() if _WORLD_PREFIX not in key
}
return gym_observations
[docs]
def remove_world_observations_from_space(observation: spaces.Dict) -> spaces.Dict:
"""Removes the world observations key from a Gymnasium observation dict.
This is used to limit the information an individual agent has access to (it cannot see the entire world).
Args:
observation: The Melting Pot observation
Returns:
observation: The Melting Pot observation, without world observations.
"""
return spaces.Dict(
{key: observation[key] for key in observation if _WORLD_PREFIX not in key}
)