201 lines
5.6 KiB
Text
201 lines
5.6 KiB
Text
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "d094257c",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import os\n",
|
||
|
"import numpy as np \n",
|
||
|
"import cv2 \n",
|
||
|
"import pandas as pd\n",
|
||
|
"import matplotlib.pyplot as plt \n",
|
||
|
"import subprocess\n",
|
||
|
"import warnings\n",
|
||
|
"\n",
|
||
|
"import dataset_utils as utils"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "dbabd791",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"\n",
|
||
|
"### Read all Montezuma's Revenge trials"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "9e42d862",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"DATA_PATH = 'montezuma_revenge/'\n",
|
||
|
"\n",
|
||
|
"df = pd.read_csv(os.path.join(DATA_PATH, 'meta_data.csv'))\n",
|
||
|
"df = df.loc[df.GameName.str.contains('montezuma_revenge')]\n",
|
||
|
"df.sort_values(by=['trial_id'], inplace=True, ascending=True)\n",
|
||
|
"df.reset_index(drop=True, inplace=True)\n",
|
||
|
"df.head()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "e42160ea",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Get folder names of each trial"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "a9175b8f",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"file_lst = [os.path.join(root, name) for root, dirs, files in os.walk(DATA_PATH) for name in files if 'tar.bz2' in name]\n",
|
||
|
"\n",
|
||
|
"folder_lst = [f.split('.')[0].split('/')[-1] for f in file_lst]\n",
|
||
|
"folder_lst.sort(key=lambda x: int(str(x).split('_')[0]), reverse=False)\n",
|
||
|
"\n",
|
||
|
"df['trial_folder'] = folder_lst\n",
|
||
|
"\n",
|
||
|
"df.to_pickle(os.path.join(DATA_PATH, \"all_trials_summary.pkl\")) \n",
|
||
|
"df.head()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "28f38343",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Unpack all folders"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "c83443d7",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"for i in df.trial_id:\n",
|
||
|
" file = [f for f in file_lst if str(i) + '_' in f]\n",
|
||
|
" print(i, *file)\n",
|
||
|
" cmd = f'tar -jxf {file[0]} --directory {DATA_PATH}'\n",
|
||
|
" subprocess.call(cmd, shell=True)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "b65c1efb",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Genarate Dataframe with all Trials"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "24ec2a88",
|
||
|
"metadata": {
|
||
|
"scrolled": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"%%time\n",
|
||
|
"def write_unique_id(frame_id, episode_id, trial_id):\n",
|
||
|
" if not pd.isna(episode_id) and not pd.isna(frame_id):\n",
|
||
|
" unique_id = str(trial_id)+ '_' + '_'.join(frame_id.split('_')[:2]) + '_E{:01d}'.format(int(episode_id))\n",
|
||
|
" elif not pd.isna(frame_id):\n",
|
||
|
" unique_id = str(trial_id)+ '_' + '_'.join(frame_id.split('_')[:2]) + '_E0'\n",
|
||
|
" else: \n",
|
||
|
" unique_id = None\n",
|
||
|
" return unique_id\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"path = os.path.join(DATA_PATH, df.iloc[0].trial_folder)\n",
|
||
|
"print(path) \n",
|
||
|
" \n",
|
||
|
"# Read Annotations\n",
|
||
|
"trial_df = utils.txt_to_dataframe(path + '.txt') \n",
|
||
|
"\n",
|
||
|
"# Write unique ID\n",
|
||
|
"trial_df['ID'] = trial_df.apply(lambda x: write_unique_id(x['frame_id'], x['episode_id'], df.iloc[0].trial_id), axis=1)\n",
|
||
|
"\n",
|
||
|
"# Write image paths\n",
|
||
|
"trial_df['img_path'] = trial_df.apply(lambda x: os.path.join(path, str(x['frame_id']) + '.png'), axis=1)\n",
|
||
|
"\n",
|
||
|
"# Reorder columns\n",
|
||
|
"cols = ['ID'] + [c for c in trial_df.columns.tolist() if not c=='ID'] \n",
|
||
|
"trial_df = trial_df[cols]\n",
|
||
|
"\n",
|
||
|
"# Cut frames without annotations\n",
|
||
|
"trial_df = trial_df[trial_df.ID.notnull()] \n",
|
||
|
"\n",
|
||
|
"print(f'Episodes: {trial_df.ID.unique()}\\n')\n",
|
||
|
"\n",
|
||
|
"full_df = trial_df.copy()\n",
|
||
|
"\n",
|
||
|
"for idx in df.index[1:]:\n",
|
||
|
" row = df.iloc[idx]\n",
|
||
|
" if row.GameName == 'montezuma_revenge':\n",
|
||
|
" path = os.path.join(DATA_PATH, row.trial_folder)\n",
|
||
|
" elif row.GameName == 'montezuma_revenge_highscore':\n",
|
||
|
" path = os.path.join(DATA_PATH, 'highscore', row.trial_folder)\n",
|
||
|
" else: \n",
|
||
|
" path = ''\n",
|
||
|
" warnings.warn(f\"GameName of row {idx} not recognised! Returning empty path.\")\n",
|
||
|
" print(f'Reading {path}')\n",
|
||
|
" \n",
|
||
|
" # Read Annotations\n",
|
||
|
" trial_df = utils.txt_to_dataframe(path + '.txt') \n",
|
||
|
" \n",
|
||
|
" # Write unique ID \n",
|
||
|
" trial_df['ID'] = trial_df.apply(lambda x: write_unique_id(x['frame_id'], x['episode_id'], row.trial_id), axis=1)\n",
|
||
|
" \n",
|
||
|
" # Write image paths\n",
|
||
|
" trial_df['img_path'] = trial_df.apply(lambda x: os.path.join(path, str(x['frame_id']) + '.png'), axis=1)\n",
|
||
|
"\n",
|
||
|
" # Cut frames without annotations\n",
|
||
|
" trial_df = trial_df[trial_df.ID.notnull()] \n",
|
||
|
"\n",
|
||
|
" print(f'Episodes: {trial_df.ID.unique()}\\n')\n",
|
||
|
" full_df = pd.concat([full_df, trial_df], join='inner', ignore_index=True)\n",
|
||
|
"\n",
|
||
|
"outpath = os.path.join(DATA_PATH, \"all_trials.pkl\")\n",
|
||
|
"print(f'Saving dataframe to {outpath}\\n')\n",
|
||
|
"\n",
|
||
|
"full_df.to_pickle(outpath)\n",
|
||
|
"full_df.head()"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.9.13"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|