"""ALE-py interface for atari.
This file was originally copied from https://github.com/mgbellemare/Arcade-Learning-Environment/blob/master/src/python/env/gym.py
Under the GNU General Public License v2.0
Copyright is held by the authors
Changes
* Added `self.render_mode` which is identical to `self._render_mode`
"""
from __future__ import annotations
import sys
from typing import Any, Sequence
import ale_py
import ale_py.roms as roms
import ale_py.roms.utils as rom_utils
import gymnasium
import gymnasium.logger as logger
import numpy as np
from gymnasium.error import Error
from gymnasium.spaces import Box, Discrete
from gymnasium.utils.ezpickle import EzPickle
if sys.version_info < (3, 11):
from typing_extensions import NotRequired, TypedDict
else:
from typing import NotRequired, TypedDict
class AtariEnvStepMetadata(TypedDict):
"""Atari Environment Step Metadata."""
lives: int
episode_frame_number: int
frame_number: int
seeds: NotRequired[Sequence[int]]
[docs]
class AtariEnv(gymnasium.Env[np.ndarray, np.int64], EzPickle):
"""(A)rcade (L)earning (Gymnasium) (Env)ironment.
A Gymnasium wrapper around the Arcade Learning Environment (ALE).
"""
# No render modes
metadata = {
"render_modes": ["human", "rgb_array"],
"obs_types": {"rgb", "grayscale", "ram"},
}
def __init__(
self,
game: str = "pong",
mode: int | None = None,
difficulty: int | None = None,
obs_type: str = "rgb",
frameskip: tuple[int, int] | int = 4,
repeat_action_probability: float = 0.25,
full_action_space: bool = False,
max_num_frames_per_episode: int | None = None,
render_mode: str | None = None,
):
"""Initialize the ALE interface for Gymnasium.
Default parameters are taken from Machado et al., 2018.
Args:
game: str => Game to initialize env with.
mode: Optional[int] => Game mode, see Machado et al., 2018
difficulty: Optional[int] => Game difficulty,see Machado et al., 2018
obs_type: str => Observation type in { 'rgb', 'grayscale', 'ram' }
frameskip: Union[Tuple[int, int], int] =>
Stochastic frameskip as tuple or fixed.
repeat_action_probability: int =>
Probability to repeat actions, see Machado et al., 2018
full_action_space: bool => Use full action space?
max_num_frames_per_episode: int => Max number of frame per episode.
Once `max_num_frames_per_episode` is reached the episode is
truncated.
render_mode: str => One of { 'human', 'rgb_array' }.
If `human` we'll interactively display the screen and enable
game sounds. This will lock emulation to the ROMs specified FPS
If `rgb_array` we'll return the `rgb` key in step metadata with
the current environment RGB frame.
Note:
- The game must be installed, see ale-import-roms, or ale-py-roms.
- Frameskip values of (low, high) will enable stochastic frame skip
which will sample a random frameskip uniformly each action.
- It is recommended to enable full action space.
See Machado et al., 2018 for more details.
References:
`Revisiting the Arcade Learning Environment: Evaluation Protocols
and Open Problems for General Agents`, Machado et al., 2018, JAIR
URL: https://jair.org/index.php/jair/article/view/11182
"""
if obs_type == "image":
logger.warn(
'obs_type "image" should be replaced with the image type, one of: rgb, grayscale'
)
obs_type = "rgb"
if obs_type not in self.metadata["obs_types"]:
raise Error(
f"Invalid observation type: {obs_type}. Expecting: rgb, grayscale, ram."
)
if type(frameskip) not in (int, tuple):
raise Error(f"Invalid frameskip type: {type(frameskip)}.")
if isinstance(frameskip, int) and frameskip <= 0:
raise Error(
f"Invalid frameskip of {frameskip}, frameskip must be positive."
)
elif isinstance(frameskip, tuple) and len(frameskip) != 2:
raise Error(
f"Invalid stochastic frameskip length of {len(frameskip)}, expected length 2."
)
elif isinstance(frameskip, tuple) and frameskip[0] > frameskip[1]:
raise Error(
"Invalid stochastic frameskip, lower bound is greater than upper bound."
)
elif isinstance(frameskip, tuple) and frameskip[0] <= 0:
raise Error(
"Invalid stochastic frameskip lower bound is greater than upper bound."
)
if render_mode is not None and render_mode not in self.metadata["render_modes"]:
raise Error(f"Render mode {render_mode} not supported (rgb_array, human).")
EzPickle.__init__(
self,
game,
mode,
difficulty,
obs_type,
frameskip,
repeat_action_probability,
full_action_space,
max_num_frames_per_episode,
render_mode,
)
# Initialize ALE
self.ale = ale_py.ALEInterface()
self._game = rom_utils.rom_id_to_name(game)
self._game_mode = mode
self._game_difficulty = difficulty
self._frameskip = frameskip
self._obs_type = obs_type
self._render_mode = self.render_mode = render_mode
# Set logger mode to error only
self.ale.setLoggerMode(ale_py.LoggerMode.Error)
# Config sticky action prob.
self.ale.setFloat("repeat_action_probability", repeat_action_probability)
if max_num_frames_per_episode is not None:
self.ale.setInt("max_num_frames_per_episode", max_num_frames_per_episode)
# If render mode is human we can display screen and sound
if render_mode == "human":
self.ale.setBool("display_screen", True)
self.ale.setBool("sound", True)
# Seed + Load
self.seed()
if full_action_space:
self._action_set = self.ale.getLegalActionSet()
else:
self._action_set = self.ale.getMinimalActionSet()
self.action_space = Discrete(len(self._action_set))
# Initialize observation type
if self._obs_type == "ram":
self.observation_space = Box(
low=0, high=255, dtype=np.uint8, shape=(self.ale.getRAMSize(),)
)
elif self._obs_type == "rgb" or self._obs_type == "grayscale":
image_shape = self.ale.getScreenDims()
if self._obs_type == "rgb":
image_shape += (3,)
self.observation_space = Box(
low=0, high=255, dtype=np.uint8, shape=image_shape
)
else:
raise Error(f"Unrecognized observation type: {self._obs_type}")
[docs]
def seed(self, seed: int | None = None) -> tuple[int, int]:
"""Seeds both the internal numpy rng for stochastic frame skip and the ALE RNG.
This function must also initialize the ROM and set the corresponding
mode and difficulty. `seed` may be called to initialize the environment
during deserialization by Gymnasium so these side-effects must reside here.
Args:
seed: int => Manually set the seed for RNG.
Returns:
tuple[int, int] => (np seed, ALE seed)
"""
ss = np.random.SeedSequence(seed)
seed1, seed2 = ss.generate_state(n_words=2)
self.np_random = np.random.default_rng(seed1)
# ALE only takes signed integers for `setInt`, it'll get converted back
# to unsigned in StellaEnvironment.
self.ale.setInt("random_seed", seed2.astype(np.int32))
if not hasattr(roms, self._game):
raise Error(
f'We\'re Unable to find the game "{self._game}". Note: Gymnasium no longer distributes ROMs. '
f"If you own a license to use the necessary ROMs for research purposes you can download them "
f'via `pip install gymnasium[accept-rom-license]`. Otherwise, you should try importing "{self._game}" '
f'via the command `ale-import-roms`. If you believe this is a mistake perhaps your copy of "{self._game}" '
"is unsupported. To check if this is the case try providing the environment variable "
"`PYTHONWARNINGS=default::ImportWarning:ale_py.roms`. For more information see: "
"https://github.com/mgbellemare/Arcade-Learning-Environment#rom-management"
)
self.ale.loadROM(getattr(roms, self._game))
if self._game_mode is not None:
self.ale.setMode(self._game_mode)
if self._game_difficulty is not None:
self.ale.setDifficulty(self._game_difficulty)
return seed1, seed2
[docs]
def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[np.ndarray, AtariEnvStepMetadata]:
"""Resets environment and returns initial observation.
Args:
seed: The reset seed
options: The reset options
Returns:
The reset observation and info
"""
super().reset(seed=seed, options=options)
del options
# Gymnasium's new seeding API seeds on reset.
# This will cause the console to be recreated
# and loose all previous state, e.g., statistics, etc.
seeded_with = None
if seed is not None:
seeded_with = self.seed(seed)
self.ale.reset_game()
obs = self._get_obs()
info = self._get_info()
if seeded_with is not None:
info["seeds"] = seeded_with
return obs, info
[docs]
def step(
self,
action_ind: int,
) -> tuple[np.ndarray, float, bool, bool, AtariEnvStepMetadata]:
"""Perform one agent step, i.e., repeats `action` frameskip # of steps.
Args:
action_ind: int => Action index to execute
Returns:
Tuple[np.ndarray, float, bool, Dict[str, Any]] => observation, reward, terminal, metadata
Note: `metadata` contains the keys "lives" and "rgb" if render_mode == 'rgb_array'.
"""
# Get action enum, terminal bool, metadata
action = self._action_set[action_ind]
# If frameskip is a length 2 tuple then it's stochastic
# frameskip between [frameskip[0], frameskip[1]] uniformly.
if isinstance(self._frameskip, int):
frameskip = self._frameskip
elif isinstance(self._frameskip, tuple):
frameskip = self.np_random.integers(*self._frameskip)
else:
raise Error(f"Invalid frameskip type: {self._frameskip}")
# Frameskip
reward = 0.0
for _ in range(frameskip):
reward += self.ale.act(action)
is_terminal = self.ale.game_over(with_truncation=False)
is_truncated = self.ale.game_truncated()
return self._get_obs(), reward, is_terminal, is_truncated, self._get_info()
[docs]
def render(self) -> Any:
"""Renders the ALE environment.
Returns:
If render_mode is "rgb_array", returns the screen RGB view.
"""
if self.render_mode == "rgb_array":
return self.ale.getScreenRGB()
elif self.render_mode == "human":
pass
else:
raise Error(
f"Invalid render mode `{self.render_mode}`. Supported modes: `human`, `rgb_array`."
)
def _get_obs(self) -> np.ndarray:
"""Retrieves the current observation, dependent on `self._obs_type`.
Returns:
The current observation
"""
if self._obs_type == "ram":
return self.ale.getRAM()
elif self._obs_type == "rgb":
return self.ale.getScreenRGB()
elif self._obs_type == "grayscale":
return self.ale.getScreenGrayscale()
else:
raise Error(f"Unrecognized observation type: {self._obs_type}")
def _get_info(self) -> AtariEnvStepMetadata:
return {
"lives": self.ale.lives(),
"episode_frame_number": self.ale.getEpisodeFrameNumber(),
"frame_number": self.ale.getFrameNumber(),
}
[docs]
def get_keys_to_action(self) -> dict[tuple[int], ale_py.Action]:
"""Return keymapping -> actions for human play.
Returns:
A dictionary of keys to actions.
"""
UP = ord("w")
LEFT = ord("a")
RIGHT = ord("d")
DOWN = ord("s")
FIRE = ord(" ")
mapping = {
ale_py.Action.NOOP: (None,),
ale_py.Action.UP: (UP,),
ale_py.Action.FIRE: (FIRE,),
ale_py.Action.DOWN: (DOWN,),
ale_py.Action.LEFT: (LEFT,),
ale_py.Action.RIGHT: (RIGHT,),
ale_py.Action.UPFIRE: (UP, FIRE),
ale_py.Action.DOWNFIRE: (DOWN, FIRE),
ale_py.Action.LEFTFIRE: (LEFT, FIRE),
ale_py.Action.RIGHTFIRE: (RIGHT, FIRE),
ale_py.Action.UPLEFT: (UP, LEFT),
ale_py.Action.UPRIGHT: (UP, RIGHT),
ale_py.Action.DOWNLEFT: (DOWN, LEFT),
ale_py.Action.DOWNRIGHT: (DOWN, RIGHT),
ale_py.Action.UPLEFTFIRE: (UP, LEFT, FIRE),
ale_py.Action.UPRIGHTFIRE: (UP, RIGHT, FIRE),
ale_py.Action.DOWNLEFTFIRE: (DOWN, LEFT, FIRE),
ale_py.Action.DOWNRIGHTFIRE: (DOWN, RIGHT, FIRE),
}
# Map
# (key, key, ...) -> action_idx
# where action_idx is the integer value of the action enum
#
actions = self._action_set
return dict(
zip(
map(lambda action: tuple(sorted(mapping[action])), actions),
range(len(actions)),
)
)
[docs]
def get_action_meanings(self) -> list[str]:
"""Return the meaning of each integer action.
Returns:
A list of action meaning
"""
keys = ale_py.Action.__members__.values()
values = ale_py.Action.__members__.keys()
mapping = dict(zip(keys, values))
return [mapping[action] for action in self._action_set]
[docs]
def clone_state(self, include_rng: bool = False) -> ale_py.ALEState:
"""Clone emulator state w/o system state.
Restoring this state will *not* give an identical environment.
For complete cloning and restoring of the full state, see `{clone,restore}_full_state()`.
Args:
include_rng: If to include the rng in the cloned state
Returns:
The cloned state
"""
return self.ale.cloneState(include_rng=include_rng)
[docs]
def restore_state(self, state: ale_py.ALEState):
"""Restore emulator state w/o system state.
Args:
state: The state to restore
"""
self.ale.restoreState(state)
[docs]
def clone_full_state(self) -> ale_py.ALEState:
"""Deprecated method which would clone the emulator and system state."""
logger.warn(
"`clone_full_state()` is deprecated and will be removed in a future release of `ale-py`. "
"Please use `clone_state(include_rng=True)` which is equivalent to `clone_full_state`. "
)
return self.ale.cloneSystemState()
[docs]
def restore_full_state(self, state: ale_py.ALEState):
"""Restore emulator state w/ system state including pseudo-randomness."""
logger.warn(
"restore_full_state() is deprecated and will be removed in a future release of `ale-py`. "
"Please use `restore_state(state)` which will restore the state regardless of being a full or partial state. "
)
self.ale.restoreSystemState(state)