{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "XroKDzg8fFAs" }, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib\n", "matplotlib.use('Agg')\n", "import matplotlib.pyplot as plt\n", "\n", "# World dimensions\n", "WORLD_HEIGHT = 7\n", "WORLD_WIDTH = 10\n", "\n", "# Wind strength for each column\n", "WIND = [0, 0, 0, 1, 1, 1, 2, 2, 1, 0]\n", "\n", "# Possible actions (including King's moves)\n", "ACTION_UP = 0\n", "ACTION_DOWN = 1\n", "ACTION_LEFT = 2\n", "ACTION_RIGHT = 3\n", "ACTION_UP_LEFT = 4\n", "ACTION_UP_RIGHT = 5\n", "ACTION_DOWN_LEFT = 6\n", "ACTION_DOWN_RIGHT = 7\n", "\n", "# Probability for exploration\n", "EPSILON = 0.1\n", "\n", "# Learning rate\n", "ALPHA = 0.5\n", "\n", "# Reward for each step\n", "REWARD = -1.0\n", "\n", "# Start and Goal positions\n", "START = [3, 0]\n", "GOAL = [3, 7]\n", "\n", "# All possible actions\n", "ACTIONS = [\n", " ACTION_UP, ACTION_DOWN, ACTION_LEFT, ACTION_RIGHT,\n", " ACTION_UP_LEFT, ACTION_UP_RIGHT, ACTION_DOWN_LEFT,\n", " ACTION_DOWN_RIGHT\n", " ]\n", "\n", " def step(state, action):\n", " i, j = state\n", " if action == ACTION_UP:\n", " return [max(i - 1 - WIND[j], 0), j]\n", " elif action == ACTION_DOWN:\n", " return [max(min(i + 1 - WIND[j], WORLD_HEIGHT - 1), 0), j]\n", " elif action == ACTION_LEFT:\n", " return [max(i - WIND[j], 0), max(j - 1, 0)]=\n", " elif action == ACTION_RIGHT:\n", " return [max(i - WIND[j], 0), min(j + 1, WORLD_WIDTH - 1)]\n", " elif action == ACTION_UP_LEFT:\n", " return [max(i - 1 - WIND[max(j - 1, 0)], 0), max(j - 1, 0)]\n", " elif action == ACTION_UP_RIGHT:\n", " return [max(i - 1 - WIND[min(j + 1, WORLD_WIDTH - 1)], 0), min(j + 1, WORLD_WIDTH - 1)]\n", " elif action == ACTION_DOWN_LEFT:\n", " return [max(min(i + 1 - WIND[max(j - 1, 0)], WORLD_HEIGHT - 1), 0), max(j - 1, 0)]\n", " elif action == ACTION_DOWN_RIGHT:\n", " return [max(min(i + 1 - WIND[min(j + 1, WORLD_WIDTH - 1)], WORLD_HEIGHT - 1), 0), min(j + 1, WORLD_WIDTH - 1)]\n", " else:\n", " assert False\n", "\n", " def episode(q_value):\n", " time = 0\n", " state = START\n", " if np.random.binomial(1, EPSILON) == 1:\n", " action = np.random.choice(ACTIONS)\n", " else:\n", " values_ = q_value[state[0], state[1], :]\n", " action = np.random.choice([action_ for action_, value_ in enumerate(values_) if value_ == np.max(values_)])\n", "\n", " while state != GOAL:\n", " next_state = step(state, action)\n", " if np.random.binomial(1, EPSILON) == 1:\n", " next_action = np.random.choice(ACTIONS)\n", " else:\n", " values_ = q_value[next_state[0], next_state[1], :]\n", " next_action = np.random.choice([action_ for action_, value_ in enumerate(values_) if value_ == np.max(values_)])\n", "\n", " # Q-learning update rule\n", " q_value[state[0], state[1], action] += ALPHA * (\n", " REWARD + np.max(q_value[next_state[0], next_state[1], :]) - q_value[state[0], state[1], action]\n", " )\n", " state = next_state\n", " action = next_action\n", " time += 1\n", "\n", " return time\n", "\n", " def figure_6_3():\n", " q_value = np.zeros((WORLD_HEIGHT, WORLD_WIDTH, len(ACTIONS)))\n", " episode_limit = 500\n", " steps = []\n", " ep = 0\n", "\n", " while ep < episode_limit:\n", " steps.append(episode(q_value))\n", " ep += 1\n", "\n", " steps = np.add.accumulate(steps)\n", " plt.plot(steps, np.arange(1, len(steps) + 1))\n", " plt.xlabel('Time steps')\n", " plt.ylabel('Episodes')\n", " plt.savefig('figure_6_3.png')\n", " plt.close()\n", "\n", " optimal_policy = []\n", " for i in range(WORLD_HEIGHT):\n", " optimal_policy.append([])\n", " for j in range(WORLD_WIDTH):\n", " if [i, j] == GOAL:\n", " optimal_policy[-1].append('G')\n", " continue\n", "\n", " bestAction = np.argmax(q_value[i, j, :])\n", " if bestAction == ACTION_UP:\n", " optimal_policy[-1].append('U')\n", " elif bestAction == ACTION_DOWN:\n", " optimal_policy[-1].append('D')\n", " elif bestAction == ACTION_LEFT:\n", " optimal_policy[-1].append('L')\n", " elif bestAction == ACTION_RIGHT:\n", " optimal_policy[-1].append('R')\n", " elif bestAction == ACTION_UP_LEFT:\n", " optimal_policy[-1].append('UL')\n", " elif bestAction == ACTION_UP_RIGHT:\n", " optimal_policy[-1].append('UR')\n", " elif bestAction == ACTION_DOWN_LEFT:\n", " optimal_policy[-1].append('LL')\n", " elif bestAction == ACTION_DOWN_RIGHT:\n", " optimal_policy[-1].append('LR')\n", "\n", " print('Optimal policy is:')\n", " for row in optimal_policy:\n", " print(row)\n", " print('Wind strength for each column:\\n{}'.format([str(w) for w in WIND]))\n", "\n", " if __name__ == '__main__':\n", " figure_6_3()" ] } ] }