Int-HRL/RAMStateLabeling.ipynb

262 lines
8.1 KiB
Text

{
"cells": [
{
"cell_type": "markdown",
"id": "e5e28b09",
"metadata": {},
"source": [
"# Get RAM state of Montezuma's Revenge"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9660bf17",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import random\n",
"import cv2\n",
"\n",
"import numpy as np\n",
"import pandas as pd \n",
"import matplotlib.pyplot as plt \n",
"import gym\n",
"\n",
"from atariari.benchmark.wrapper import AtariARIWrapper\n",
"from utils import visualize_sample\n",
"\n",
"\n",
"DATA_PATH = 'montezuma_revenge'\n",
"\n",
"df = pd.read_pickle(os.path.join(DATA_PATH, \"all_trials.pkl\"))\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"id": "4417d418",
"metadata": {},
"source": [
"## Use AtariARI Wrapper to extract RAM state \n",
"```\n",
"labels: {'room_number': 15,\n",
"'player_x': 46,\n",
"'player_y': 235,\n",
"'player_direction': 76,\n",
"'enemy_skull_x': 58,\n",
"'enemy_skull_y': 240,\n",
"'key_monster_x': 132,\n",
"'key_monster_y': 254,\n",
"'level': 0,\n",
"'num_lives': 1,\n",
"'items_in_inventory_count': 0,\n",
"'room_state': 10,\n",
"'score_0': 1,\n",
"'score_1': 8,\n",
"'score_2': 0}```"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19450d3e",
"metadata": {},
"outputs": [],
"source": [
"env = AtariARIWrapper(gym.make('MontezumaRevenge-v4', \n",
" frameskip=1, \n",
" render_mode='rgb_array', \n",
" repeat_action_probability=0.0))\n",
"\n",
"#env.unwrapped.ale.getRAM()\n",
"obs = env.reset(seed=42)\n",
"obs, reward, done, info = env.step(1)"
]
},
{
"cell_type": "markdown",
"id": "cc13394b",
"metadata": {},
"source": [
"## Visualize AtariHEAD data and RAM state labels \n",
"> offset of player and skull locations was discovered manually "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3882fa3f",
"metadata": {},
"outputs": [],
"source": [
"from IPython import display\n",
"obs = env.reset()\n",
"\n",
"screen = plt.imshow(env.render(mode='rgb_array'), aspect='auto')\n",
"plt.axis('off')\n",
"\n",
"all_images = []\n",
"agent_locations = []\n",
"skull_locations = []\n",
"room_ids = []\n",
"\n",
"for i, action in enumerate(df.loc[df.ID == '285_RZ_5619207_E00'].action.values): \n",
"\n",
" n_state, reward, done, info = env.step(action)\n",
" img = info['rgb']\n",
" room_ids.append(info['labels']['room_number'])\n",
" \n",
" # agent \n",
" mean_x, mean_y = info['labels']['player_x'], 320 - info['labels']['player_y']\n",
" agent_locations.append([mean_x, mean_y])\n",
" \n",
" x1, x2, y1, y2 = mean_x - 5 , mean_x + 10, mean_y - 15, mean_y + 10\n",
" img = cv2.rectangle(img, (x1, y1), (x2, y2), (0,255,0), 2)\n",
" \n",
" # skull\n",
" mean_x, mean_y = info['labels']['enemy_skull_x'] + 35, info['labels']['enemy_skull_y'] - 65\n",
" skull_locations.append([mean_x, mean_y])\n",
" x1, x2, y1, y2 = mean_x - 5, mean_x + 5, mean_y - 10, mean_y + 5\n",
" img = cv2.rectangle(img, (x1, y1), (x2, y2), (255,0,0), 2)\n",
" \n",
" img = cv2.putText(img=img, text='Room ID: ' + str(info['labels']['room_number']) + ' index: ' + str(i), org=(5, 205), fontFace=cv2.FONT_HERSHEY_SIMPLEX, \n",
" fontScale=0.3, color=(255, 255, 255),thickness=1)\n",
" \n",
" screen.set_data(img) # just update the data\n",
" display.display(plt.gcf())\n",
" display.clear_output(wait=True)"
]
},
{
"cell_type": "markdown",
"id": "1031fee6",
"metadata": {},
"source": [
"### Number of actions with correct labeling for environment with random seed = 42 \n",
"> Discovered manually through above visualization \n",
"[-1: all actions valid, 0: no actions valid]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "868493b1",
"metadata": {},
"outputs": [],
"source": [
"test = {'284_RZ_5540489_E00': 11900, '285_RZ_5619207_E00': 2940, '285_RZ_5619207_E01': -1,\n",
" '287_RZ_7172481_E00': 0, '291_RZ_7364933_E00': 12000, '324_RZ_452975_E00':0,\n",
" '333_RZ_900705_E00': 3000, '340_RZ_1323550_E00': 5950, '359_RZ_1993616_E00': 9000,\n",
" '365_RZ_2079996_E00': 9000, '371_RZ_2173469_E00': -1, '385_RZ_2344725_E00': 3500,\n",
" '398_RZ_2530473_E00': 1200, '402_RZ_2603283_E00': -1, '416_RZ_2788252_E00': -1,\n",
" '429_RZ_2945490_E00': 4500, '436_RZ_3131841_E00': 10500, '459_RZ_3291266_E00': 5400,\n",
" '469_RZ_3390904_E00': 14500, '480_RZ_3470098_E00': 8000, '493_RZ_3557734_E00': 10500, \n",
" '523_RZ_4091327_E00': 0, '536_RZ_4420664_E00': 0, '548_RZ_4509746_E00': 0,\n",
" '561_RZ_4598680_E00': 0, '573_RZ_4680777_E00': 0, '584_RZ_4772014_E00': 0,\n",
" '588_RZ_5032278_E00': 0}\n",
"\n",
"num_frames = 0 \n",
"num_labeled_frames = 0\n",
"counter = 0 \n",
"for episode in test.keys():\n",
" counter += 1\n",
" num_frames += len(df.loc[df.ID == episode])\n",
" num_samples = test.get(episode)\n",
" if num_samples == -1:\n",
" num_labeled_frames += len(df.loc[df.ID == episode])\n",
" else: \n",
" num_labeled_frames += num_samples\n",
" \n",
"print(f'Overall percantage {num_labeled_frames / num_frames:%} for 21 episodes')"
]
},
{
"cell_type": "markdown",
"id": "7f3c0026",
"metadata": {},
"source": [
"## Label Atari-HEAD data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe0dcd7a",
"metadata": {},
"outputs": [],
"source": [
"%%time \n",
"df['level'] = None \n",
"df['room_id'] = None\n",
"df['player_location'] = None\n",
"df['skull_location'] = None \n",
"\n",
"for episode in df.ID.unique():\n",
" \n",
" obs = env.reset()\n",
" room_ids = []\n",
" agent_locations = []\n",
" skull_locations = []\n",
" level = []\n",
" \n",
" num_valid_actions = test.get(episode)\n",
"\n",
" for action in df.loc[df.ID == episode].action.values[:num_valid_actions]: \n",
" \n",
" n_state, reward, done, info = env.step(action)\n",
" room_ids.append(info['labels']['room_number'])\n",
" level.append(info['labels']['level'])\n",
" \n",
" # agent \n",
" mean_x, mean_y = info['labels']['player_x'], 320 - info['labels']['player_y']\n",
" agent_locations.append([mean_x, mean_y])\n",
" \n",
" # skull\n",
" mean_x, mean_y = info['labels']['enemy_skull_x'] + 35, info['labels']['enemy_skull_y'] - 65\n",
" skull_locations.append([mean_x, mean_y])\n",
" \n",
" index = df.loc[df.ID == episode].index[:num_valid_actions]\n",
" df.loc[index, 'level'] = level\n",
" df.loc[index, 'room_id'] = room_ids\n",
" df.loc[index, 'player_location'] = agent_locations\n",
" df.loc[index, 'skull_location'] = skull_locations\n",
"\n",
"print(f'Percentage of labeled data {len(df[df.room_id.notnull()]) / len(df):%}')\n",
"print()\n",
"df.to_pickle(os.path.join(DATA_PATH, \"all_trials_labeled.pkl\")) \n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "96ff9136",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "msc_env",
"language": "python",
"name": "msc_env"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}