import numpy as np
import random
from collections import defaultdict
import sys
import os
'PYTHONWARNINGS'] = 'ignore' # Suppress warnings
os.environ[0, './environment')
sys.path.insert(from environment import Env
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
# SARSA agent learns every time step from the sample <s, a, r, s', a'>
class SARSAgent:
def __init__(self, actions):
self.actions = actions
self.learning_rate = 0.01
self.discount_factor = 0.9
self.epsilon = 0.1
self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])
# with sample <s, a, r, s', a'>, learns new q function
def learn(self, state, action, reward, next_state, next_action):
= self.q_table[state][action]
current_q = self.q_table[next_state][next_action]
next_state_q = (current_q + self.learning_rate *
new_q + self.discount_factor * next_state_q - current_q))
(reward self.q_table[state][action] = new_q
# get action for the state according to the q function table
# agent pick action of epsilon-greedy policy
def get_action(self, state):
if np.random.rand() < self.epsilon:
# take random action
= np.random.choice(self.actions)
action else:
# take action according to the q function table
= self.q_table[state]
state_action = self.arg_max(state_action)
action return action
@staticmethod
def arg_max(state_action):
= []
max_index_list = state_action[0]
max_value for index, value in enumerate(state_action):
if value > max_value:
max_index_list.clear()= value
max_value
max_index_list.append(index)elif value == max_value:
max_index_list.append(index)return random.choice(max_index_list)
SARSA Gridworld Example
SARSA Gridworld
if __name__ == "__main__":
= Env()
env = SARSAgent(actions=list(range(env.n_actions)))
agent
for episode in tqdm(range(10000), desc="Training Episodes"):
#print("Episode: ", episode)
# reset environment and initialize state
= env.reset()
state # get action of state from agent
= agent.get_action(str(state))
action
while True:
# env.render() # Commented out to prevent matplotlib figure output
# take action and proceed one step in the environment
= env.step(action)
next_state, reward, done = agent.get_action(str(next_state))
next_action
# with sample <s,a,r,s',a'>, agent learns new q function
str(state), action, reward, str(next_state), next_action)
agent.learn(
= next_state
state = next_action
action
# if episode ends, then break
if done:
break
Training Episodes: 0%| | 0/10000 [00:00<?, ?it/s]Training Episodes: 100%|██████████| 10000/10000 [03:33<00:00, 46.89it/s]
def plot_q_values(q_table):
# Create a grid to store the maximum Q-value for each state
= np.zeros((5, 5))
q_values_grid
for state, actions in q_table.items():
try:
= eval(state) # Convert string back to list
state_coords 0], state_coords[1]] = max(actions)
q_values_grid[state_coords[except (SyntaxError, TypeError, IndexError):
# Skip invalid states
continue
# Plot the heatmap
=(8, 6))
plt.figure(figsize=True, fmt='.2f', cmap='coolwarm', cbar=True)
sns.heatmap(q_values_grid, annot'Heatmap of Maximum Q-Values for Each State')
plt.title('X Coordinate')
plt.xlabel('Y Coordinate')
plt.ylabel(
plt.show()
plot_q_values(agent.q_table)