limits-of-tom/compare_tom_cpa.ipynb

481 lines
230 KiB
Text
Raw Normal View History

2024-06-11 15:36:55 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pickle \n",
"import matplotlib.pyplot as plt\n",
"from sklearn.metrics import f1_score\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from scipy.stats import pearsonr\n",
"import numpy as np "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# task status\n",
"tom6_fname = \"XXX\"\n",
"# task knowledge\n",
"tom7_fname = \"XXX\"\n",
"# task intention\n",
"tom8_fname = \"XXX\"\n",
"# own plan\n",
"exp2_5_fname = \"XXX\"\n",
"exp2_6_fname = \"XXX\"\n",
"exp2_0_fname = \"XXX\"\n",
"# partner plan \n",
"exp3_7_fname = \"XXX\"\n",
"exp3_0_fname = \"XXX\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"with open(tom6_fname, 'rb') as f: \n",
" tom6 = pickle.load(f)\n",
"with open(tom7_fname, 'rb') as f: \n",
" tom7 = pickle.load(f)\n",
"with open(tom8_fname, 'rb') as f: \n",
" tom8 = pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"with open(exp2_5_fname, 'rb') as f:\n",
" exp2_5 = pickle.load(f)\n",
"with open(exp2_6_fname, 'rb') as f:\n",
" exp2_6 = pickle.load(f)\n",
"with open(exp2_0_fname, 'rb') as f:\n",
" exp2_0 = pickle.load(f)\n",
"with open(exp3_0_fname, 'rb') as f:\n",
" exp3_0 = pickle.load(f)\n",
"with open(exp3_7_fname, 'rb') as f:\n",
" exp3_7 = pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"tom6 = [item[0] for item in tom6 if item and item[0]]\n",
"tom7 = [item[0] for item in tom7 if item and item[0]]\n",
"tom8 = [item[0] for item in tom8 if item and item[0]]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# compute per-game f1 scores\n",
"tom6_f1 = []\n",
"tom6_game_paths = []\n",
"for game in tom6: \n",
" pred = game[0]\n",
" gt = game[1]\n",
" game_path = game[2]\n",
" f1 = f1_score(gt, pred, average=\"weighted\")\n",
" tom6_f1.append(f1)\n",
" tom6_game_paths.append(game_path)\n",
"\n",
"tom7_f1 = []\n",
"tom7_game_paths = []\n",
"for game in tom7: \n",
" pred = game[0]\n",
" gt = game[1]\n",
" game_path = game[2]\n",
" f1 = f1_score(gt, pred, average=\"weighted\")\n",
" tom7_f1.append(f1)\n",
" tom7_game_paths.append(game_path)\n",
"\n",
"tom8_f1 = []\n",
"tom8_game_paths = []\n",
"for game in tom8: \n",
" pred = game[0]\n",
" gt = game[1]\n",
" game_path = game[2]\n",
" f1 = f1_score(gt, pred, average=\"weighted\")\n",
" tom8_f1.append(f1)\n",
" tom8_game_paths.append(game_path)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tom6_game_paths == tom7_game_paths == tom8_game_paths"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'141_212_108_99_20210325_121618'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"game_ids = [path.split(\"/\")[-1] for path in tom6_game_paths]\n",
"game_ids[0]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"exp2_5_f1 = [item[-1] for item in exp2_5]\n",
"exp2_6_f1 = [item[-1] for item in exp2_6]\n",
"exp2_0_f1 = [item[-1] for item in exp2_0]\n",
"exp3_0_f1 = [item[-1] for item in exp3_0]\n",
"exp3_7_f1 = [item[-1] for item in exp3_7]\n",
"exp2_game_ids = [item[-3].split(\"/\")[-1] for item in exp2_5]"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"exp2_f1_diff_5_0 = [a - b for a, b in zip(exp2_5_f1, exp2_0_f1)]\n",
"exp2_f1_diff_6_0 = [a - b for a, b in zip(exp2_6_f1, exp2_0_f1)]\n",
"exp3_f1_diff_7_0 = [a - b for a, b in zip(exp3_7_f1, exp3_0_f1)]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(60, 64)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(exp2_game_ids), len(game_ids)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"common_elements = set(exp2_game_ids) & set(game_ids)\n",
"exp2_common_game_ids = []\n",
"exp2_common_f1_diff_5_0 = []\n",
"exp2_common_f1_diff_6_0 = []\n",
"exp3_common_f1_diff_7_0 = []\n",
"for i in range(len(exp2_game_ids)):\n",
" if exp2_game_ids[i] in common_elements:\n",
" exp2_common_game_ids.append(exp2_game_ids[i])\n",
" exp2_common_f1_diff_5_0.append(exp2_f1_diff_5_0[i])\n",
" exp2_common_f1_diff_6_0.append(exp2_f1_diff_6_0[i])\n",
" exp3_common_f1_diff_7_0.append(exp3_f1_diff_7_0[i])\n",
"\n",
"common_game_ids = []\n",
"tom6_common_f1 = []\n",
"tom7_common_f1 = []\n",
"tom8_common_f1 = []\n",
"for i in range(len(game_ids)):\n",
" if game_ids[i] in common_elements:\n",
" common_game_ids.append(game_ids[i])\n",
" tom6_common_f1.append(tom6_f1[i])\n",
" tom7_common_f1.append(tom7_f1[i])\n",
" tom8_common_f1.append(tom8_f1[i])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"exp2_common_game_ids == common_game_ids"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Own missing knowledge (exp2)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"tom6correlation_coefficient, tom6p_value = pearsonr(tom6_common_f1, exp2_common_f1_diff_5_0)\n",
"tom7correlation_coefficient, tom7p_value = pearsonr(tom7_common_f1, exp2_common_f1_diff_6_0)\n",
"tom8correlation_coefficient, tom8p_value = pearsonr(tom8_common_f1, exp2_common_f1_diff_6_0)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.21866640044266308 0.09913002973227936\n",
"0.08225606020725357 0.5393156230207309\n",
"0.15718692888976768 0.23864397213593697\n"
]
}
],
"source": [
"print(tom6correlation_coefficient, tom6p_value)\n",
"print(tom7correlation_coefficient, tom7p_value)\n",
"print(tom8correlation_coefficient, tom8p_value)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB1gAAAHzCAYAAACNAs8uAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzde1xUdf4/8NeZAQSDQQXMC4JCRXhpvVTeUtHKJe+pP7da87K6ldVqrZu1F7esdv1m2+ali2XrJbN2raULSqi7opjXQjFNNFMUxQvghQEDZGbO7w92JpCLM8yZeZ9hXs8ePBrOnDnn83l9PjN8PJ855yiqqqogIiIiIiIiIiIiIiIiIqLrMkgXgIiIiIiIiIiIiIiIiIjIV3CClYiIiIiIiIiIiIiIiIjISZxgJSIiIiIiIiIiIiIiIiJyUoAWGzl58iS2bt2KrVu34uDBgygsLERhYSEAoHXr1oiKikKXLl0waNAgDBo0CLGxsVrsloiIiIiIiIiIiIiIiIjIqxRVVdXGvPDKlSv48MMP8e6772Lv3r2O5fVtTlEUx+NevXrhkUcewYMPPogbbrihMbsnIiIiIiIiIiIiIiIiIvI6lydYS0tL8corr2Dx4sUoLS11TKgajUZ07doV8fHxaNWqFVq1agUAuHjxIi5evIhjx47h4MGDsFqtVTtWFNxwww2YNWsWnn32WYSGhmpcNSIiIiIiIiIiIiIiIiIibbk0wfrmm2/ipZdeQmFhIVRVRWxsLB566CHcd9996NWrF0JCQhp8fVlZGbKysrBhwwZ8+OGHyM3NhaIoiIyMxJ///Gc88cQTbleIiIiIiIiIiIiIiIiIiMhTXJpgNRgMUBQF48ePx8yZM9G/f3+3dr5jxw4sWrQI//73v6GqquPsViIiIiIiIiIiIiIiIiIiPXJpgvXBBx/E3Llz0blzZ00L8d133+Hll1/GRx99pOl2iYiIiIiIiIiIiIiIiIi05PI9WImIiIiIiIiIiIiIiIiI/JVBugBERERERERERERERERERL6CE6xE1KStW7cOI0aMQNu2bREcHIzY2FhMmzYNBw4caNT2ysvL8dlnn+Hxxx/H7bffjpYtWyIwMBCRkZEYNGgQXnvtNZjNZo9vg2rTuq3tVFVFTk4OVq1ahSeeeAJ33HEHmjVrBkVRoCgKTpw4oYtyEhER+Qpv/C10dx8//PADnn76afzsZz+DyWRCYGAgoqKikJSUhL///e8oLS3VrKxNnafa+8SJE47x2PV+ioqKXNp2YWEhIiMjHa+fMmWKW2UlIiLyFb5ybAUAMjIyMHnyZMTHx6N58+Zo2bIlOnfujMmTJ+Nf//qXW+X1F55ob08e9+QYjfTGpUsEZ2ZmarrzgQMHaro9IqLqZsyYgaVLl9b5XLNmzfDuu+9i0qRJLm3TZDKhpKSkwXWio6Px73//G3feeafHtkE1eaKt7U6cOIFOnTrV+3xubi46duzo1LY8WU4iIiJf4I2/he7uY9WqVXjsscdQXl5e7zoxMTH48ssv0blzZ7fK2tRJjtGqsx+Mc9bEiROxZs0ax++TJ0/GypUrXS0iERGRT/GVYytlZWWYNm0aPvroo3rXiY2NbdSkrT/xVHt78rgnx2ikO6oLFEVRDQaDJj9Go9GVXRMRueSVV15RAagA1DFjxqhZWVlqQUGBumHDBrVr164qADUgIED96quvXNouADUoKEidMGGC+uGHH6pHjx5VL168qB48eFB97rnn1ICAABWA2rJlS/X06dMe2wb9xFNtbZebm+vYfnR0tHr//ferAwYMcCzLzc3VRTmJiIj0zht/C93dx549e1SDwaACUKOiotQ33nhDzcnJUQsKCtTdu3erU6ZMcWw/Pj5eLS8vb3RZmzpvjtHS0tLUkpKSen9csWHDBhWAGhcX59j+5MmTG1VGIiIiX+Erx1YqKyvVoUOHqgDUwMBA9amnnlJ37dqlFhQUqOfOnVO3bt2qzp49W+3bt2+jyukvPNnenjruyTEa6VGjJlgVRXH7x2AweKpOROTnCgoK1NDQUBWAOnToUNVms9V4vqioSL3xxhtVAGrv3r1d2vbjjz+unj17tt7n16xZ4/gjP2PGDI9tg6p4sq3tzGaz+tlnn9Vos+eff96lfwR4o5xERER65o2/hVrs45e//KUKQDUYDOru3bvrXOexxx5zjAM+++yzRpW1qfNGe1c/UJuRkaFBqVX1ypUrjoN2aWlpPHhHRER+wVeOraiqqs6fP18FoAYHB2v299/feLq9PXHck2M00qtG3YM1PDwc06dPx/r165GRkdGon82bNzdm10RE17Vq1SrHfbHmz58PRVFqPB8REYE5c+YAAHbv3o29e/c6ve0333wTbdq0qff5hx56CN26dQMAfPnllx7bBlXxZFvbhYWFYfTo0Q22mR7KSUREpGfe+FuoxT6ys7MBADfffHO9lyx7+OGHHY8PHz7scjn9ga+OfV544QUcP34c48ePx3333SddHCIiIq/wlWMrly5dwosvvggA+OMf/4ikpKRGb8ufebq9PXHck2M00iuXJlh79+4NVVVRXFyMf/zjH3jggQewatUqWK1WDBo0yOUfIvIt//znP6EoCoxGI3788UecOHECv/3tb3HLLbcgODgYiqLg1KlT0sVEamoqACA+Ph49e/asc50JEyY4Hn/xxRea7r9Lly4AgDNnzohuwxnr1q1z3Bj+woUL2Lp1Kx566CF07NgRISEhCA8Px7333outW7d6tByNJd3WzvKVchIRke/h+EzbfQQHBwNArQNN1RmNRsfj1q1bu1xOZ3CM5n3Z2dl4/fXXERYWhkWLFkkXh4iImgCO07T1wQcfoKysDEFBQXjiiSdEygBwnKYFV457coxGeubSBOvOnTtx9OhR/OlPf0LHjh1RUlKClStX4t5770VMTAx+//vf49ChQ54qKxEJ27dvH4CqP8BffPEFunbtitdffx1Hjx5FRUUFWrZsiQ4dOgiXEo5vVvXp06fedaKjo9G+fXsAQFZWlqb7P3/+PICqs/0lt+GM/fv3AwA6dOiAV155BUlJSfjoo49w8uRJlJeXw2w24z//+Q/uvvtufPbZZx4tS2NIt7WzfKWcRETkezg+03Yf9oNM33//vWOcdK1//etfAIBmzZphyJAhLpfTGRyjNc7Vq1cb9TqbzYZHHnkEFosFL730Etq1a6dJeYiIyL9xnKattLQ0AMDtt9+Oli1bOpZbrVbYbDavlYPjNPc5e9yTYzTSO5cvERwfH48XX3wRx44dQ2ZmJqZPn47w8HCcPn0aCxYsQLdu3dCrVy8sWrQIBQUFnigzEQmxXzKtuLgYEydOxK233opPP/0U586dw6lTp/Dxxx/X+Tqr1YrS0lK3fqxWq1NlzM/Pd1zmIi4ursF1O3XqBAA4cuSIkwlc3/nz5/HVV18BAPr16ye2DWfZB4Vnz57Fq6++ittuuw2ffPIJzpw5g0OHDmHevHkICAiA1WrFo48+ioqKiutu01vtLd3WzvKVchIRkW/i+EzbfTz33HMICQmBzWbDiBEjsHr1apw5cwZlZWXIycnB008/jYULF0JRFLz66quIjY11qZzO4hjNNU8++STCwsLQrFkzBAcH47bbbsOcOXNw+vRpp16/ePFifP311+jZsyeefPJJt8pCRERkx3Gatr755hsAQOfOnXH16lW88sorSExMRLNmzRAYGIj4+HjMnDnT6b//jcVxmntcOe7JMRrpnhY3ci0vL1c//vhjdcSIEWpgYKCqKIpqMBjUwMBAddiwYeo///lPtby8XItdEZGg1q1bO24iPmbMGPXq1atOvS4jI8Pxusb+OHvj+uzsbMdrlixZ0uC6Y8eOVQGorVq1cmrbzpg6dapj/+np6WLbcNYtt9zi2Nfw4cPr/KyeO3euY51169Zdd5veam/Jtn7++ecd+87NzdVtOYmIqOnj+Ez7fezcuVONiYmpt94///nP1Q0bNrhUPldxjHZ9ubm51y3vDTfcoH700UcNbufkyZNqaGioajAY1D179tR4zr6dyZMnu1w+IiIijtMa5sqxlR9//NGx7uO
"text/plain": [
"<Figure size 1900x510 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fontsize = 19\n",
"\n",
"# Create subplots\n",
"fig, axes = plt.subplots(1, 3, figsize=(19, 5.1))\n",
"\n",
"# Plot for Task Intention\n",
"axes[2].scatter(tom8_common_f1, exp2_common_f1_diff_6_0)\n",
"axes[2].set_title(f'$r={tom8correlation_coefficient:.2f}$, $p={tom8p_value:.2f}$', fontsize=fontsize)\n",
"axes[2].plot(np.unique(tom8_common_f1), np.poly1d(np.polyfit(tom8_common_f1, exp2_common_f1_diff_6_0, 1))(np.unique(tom8_common_f1)), color='red', linestyle='--')\n",
"axes[2].set_xlabel('F1(Task Intention)', fontsize=fontsize)\n",
"#axes[2].set_ylabel('F1(OMK w ToM) - F1(OMK w/ ToM)', fontsize=fontsize)\n",
"#axes[2].legend(fontsize=fontsize, loc='upper center', bbox_to_anchor=(0.5, 1.17))\n",
"axes[2].grid(True, linestyle='--')\n",
"axes[2].tick_params(axis='both', labelsize=fontsize)\n",
"\n",
"# Plot for Task Status\n",
"axes[0].scatter(tom6_common_f1, exp2_common_f1_diff_5_0, c='purple')\n",
"axes[0].set_title(f'$r={tom6correlation_coefficient:.2f}$, $p={tom6p_value:.2f}$', fontsize=fontsize)\n",
"axes[0].plot(np.unique(tom6_common_f1), np.poly1d(np.polyfit(tom6_common_f1, exp2_common_f1_diff_5_0, 1))(np.unique(tom6_common_f1)), color='red', linestyle='--')\n",
"axes[0].set_xlabel('F1(Task Status)', fontsize=fontsize)\n",
"axes[0].set_ylabel('F1(OMK w ToM) $-$ F1(OMK w/ ToM)', fontsize=fontsize)\n",
"#axes[0].legend(fontsize=fontsize, loc='upper center', bbox_to_anchor=(0.5, 1.17))\n",
"axes[0].grid(True, linestyle='--')\n",
"axes[0].tick_params(axis='both', labelsize=fontsize)\n",
"\n",
"# Plot for Task Knowledge\n",
"axes[1].scatter(tom7_common_f1, exp2_common_f1_diff_6_0, c='green')\n",
"axes[1].set_title(f'$r={tom7correlation_coefficient:.2f}$, $p={tom7p_value:.2f}$', fontsize=fontsize)\n",
"axes[1].plot(np.unique(tom7_common_f1), np.poly1d(np.polyfit(tom7_common_f1, exp2_common_f1_diff_6_0, 1))(np.unique(tom7_common_f1)), color='red', linestyle='--')\n",
"axes[1].set_xlabel('F1(Task Knowledge)', fontsize=fontsize)\n",
"#axes[1].set_ylabel('F1(OMK w ToM) - F1(OMK w/ ToM)', fontsize=fontsize)\n",
"#axes[1].legend(fontsize=fontsize, loc='upper center', bbox_to_anchor=(0.5, 1.25))\n",
"axes[1].grid(True, linestyle='--')\n",
"axes[1].tick_params(axis='both', labelsize=fontsize)\n",
"\n",
"plt.tight_layout()\n",
"\n",
"plt.savefig('correlation_tom_feats_exp2.pdf', bbox_inches='tight')\n",
"\n",
"plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Partner's missing knowledge (exp3)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"tom6correlation_coefficient, tom6p_value = pearsonr(tom6_common_f1, exp3_common_f1_diff_7_0)\n",
"tom7correlation_coefficient, tom7p_value = pearsonr(tom7_common_f1, exp3_common_f1_diff_7_0)\n",
"tom8correlation_coefficient, tom8p_value = pearsonr(tom8_common_f1, exp3_common_f1_diff_7_0)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.23820019887745394 0.07176536455128052\n",
"-0.0691280537719391 0.6061193964113574\n",
"0.10181632005409125 0.4469499039738408\n"
]
}
],
"source": [
"print(tom6correlation_coefficient, tom6p_value)\n",
"print(tom7correlation_coefficient, tom7p_value)\n",
"print(tom8correlation_coefficient, tom8p_value)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB1gAAAHzCAYAAACNAs8uAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzde3yT5f3/8fedlB6QNiIHBaGcnMhBf4hTBEUOGwxBwCHiYYo6nVO3qZNN5xybzjknzimeN+dEmXhmmygibnJyICqCgihDrBQQhII0BQs0yf37o98ESg8kzeGTNK/n49GHJb2bXPfrvtJe5m4Sx3VdVwAAAAAAAAAAAACAQ/JYDwAAAAAAAAAAAAAAMgUnWAEAAAAAAAAAAAAgSpxgBQAAAAAAAAAAAIAocYIVAAAAAAAAAAAAAKLECVYAAAAAAAAAAAAAiBInWAEAAAAAAAAAAAAgSpxgBQAAAAAAAAAAAIAo5STiStavX68FCxZowYIFWrVqlbZt26Zt27ZJktq2bas2bdqoV69eGjRokAYNGqROnTol4mYBAAAAAAAAAAAAIKUc13Xdxnzj7t27NWPGDP3lL3/R+++/H7m8vqtzHCfy+UknnaQrr7xSF1xwgQ477LDG3DwAAAAAAAAAAAAApFzMJ1h37dqlu+66S/fff7927doVOaHq9XrVu3dvdevWTUcccYSOOOIISdKOHTu0Y8cOrVu3TqtWrVIwGKy+YcfRYYcdpuuuu0433XSTWrRokeBdAwAAAAAAAAAAAIDEiukE60MPPaTbb79d27Ztk+u66tSpky688EKdeeaZOumkk1RQUNDg91dWVmrZsmV6/fXXNWPGDJWUlMhxHLVu3Vq//vWv9aMf/SjuHQIAAAAAAAAAAACAZPHEsvFPfvITbdu2TePHj9eiRYtUUlKiO+64Q6effvohT65KUkFBgU4//XTdfvvtWrdund566y2NHz9e27dv17XXXtvonQCA+rzyyis666yz1K5dO+Xn56tTp066/PLLtXLlykZd3549e/TPf/5T11xzjb75zW+qZcuWatasmVq3bq1Bgwbpnnvukd/vb9R1b9u2Ta1bt5bjOHIcR5deemmjridbJfpYJ/I2OnfuHDmu0XzceuutCRszAABNEb/3s08qjrkkzZs3T5dccom6deum5s2bq2XLlurZs6cuueQSPffcc3V+j+u6+vjjj/Xkk0/qRz/6kU4++WTl5eVFjvHnn3+e0DECAJAsyfp9m+jflalaFzR1qe4Yy2Ofn3/+edTr6bKysqSMFzgkNwbnn3+++9FHH8XyLVFZtWqVe/755yf8egFkt6uuusqVVOdHXl6e++STT8Z8nYWFhfVeZ/ijQ4cO7tKlS2O+7u9973s1rueSSy6J+TqyVTKOdSJvo1OnToecNwd+vPjii3GPFwCAporf+9knFcf866+/di+44IIGj1WnTp3q/N6SkpIGv6+kpCTu8QEAkGzJ/H2byN+VqVgXZAOLjrE89nmoOXPgx7Zt2xI+ViAaMT2D9ZlnnlHPnj1j+Zao9OrVS88880zCrxdA9poyZYoeffRRSdLZZ5+tZcuWaevWrXr99dfVu3dv7d27V5dffrn++9//xnS9FRUVys3N1YQJEzRjxgytXbtWO3bs0KpVq/SLX/xCOTk52rhxo0aMGKFNmzZFfb1z587V008/ra5du8Y0HiTvWCfyNlavXq2KiooGP/r06SNJatmypUaPHt3osQIA0JTxez/7pOKYBwIBnX322XrmmWfUrFkzXX/99Xr77be1detWbdmyRQsWLNCkSZPUvn37Q15Xhw4d9N3vflcDBw5s9HgAAEi1VPy+DYvnd2Uqx9mUWXSM57HP2bNnN7i2bt26dcLGCcTE+gwvACTa1q1b3RYtWriS3OHDh7uhUKjG18vKytwjjzzSleT269cvpuu+5ppr3M2bN9f79aeffjry11NXX311VNe5e/dut2vXrq4kd/bs2TyDNQbJPNapvI3Vq1dHjvtVV13VqOsAAKCp4/d+9knF8XBd173zzjtdSW5+fr47b968mL/f7/e7//znP2v8f8JvfvMbnsEKAMgIqfh9m4jflalaFzR1Fh0b89jngc9gbcz6DEiFmJ7BCgCZ4Mknn9SuXbskSXfeeaccx6nx9VatWunGG2+UJC1dulTvv/9+1Nf90EMP6aijjqr36xdeeKGOP/54SdJrr70W1XXeeuut+uyzzzR+/HideeaZUY8FyT3WqbyNp556KvL5JZdcEvP3AwCQDfi9n31ScTy++uor/fa3v5Uk3XLLLRo8eHDM11FYWKixY8c2+P8JAACkq1T8vk3E78pUjDMbWHTksU80VZxgBRC1Z599Vo7jyOv16uuvv9bnn3+uG264Qccee6zy8/PlOI42bNhgPUzNmjVLktStWzf17du3zm0mTJgQ+fzll19O6O336tVLkvTFF18cctsVK1bo3nvvVWFhoaZOnZrQcUTjlVdeibwh/Pbt27VgwQJdeOGF6ty5swoKCuTz+TRs2DAtWLAg5WOLRiqOdbJvw3VdPf3005KkY489VqeeemrMYwQAZLZMWWNZ4/d+7FjrHdrf//53VVZWKjc3Vz/60Y8aN1AAAOqQKWs868fRopUu42R9FRvrxz6BZMqJZeOFCxcm9MbPOOOMhF4fgORavny5pOpfwC+//LKuuOIK7d69O/L1li1bqmPHjlbDiwj/ZVVDD1h16NBBRx99tDZt2qRly5Yl9Pa//PJLSZLP52twu1AopCuvvFKBQEC33357VO/plGgffPCBJKljx4666667dPfdd9f4+p49e/Tvf/9b8+bN04svvqizzz475WNsSCqOdbJvY968eZH/obr44otjHh8AIPNlyhrLGr/3Y8da79Bmz54tSfrmN7+pli1bRi4PBoNyHEceD3+XDgBonExZ41k/jhatdBkn66voJfqxz3379ik3Nzeu6wASKaYTrIMHD671lPHGchxHgUAgIdcFIDVWrFghSSovL9dFF12kPn366Fe/+pX69++vqqoqrVmzps7vCwaDqqysjOu2CwoK5PV6D7ndpk2bIi9zcag3Te/SpYs2bdpU77gb48svv9Rbb70lSRowYECD295///1699131bdvX/34xz9O2BhiEV4Ubt68WXfffbdOOOEE/frXv9aAAQO0c+dOvfDCC7r99tsVCAT0wx/+UGeeeaby8vIavM5UHe9UHOtU3Eb4ZQIdx+EEKwBkqUxYY1nj937jJHqt1xTX9e+9954kqWfPntq3b5/uvfdeTZs2TWvXrpXruurcubNGjRqlG2+8UR06dIj5+gEA2SsT1njWj6NFK53GyWNp0UvUY58//vGPtX79eu3atUt5eXk69thjNWLECF177bWsz2AqphOsYa7rJnocADJAeGG4detWnX322Xr++efVrFmzyNfr+4W2aNEiDRkyJK7bnjdvXlTvh1RWVhb5vG3btg1uG/769u3b4xrbgW6++WZVVVVJkq6++up6tystLdXkyZPl8Xj06KOPmj2wGV4UBgIBjRo1Si+99FJk0deuXTv9+te/jvyV2datW/Xvf/9bo0aNavA6U3W8U3Gsk30bX3/9tV566SVJ0qBBg9SpU6eYxgcAaBoyYY1ljd/7jZPotV5TW9dXVlZGbic3N1dnnHGGli5dWmObzz77TA888ICmT5+umTNnxr3/AIDskQlrPOvH0aKVTuPksbToJPKxz48++ijy+d69e7Vy5UqtXLlSDz/8sP7617/q/PPPb/R1A/Fo1AlWn8+nc889V9/97nfVvHnzRI8JQBr64osvtHXrVklS+/bt9be//a3GojBdHPhSK/n5+Q1uW1BQIEmRv9yK19NPP60nnnhCkjRmzBh95zvfqXfbH/3oR9q1a5euueYanXzyyQm5/Vjt3r1bn376qaTqRf2zzz5b51/UXXXVVbr99tslSatWrTrkojBVUnGsk30bM2fOjGw/ceLEmMYGAGgaMmWNZY3f+7FjrXdo5eXlkc//+te/at++fRo7dqxuvfVW9ejRQzt27NCMGTN0yy23aOfOnTrnnHP04Yc
"text/plain": [
"<Figure size 1900x510 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"fontsize = 19\n",
"t_ylim = 0.4\n",
"b_ylim = -0.12\n",
"\n",
"# Create subplots\n",
"fig, axes = plt.subplots(1, 3, figsize=(19, 5.1))\n",
"\n",
"# Plot for Task Intention\n",
"axes[2].scatter(tom8_common_f1, exp3_common_f1_diff_7_0)\n",
"axes[2].set_title(f'$r={tom8correlation_coefficient:.2f}$, $p={tom8p_value:.2f}$', fontsize=fontsize)\n",
"axes[2].plot(np.unique(tom8_common_f1), np.poly1d(np.polyfit(tom8_common_f1, exp3_common_f1_diff_7_0, 1))(np.unique(tom8_common_f1)), color='red', linestyle='--')\n",
"axes[2].set_xlabel('F1(Task Intention)', fontsize=fontsize)\n",
"axes[2].set_ylim(b_ylim, t_ylim)\n",
"#axes[2].set_ylabel('F1(OMK w ToM) - F1(OMK w/ ToM)', fontsize=fontsize)\n",
"#axes[2].legend(fontsize=fontsize, loc='upper center', bbox_to_anchor=(0.5, 1.17))\n",
"axes[2].grid(True, linestyle='--')\n",
"axes[2].tick_params(axis='both', labelsize=fontsize)\n",
"\n",
"# Plot for Task Status\n",
"axes[0].scatter(tom6_common_f1, exp3_common_f1_diff_7_0, c='purple')\n",
"axes[0].set_title(f'$r={tom6correlation_coefficient:.2f}$, $p={tom6p_value:.2f}$', fontsize=fontsize)\n",
"axes[0].plot(np.unique(tom6_common_f1), np.poly1d(np.polyfit(tom6_common_f1, exp3_common_f1_diff_7_0, 1))(np.unique(tom6_common_f1)), color='red', linestyle='--')\n",
"axes[0].set_xlabel('F1(Task Status)', fontsize=fontsize)\n",
"axes[0].set_ylabel('F1(PMK w ToM) $-$ F1(PMK w/ ToM)', fontsize=fontsize)\n",
"axes[0].set_ylim(b_ylim, t_ylim)\n",
"#axes[0].legend(fontsize=fontsize, loc='upper center', bbox_to_anchor=(0.5, 1.17))\n",
"axes[0].grid(True, linestyle='--')\n",
"axes[0].tick_params(axis='both', labelsize=fontsize)\n",
"\n",
"# Plot for Task Knowledge\n",
"axes[1].scatter(tom7_common_f1, exp3_common_f1_diff_7_0, c='green')\n",
"axes[1].set_title(f'$r={tom7correlation_coefficient:.2f}$, $p={tom7p_value:.2f}$', fontsize=fontsize)\n",
"axes[1].plot(np.unique(tom7_common_f1), np.poly1d(np.polyfit(tom7_common_f1, exp3_common_f1_diff_7_0, 1))(np.unique(tom7_common_f1)), color='red', linestyle='--')\n",
"axes[1].set_xlabel('F1(Task Knowledge)', fontsize=fontsize)\n",
"axes[1].set_ylim(b_ylim, t_ylim)\n",
"#axes[1].set_ylabel('F1(OMK w ToM) - F1(OMK w/ ToM)', fontsize=fontsize)\n",
"#axes[1].legend(fontsize=fontsize, loc='upper center', bbox_to_anchor=(0.5, 1.25))\n",
"axes[1].grid(True, linestyle='--')\n",
"axes[1].tick_params(axis='both', labelsize=fontsize)\n",
"\n",
"plt.tight_layout()\n",
"\n",
"plt.savefig('correlation_tom_feats_exp3.pdf', bbox_inches='tight')\n",
"\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "mindcraft_",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}