# aima_gridworld_env.py
import gymnasium as gym
from gymnasium import spaces
from minigrid.core.grid import Grid
from minigrid.minigrid_env import MiniGridEnv
class AIMAGridworldEnv(MiniGridEnv):
    """
    4×3 Gridworld from AIMA Ch.17, Fig 17.1:
      - Grid size: width=4, height=3
      - Wall at (1,1)
      - Terminal +1 at (3,2); Terminal –1 at (3,1)
      - Step cost for non-terminal moves (default 0.0)
      - Deterministic actions: 0=right,1=down,2=left,3=up
    """
    def __init__(self, step_cost: float = 0.0):
        super().__init__(
            grid_size=(4, 3),
            max_steps=100,
            see_through_walls=True
        )
        # override action space to 4 direct moves
        self.action_space = spaces.Discrete(4)
        self.step_cost = step_cost
        self.goal_reward = 1.0
        self.pit_reward = -1.0
        # start state (as in Fig 17.1): lower-left corner
        self.start_pos = (0, 0)
        self.start_dir = 0
    def _gen_grid(self, width, height):
        # create empty grid and fill outer walls
        self.grid = Grid(width, height)
        self.grid.wall_rect(0, 0, width, height)
        # inner wall at (1,1)
        self.grid.set(1, 1, Wall())
        # place goal (+1) at (3,2)
        goal = Goal()
        self.grid.set(3, 2, goal)
        # place pit (-1) at (3,1) using Lava
        lava = Lava()
        self.grid.set(3, 1, lava)
        # set agent start
        self.start_pos = self.start_pos
        self.start_dir = self.start_dir
        self.mission = "reach +1 and avoid –1"
    def step(self, action):
        """
        Interpret action ∈ {0,1,2,3} as a move in (→,↓,←,↑).
        Walls/boundaries block movement.
        Returns (obs, reward, terminated, truncated, info).
        """
        # map action to dx,dy
        dirs = {
            0: (1, 0),   # right
            1: (0, 1),   # down
            2: (-1, 0),  # left
            3: (0, -1),  # up
        }
        dx, dy = dirs[action]
        x, y = self.agent_pos
        nx, ny = x + dx, y + dy
        # check for wall or out-of-bounds
        if not (0 <= nx < self.width and 0 <= ny < self.height) or \
           not self.grid.get(nx, ny).can_overlap(self.agent):
            nx, ny = x, y
        self.agent_pos = (nx, ny)
        obs = self.gen_obs()
        # terminal checks
        if (nx, ny) == (3, 2):
            return obs, self.goal_reward, True, False, {}
        if (nx, ny) == (3, 1):
            return obs, self.pit_reward, True, False, {}
        # non-terminal step cost
        return obs, self.step_cost, False, False, {}