SARSA Gridworld Example

sarsa-gridworld SARSA Gridworld

import numpy as np
import random
from collections import defaultdict
import sys
sys.path.insert(0, './environment')
from environment import Env

# SARSA agent learns every time step from the sample <s, a, r, s', a'>
class SARSAgent:
    def __init__(self, actions):
        self.actions = actions
        self.learning_rate = 0.01
        self.discount_factor = 0.9
        self.epsilon = 0.1
        self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])

    # with sample <s, a, r, s', a'>, learns new q function
    def learn(self, state, action, reward, next_state, next_action):
        current_q = self.q_table[state][action]
        next_state_q = self.q_table[next_state][next_action]
        new_q = (current_q + self.learning_rate *
                (reward + self.discount_factor * next_state_q - current_q))
        self.q_table[state][action] = new_q

    # get action for the state according to the q function table
    # agent pick action of epsilon-greedy policy
    def get_action(self, state):
        if np.random.rand() < self.epsilon:
            # take random action
            action = np.random.choice(self.actions)
        else:
            # take action according to the q function table
            state_action = self.q_table[state]
            action = self.arg_max(state_action)
        return action

    @staticmethod
    def arg_max(state_action):
        max_index_list = []
        max_value = state_action[0]
        for index, value in enumerate(state_action):
            if value > max_value:
                max_index_list.clear()
                max_value = value
                max_index_list.append(index)
            elif value == max_value:
                max_index_list.append(index)
        return random.choice(max_index_list)
if __name__ == "__main__":
    env = Env()
    agent = SARSAgent(actions=list(range(env.n_actions)))

    for episode in range(1000):
        # reset environment and initialize state

        state = env.reset()
        # get action of state from agent
        action = agent.get_action(str(state))

        while True:
            env.render()

            # take action and proceed one step in the environment
            next_state, reward, done = env.step(action)
            next_action = agent.get_action(str(next_state))

            # with sample <s,a,r,s',a'>, agent learns new q function
            agent.learn(str(state), action, reward, str(next_state), next_action)

            state = next_state
            action = next_action

            # print q function of all states at screen
            env.print_value_all(agent.q_table)

            # if episode ends, then break
            if done:
                break
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/workspaces/artificial_intelligence/artificial_intelligence/aiml-common/lectures/reinforcement-learning/model-free-control/sarsa/gridworld/sarsa_gridworld.ipynb Cell 3 line 1
     <a href='vscode-notebook-cell://dev-container%2B7b22686f737450617468223a222f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e6365222c226c6f63616c446f636b6572223a66616c73652c2273657474696e6773223a7b22686f7374223a22756e69783a2f2f2f7661722f72756e2f646f636b65722e736f636b227d2c22636f6e66696746696c65223a7b22246d6964223a312c22667350617468223a222f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e63652f2e646576636f6e7461696e65722f646576636f6e7461696e65722e6a736f6e222c2265787465726e616c223a2266696c653a2f2f2f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e63652f2e646576636f6e7461696e65722f646576636f6e7461696e65722e6a736f6e222c2270617468223a222f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e63652f2e646576636f6e7461696e65722f646576636f6e7461696e65722e6a736f6e222c22736368656d65223a2266696c65227d7d/workspaces/artificial_intelligence/artificial_intelligence/aiml-common/lectures/reinforcement-learning/model-free-control/sarsa/gridworld/sarsa_gridworld.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=9'>10</a> action = agent.get_action(str(state))
     <a href='vscode-notebook-cell://dev-container%2B7b22686f737450617468223a222f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e6365222c226c6f63616c446f636b6572223a66616c73652c2273657474696e6773223a7b22686f7374223a22756e69783a2f2f2f7661722f72756e2f646f636b65722e736f636b227d2c22636f6e66696746696c65223a7b22246d6964223a312c22667350617468223a222f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e63652f2e646576636f6e7461696e65722f646576636f6e7461696e65722e6a736f6e222c2265787465726e616c223a2266696c653a2f2f2f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e63652f2e646576636f6e7461696e65722f646576636f6e7461696e65722e6a736f6e222c2270617468223a222f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e63652f2e646576636f6e7461696e65722f646576636f6e7461696e65722e6a736f6e222c22736368656d65223a2266696c65227d7d/workspaces/artificial_intelligence/artificial_intelligence/aiml-common/lectures/reinforcement-learning/model-free-control/sarsa/gridworld/sarsa_gridworld.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=11'>12</a> while True:
---> <a href='vscode-notebook-cell://dev-container%2B7b22686f737450617468223a222f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e6365222c226c6f63616c446f636b6572223a66616c73652c2273657474696e6773223a7b22686f7374223a22756e69783a2f2f2f7661722f72756e2f646f636b65722e736f636b227d2c22636f6e66696746696c65223a7b22246d6964223a312c22667350617468223a222f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e63652f2e646576636f6e7461696e65722f646576636f6e7461696e65722e6a736f6e222c2265787465726e616c223a2266696c653a2f2f2f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e63652f2e646576636f6e7461696e65722f646576636f6e7461696e65722e6a736f6e222c2270617468223a222f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e63652f2e646576636f6e7461696e65722f646576636f6e7461696e65722e6a736f6e222c22736368656d65223a2266696c65227d7d/workspaces/artificial_intelligence/artificial_intelligence/aiml-common/lectures/reinforcement-learning/model-free-control/sarsa/gridworld/sarsa_gridworld.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=12'>13</a>     env.render()
     <a href='vscode-notebook-cell://dev-container%2B7b22686f737450617468223a222f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e6365222c226c6f63616c446f636b6572223a66616c73652c2273657474696e6773223a7b22686f7374223a22756e69783a2f2f2f7661722f72756e2f646f636b65722e736f636b227d2c22636f6e66696746696c65223a7b22246d6964223a312c22667350617468223a222f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e63652f2e646576636f6e7461696e65722f646576636f6e7461696e65722e6a736f6e222c2265787465726e616c223a2266696c653a2f2f2f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e63652f2e646576636f6e7461696e65722f646576636f6e7461696e65722e6a736f6e222c2270617468223a222f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e63652f2e646576636f6e7461696e65722f646576636f6e7461696e65722e6a736f6e222c22736368656d65223a2266696c65227d7d/workspaces/artificial_intelligence/artificial_intelligence/aiml-common/lectures/reinforcement-learning/model-free-control/sarsa/gridworld/sarsa_gridworld.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=14'>15</a>     # take action and proceed one step in the environment
     <a href='vscode-notebook-cell://dev-container%2B7b22686f737450617468223a222f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e6365222c226c6f63616c446f636b6572223a66616c73652c2273657474696e6773223a7b22686f7374223a22756e69783a2f2f2f7661722f72756e2f646f636b65722e736f636b227d2c22636f6e66696746696c65223a7b22246d6964223a312c22667350617468223a222f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e63652f2e646576636f6e7461696e65722f646576636f6e7461696e65722e6a736f6e222c2265787465726e616c223a2266696c653a2f2f2f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e63652f2e646576636f6e7461696e65722f646576636f6e7461696e65722e6a736f6e222c2270617468223a222f686f6d652f70616e74656c69732e6d6f6e6f67696f756469732f6c6f63616c2f7765622f73697465732f636f75727365732f6172746966696369616c5f696e74656c6c6967656e63652f2e646576636f6e7461696e65722f646576636f6e7461696e65722e6a736f6e222c22736368656d65223a2266696c65227d7d/workspaces/artificial_intelligence/artificial_intelligence/aiml-common/lectures/reinforcement-learning/model-free-control/sarsa/gridworld/sarsa_gridworld.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=15'>16</a>     next_state, reward, done = env.step(action)

File /workspaces/artificial_intelligence/artificial_intelligence/aiml-common/lectures/reinforcement-learning/model-free-control/sarsa/gridworld/./environment/environment.py:141, in Env.render(self)
    140 def render(self):
--> 141     time.sleep(0.03)
    142     self.update()

KeyboardInterrupt: 
Back to top