262 lines
8.1 KiB
Text
262 lines
8.1 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e5e28b09",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Get RAM state of Montezuma's Revenge"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9660bf17",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import random\n",
|
|
"import cv2\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd \n",
|
|
"import matplotlib.pyplot as plt \n",
|
|
"import gym\n",
|
|
"\n",
|
|
"from atariari.benchmark.wrapper import AtariARIWrapper\n",
|
|
"from utils import visualize_sample\n",
|
|
"\n",
|
|
"\n",
|
|
"DATA_PATH = 'montezuma_revenge'\n",
|
|
"\n",
|
|
"df = pd.read_pickle(os.path.join(DATA_PATH, \"all_trials.pkl\"))\n",
|
|
"df.head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4417d418",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Use AtariARI Wrapper to extract RAM state \n",
|
|
"```\n",
|
|
"labels: {'room_number': 15,\n",
|
|
"'player_x': 46,\n",
|
|
"'player_y': 235,\n",
|
|
"'player_direction': 76,\n",
|
|
"'enemy_skull_x': 58,\n",
|
|
"'enemy_skull_y': 240,\n",
|
|
"'key_monster_x': 132,\n",
|
|
"'key_monster_y': 254,\n",
|
|
"'level': 0,\n",
|
|
"'num_lives': 1,\n",
|
|
"'items_in_inventory_count': 0,\n",
|
|
"'room_state': 10,\n",
|
|
"'score_0': 1,\n",
|
|
"'score_1': 8,\n",
|
|
"'score_2': 0}```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "19450d3e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"env = AtariARIWrapper(gym.make('MontezumaRevenge-v4', \n",
|
|
" frameskip=1, \n",
|
|
" render_mode='rgb_array', \n",
|
|
" repeat_action_probability=0.0))\n",
|
|
"\n",
|
|
"#env.unwrapped.ale.getRAM()\n",
|
|
"obs = env.reset(seed=42)\n",
|
|
"obs, reward, done, info = env.step(1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cc13394b",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Visualize AtariHEAD data and RAM state labels \n",
|
|
"> offset of player and skull locations was discovered manually "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3882fa3f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from IPython import display\n",
|
|
"obs = env.reset()\n",
|
|
"\n",
|
|
"screen = plt.imshow(env.render(mode='rgb_array'), aspect='auto')\n",
|
|
"plt.axis('off')\n",
|
|
"\n",
|
|
"all_images = []\n",
|
|
"agent_locations = []\n",
|
|
"skull_locations = []\n",
|
|
"room_ids = []\n",
|
|
"\n",
|
|
"for i, action in enumerate(df.loc[df.ID == '285_RZ_5619207_E00'].action.values): \n",
|
|
"\n",
|
|
" n_state, reward, done, info = env.step(action)\n",
|
|
" img = info['rgb']\n",
|
|
" room_ids.append(info['labels']['room_number'])\n",
|
|
" \n",
|
|
" # agent \n",
|
|
" mean_x, mean_y = info['labels']['player_x'], 320 - info['labels']['player_y']\n",
|
|
" agent_locations.append([mean_x, mean_y])\n",
|
|
" \n",
|
|
" x1, x2, y1, y2 = mean_x - 5 , mean_x + 10, mean_y - 15, mean_y + 10\n",
|
|
" img = cv2.rectangle(img, (x1, y1), (x2, y2), (0,255,0), 2)\n",
|
|
" \n",
|
|
" # skull\n",
|
|
" mean_x, mean_y = info['labels']['enemy_skull_x'] + 35, info['labels']['enemy_skull_y'] - 65\n",
|
|
" skull_locations.append([mean_x, mean_y])\n",
|
|
" x1, x2, y1, y2 = mean_x - 5, mean_x + 5, mean_y - 10, mean_y + 5\n",
|
|
" img = cv2.rectangle(img, (x1, y1), (x2, y2), (255,0,0), 2)\n",
|
|
" \n",
|
|
" img = cv2.putText(img=img, text='Room ID: ' + str(info['labels']['room_number']) + ' index: ' + str(i), org=(5, 205), fontFace=cv2.FONT_HERSHEY_SIMPLEX, \n",
|
|
" fontScale=0.3, color=(255, 255, 255),thickness=1)\n",
|
|
" \n",
|
|
" screen.set_data(img) # just update the data\n",
|
|
" display.display(plt.gcf())\n",
|
|
" display.clear_output(wait=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1031fee6",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Number of actions with correct labeling for environment with random seed = 42 \n",
|
|
"> Discovered manually through above visualization \n",
|
|
"[-1: all actions valid, 0: no actions valid]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "868493b1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"test = {'284_RZ_5540489_E00': 11900, '285_RZ_5619207_E00': 2940, '285_RZ_5619207_E01': -1,\n",
|
|
" '287_RZ_7172481_E00': 0, '291_RZ_7364933_E00': 12000, '324_RZ_452975_E00':0,\n",
|
|
" '333_RZ_900705_E00': 3000, '340_RZ_1323550_E00': 5950, '359_RZ_1993616_E00': 9000,\n",
|
|
" '365_RZ_2079996_E00': 9000, '371_RZ_2173469_E00': -1, '385_RZ_2344725_E00': 3500,\n",
|
|
" '398_RZ_2530473_E00': 1200, '402_RZ_2603283_E00': -1, '416_RZ_2788252_E00': -1,\n",
|
|
" '429_RZ_2945490_E00': 4500, '436_RZ_3131841_E00': 10500, '459_RZ_3291266_E00': 5400,\n",
|
|
" '469_RZ_3390904_E00': 14500, '480_RZ_3470098_E00': 8000, '493_RZ_3557734_E00': 10500, \n",
|
|
" '523_RZ_4091327_E00': 0, '536_RZ_4420664_E00': 0, '548_RZ_4509746_E00': 0,\n",
|
|
" '561_RZ_4598680_E00': 0, '573_RZ_4680777_E00': 0, '584_RZ_4772014_E00': 0,\n",
|
|
" '588_RZ_5032278_E00': 0}\n",
|
|
"\n",
|
|
"num_frames = 0 \n",
|
|
"num_labeled_frames = 0\n",
|
|
"counter = 0 \n",
|
|
"for episode in test.keys():\n",
|
|
" counter += 1\n",
|
|
" num_frames += len(df.loc[df.ID == episode])\n",
|
|
" num_samples = test.get(episode)\n",
|
|
" if num_samples == -1:\n",
|
|
" num_labeled_frames += len(df.loc[df.ID == episode])\n",
|
|
" else: \n",
|
|
" num_labeled_frames += num_samples\n",
|
|
" \n",
|
|
"print(f'Overall percantage {num_labeled_frames / num_frames:%} for 21 episodes')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7f3c0026",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Label Atari-HEAD data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fe0dcd7a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%%time \n",
|
|
"df['level'] = None \n",
|
|
"df['room_id'] = None\n",
|
|
"df['player_location'] = None\n",
|
|
"df['skull_location'] = None \n",
|
|
"\n",
|
|
"for episode in df.ID.unique():\n",
|
|
" \n",
|
|
" obs = env.reset()\n",
|
|
" room_ids = []\n",
|
|
" agent_locations = []\n",
|
|
" skull_locations = []\n",
|
|
" level = []\n",
|
|
" \n",
|
|
" num_valid_actions = test.get(episode)\n",
|
|
"\n",
|
|
" for action in df.loc[df.ID == episode].action.values[:num_valid_actions]: \n",
|
|
" \n",
|
|
" n_state, reward, done, info = env.step(action)\n",
|
|
" room_ids.append(info['labels']['room_number'])\n",
|
|
" level.append(info['labels']['level'])\n",
|
|
" \n",
|
|
" # agent \n",
|
|
" mean_x, mean_y = info['labels']['player_x'], 320 - info['labels']['player_y']\n",
|
|
" agent_locations.append([mean_x, mean_y])\n",
|
|
" \n",
|
|
" # skull\n",
|
|
" mean_x, mean_y = info['labels']['enemy_skull_x'] + 35, info['labels']['enemy_skull_y'] - 65\n",
|
|
" skull_locations.append([mean_x, mean_y])\n",
|
|
" \n",
|
|
" index = df.loc[df.ID == episode].index[:num_valid_actions]\n",
|
|
" df.loc[index, 'level'] = level\n",
|
|
" df.loc[index, 'room_id'] = room_ids\n",
|
|
" df.loc[index, 'player_location'] = agent_locations\n",
|
|
" df.loc[index, 'skull_location'] = skull_locations\n",
|
|
"\n",
|
|
"print(f'Percentage of labeled data {len(df[df.room_id.notnull()]) / len(df):%}')\n",
|
|
"print()\n",
|
|
"df.to_pickle(os.path.join(DATA_PATH, \"all_trials_labeled.pkl\")) \n",
|
|
"df.head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "96ff9136",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "msc_env",
|
|
"language": "python",
|
|
"name": "msc_env"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|