Int-HRL/Preprocess_AtariHEAD.ipynb

201 lines
5.6 KiB
Text
Raw Normal View History

2025-03-12 18:20:56 +01:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "d094257c",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import numpy as np \n",
"import cv2 \n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt \n",
"import subprocess\n",
"import warnings\n",
"\n",
"import dataset_utils as utils"
]
},
{
"cell_type": "markdown",
"id": "dbabd791",
"metadata": {},
"source": [
"\n",
"### Read all Montezuma's Revenge trials"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9e42d862",
"metadata": {},
"outputs": [],
"source": [
"DATA_PATH = 'montezuma_revenge/'\n",
"\n",
"df = pd.read_csv(os.path.join(DATA_PATH, 'meta_data.csv'))\n",
"df = df.loc[df.GameName.str.contains('montezuma_revenge')]\n",
"df.sort_values(by=['trial_id'], inplace=True, ascending=True)\n",
"df.reset_index(drop=True, inplace=True)\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"id": "e42160ea",
"metadata": {},
"source": [
"#### Get folder names of each trial"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a9175b8f",
"metadata": {},
"outputs": [],
"source": [
"file_lst = [os.path.join(root, name) for root, dirs, files in os.walk(DATA_PATH) for name in files if 'tar.bz2' in name]\n",
"\n",
"folder_lst = [f.split('.')[0].split('/')[-1] for f in file_lst]\n",
"folder_lst.sort(key=lambda x: int(str(x).split('_')[0]), reverse=False)\n",
"\n",
"df['trial_folder'] = folder_lst\n",
"\n",
"df.to_pickle(os.path.join(DATA_PATH, \"all_trials_summary.pkl\")) \n",
"df.head()"
]
},
{
"cell_type": "markdown",
"id": "28f38343",
"metadata": {},
"source": [
"#### Unpack all folders"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c83443d7",
"metadata": {},
"outputs": [],
"source": [
"for i in df.trial_id:\n",
" file = [f for f in file_lst if str(i) + '_' in f]\n",
" print(i, *file)\n",
" cmd = f'tar -jxf {file[0]} --directory {DATA_PATH}'\n",
" subprocess.call(cmd, shell=True)"
]
},
{
"cell_type": "markdown",
"id": "b65c1efb",
"metadata": {},
"source": [
"## Genarate Dataframe with all Trials"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24ec2a88",
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%%time\n",
"def write_unique_id(frame_id, episode_id, trial_id):\n",
" if not pd.isna(episode_id) and not pd.isna(frame_id):\n",
" unique_id = str(trial_id)+ '_' + '_'.join(frame_id.split('_')[:2]) + '_E{:01d}'.format(int(episode_id))\n",
" elif not pd.isna(frame_id):\n",
" unique_id = str(trial_id)+ '_' + '_'.join(frame_id.split('_')[:2]) + '_E0'\n",
" else: \n",
" unique_id = None\n",
" return unique_id\n",
"\n",
"\n",
"path = os.path.join(DATA_PATH, df.iloc[0].trial_folder)\n",
"print(path) \n",
" \n",
"# Read Annotations\n",
"trial_df = utils.txt_to_dataframe(path + '.txt') \n",
"\n",
"# Write unique ID\n",
"trial_df['ID'] = trial_df.apply(lambda x: write_unique_id(x['frame_id'], x['episode_id'], df.iloc[0].trial_id), axis=1)\n",
"\n",
"# Write image paths\n",
"trial_df['img_path'] = trial_df.apply(lambda x: os.path.join(path, str(x['frame_id']) + '.png'), axis=1)\n",
"\n",
"# Reorder columns\n",
"cols = ['ID'] + [c for c in trial_df.columns.tolist() if not c=='ID'] \n",
"trial_df = trial_df[cols]\n",
"\n",
"# Cut frames without annotations\n",
"trial_df = trial_df[trial_df.ID.notnull()] \n",
"\n",
"print(f'Episodes: {trial_df.ID.unique()}\\n')\n",
"\n",
"full_df = trial_df.copy()\n",
"\n",
"for idx in df.index[1:]:\n",
" row = df.iloc[idx]\n",
" if row.GameName == 'montezuma_revenge':\n",
" path = os.path.join(DATA_PATH, row.trial_folder)\n",
" elif row.GameName == 'montezuma_revenge_highscore':\n",
" path = os.path.join(DATA_PATH, 'highscore', row.trial_folder)\n",
" else: \n",
" path = ''\n",
" warnings.warn(f\"GameName of row {idx} not recognised! Returning empty path.\")\n",
" print(f'Reading {path}')\n",
" \n",
" # Read Annotations\n",
" trial_df = utils.txt_to_dataframe(path + '.txt') \n",
" \n",
" # Write unique ID \n",
" trial_df['ID'] = trial_df.apply(lambda x: write_unique_id(x['frame_id'], x['episode_id'], row.trial_id), axis=1)\n",
" \n",
" # Write image paths\n",
" trial_df['img_path'] = trial_df.apply(lambda x: os.path.join(path, str(x['frame_id']) + '.png'), axis=1)\n",
"\n",
" # Cut frames without annotations\n",
" trial_df = trial_df[trial_df.ID.notnull()] \n",
"\n",
" print(f'Episodes: {trial_df.ID.unique()}\\n')\n",
" full_df = pd.concat([full_df, trial_df], join='inner', ignore_index=True)\n",
"\n",
"outpath = os.path.join(DATA_PATH, \"all_trials.pkl\")\n",
"print(f'Saving dataframe to {outpath}\\n')\n",
"\n",
"full_df.to_pickle(outpath)\n",
"full_df.head()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}