import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
Policy Iteration Gridworld
This notebook implements policy iteration for the classic 4x3 grid world example in Artificial Intelligence: A Modern Approach, Figure 17.2.
- Terminal states: +1 at (3,0), -1 at (3,1)
- Wall: (1,1)
- Step cost: -0.04
- Discount factor γ = 1.0
# Define state positions and mappings
= {
state_coords 0: (0, 0), 1: (1, 0), 2: (2, 0), 3: (3, 0),
4: (0, 1), 5: (1, 1), 6: (2, 1), 7: (3, 1),
8: (0, 2), 9: (1, 2), 10: (2, 2), 11: (3, 2)
}
= {3: 1.0, 7: -1.0}
terminal_states = {5}
wall = list(range(12))
states = [0, 1, 2, 3] # east, north, south, west
actions = -0.04
step_cost = 1.0
gamma = 1e-4
theta
= {0: (1, 0), 1: (0, -1), 2: (0, 1), 3: (-1, 0)}
action_delta = {s: (x, y) for s, (x, y) in state_coords.items()}
state_pos = {v: k for k, v in state_pos.items()}
pos_state
def get_transitions(s, a):
if s in terminal_states or s in wall:
return [(1.0, s)]
= state_pos[s]
x, y = []
results for prob, direction in zip([0.8, 0.1, 0.1], [a, (a+1)%4, (a+3)%4]):
= action_delta[direction]
dx, dy = x + dx, y + dy
nx, ny = pos_state.get((nx, ny), s)
ns if ns in wall:
= s
ns
results.append((prob, ns))return results
= [0.0 for _ in states]
v = [0 if s not in terminal_states and s not in wall else None for s in states]
policy
= False
is_policy_stable while not is_policy_stable:
while True:
= 0
delta for s in states:
if s in terminal_states or s in wall:
continue
= v[s]
v_old = sum(prob * (terminal_states.get(s_next, step_cost) + gamma * v[s_next])
v[s] for prob, s_next in get_transitions(s, policy[s]))
= max(delta, abs(v_old - v[s]))
delta if delta < theta:
break
= True
is_policy_stable for s in states:
if s in terminal_states or s in wall:
continue
= policy[s]
old_action = []
q_vals for a in actions:
= sum(prob * (terminal_states.get(s_next, step_cost) + gamma * v[s_next])
q for prob, s_next in get_transitions(s, a))
q_vals.append(q)= int(np.argmax(q_vals))
best_action = best_action
policy[s] if best_action != old_action:
= False is_policy_stable
= {0: (0.4, 0), 1: (0, -0.4), 2: (0, 0.4), 3: (-0.4, 0)}
arrow_map
= plt.subplots(figsize=(6, 5))
fig, ax for s, (x, y) in state_coords.items():
if s in wall:
= 'gray'
color elif s in terminal_states:
= 'lightgreen' if terminal_states[s] > 0 else 'salmon'
color else:
= 'white'
color 1, 1, edgecolor='black', facecolor=color))
ax.add_patch(patches.Rectangle((x, y), if s not in terminal_states and s not in wall:
= arrow_map[policy[s]]
dx, dy + 0.5, y + 0.5, dx, dy, head_width=0.15, head_length=0.1, fc='blue', ec='blue')
ax.arrow(x + 0.05, y + 0.05, f"{v[s]:.2f}", fontsize=8, color='black')
ax.text(x
for s, r in terminal_states.items():
= state_coords[s]
x, y + 0.35, y + 0.35, f"{r:+.0f}", fontsize=14, color='black')
ax.text(x
0, 4)
ax.set_xlim(0, 3)
ax.set_ylim(
ax.set_xticks([])
ax.set_yticks([])'equal')
ax.set_aspect("Policy Iteration: Step Cost = -0.04, γ = 1.0")
plt.title(
plt.gca().invert_yaxis() plt.show()