Source code for shimmy.bsuite_compatibility

"""Wrapper to convert a BSuite environment into a gymnasium compatible environment."""
from __future__ import annotations

from typing import Any

import gymnasium
import numpy as np
from bsuite.environments import Environment
from gymnasium.core import ObsType
from gymnasium.error import UnsupportedMode
from gymnasium.utils import EzPickle

from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space

# Until the BSuite authors fix
# https://github.com/deepmind/bsuite/pull/48
# This needs to exist...
np.int = int  # pyright: ignore[reportGeneralTypeIssues]


[docs] class BSuiteCompatibilityV0(gymnasium.Env[ObsType, np.ndarray], EzPickle): """A compatibility wrapper that converts a BSuite environment into a gymnasium environment. Note: Bsuite uses `np.random.RandomState`, a legacy random number generator while gymnasium uses `np.random.Generator`, therefore the return type of `np_random` is different from expected. """ metadata = {"render_modes": []} def __init__( self, env: Environment, render_mode: str | None = None, ): """Initialises the environment with a render mode along with render information.""" EzPickle.__init__(self, env, render_mode) self._env: Any = env self.observation_space = dm_spec2gym_space(env.observation_spec()) self.action_space = dm_spec2gym_space(env.action_spec()) assert render_mode is None, "No render modes available in BSuite."
[docs] def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[ObsType, dict[str, Any]]: """Resets the bsuite environment.""" super().reset(seed=seed) if seed is not None: self.np_random = np.random.RandomState(seed=seed) self._env._rng = self.np_random if hasattr(self._env, "raw_env"): self._env.raw_env._rng = self.np_random timestep = self._env.reset() obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep) return obs, info
[docs] def step(self, action: int) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: """Steps through the bsuite environment.""" timestep = self._env.step(action) obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep) return ( # pyright: ignore[reportGeneralTypeIssues] obs, reward, terminated, truncated, info, )
[docs] def render(self) -> np.ndarray | None: """Renders the bsuite env.""" raise UnsupportedMode( "Rendering is not built into BSuite, print the observation instead." )
[docs] def close(self): """Closes the environment.""" self._env.close()
@property def np_random(self) -> np.random.RandomState: """This should be np.random.Generator but bsuite uses np.random.RandomState.""" return self._env._rng # pyright: ignore[reportGeneralTypeIssues] @np_random.setter def np_random(self, value: np.random.RandomState): self._env._rng = value # pyright: ignore[reportGeneralTypeIssues] def __getattr__(self, item: str): """If the attribute is missing, try getting the attribute from bsuite env.""" return getattr(self._env, item)