383 lines
19 KiB
Python
383 lines
19 KiB
Python
import argparse
|
|
import numpy as np
|
|
import pickle
|
|
import os
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
from sklearn.metrics import accuracy_score, classification_report, f1_score
|
|
from src.data.game_parser_graphs_new import GameParser, make_splits, onehot, set_seed
|
|
from tqdm import tqdm
|
|
from scipy.stats import wilcoxon
|
|
from sklearn.exceptions import ConvergenceWarning
|
|
import matplotlib.pyplot as plt
|
|
from sklearn.manifold import TSNE
|
|
from sklearn.preprocessing import StandardScaler
|
|
from sklearn.decomposition import PCA
|
|
import umap
|
|
import warnings
|
|
warnings.simplefilter("ignore", category=ConvergenceWarning)
|
|
|
|
|
|
def parse_q(q, game):
|
|
if not q is None:
|
|
q ,l = q
|
|
q = np.concatenate([
|
|
onehot(q[2],2),
|
|
onehot(q[3],2),
|
|
onehot(q[4][0][0]+1,2),
|
|
onehot(game.materials_dict[q[4][0][1]],len(game.materials_dict)),
|
|
onehot(q[4][1][0]+1,2),
|
|
onehot(game.materials_dict[q[4][1][1]],len(game.materials_dict)),
|
|
onehot(q[4][2]+1,2),
|
|
onehot(q[5][0][0]+1,2),
|
|
onehot(game.materials_dict[q[5][0][1]],len(game.materials_dict)),
|
|
onehot(q[5][1][0]+1,2),
|
|
onehot(game.materials_dict[q[5][1][1]],len(game.materials_dict)),
|
|
onehot(q[5][2]+1,2)
|
|
])
|
|
else:
|
|
q = np.zeros(100)
|
|
l = None
|
|
return q, l
|
|
|
|
def cosine_similarity(array1, array2):
|
|
"""
|
|
Compute the cosine similarity between two arrays.
|
|
Parameters:
|
|
- array1: First input array
|
|
- array2: Second input array
|
|
Returns:
|
|
- similarity: Cosine similarity between the two arrays
|
|
"""
|
|
dot_product = np.dot(array1, array2)
|
|
norm_array1 = np.linalg.norm(array1)
|
|
norm_array2 = np.linalg.norm(array2)
|
|
similarity = dot_product / (norm_array1 * norm_array2)
|
|
return similarity
|
|
|
|
def compute_and_plot_pca(data1, data2, labels=None, fname='pca'):
|
|
"""
|
|
Compute and plot Principal Component Analysis (PCA) for a given dataset with 2 components.
|
|
Parameters:
|
|
- data: Input dataset
|
|
- labels: Labels for data points (optional)
|
|
Returns:
|
|
- pca_result: Result of PCA transformation
|
|
"""
|
|
scaler1 = StandardScaler()
|
|
data_standardized1 = scaler1.fit_transform(data1)
|
|
scaler2 = StandardScaler()
|
|
data_standardized2 = scaler2.fit_transform(data2)
|
|
pca1 = PCA(n_components=2)
|
|
pca_result1 = pca1.fit_transform(data_standardized1)
|
|
pca2 = PCA(n_components=2)
|
|
pca_result2 = pca2.fit_transform(data_standardized2)
|
|
pca_result = np.concatenate([pca_result1, pca_result2])
|
|
unique_labels = np.unique(labels) if labels is not None else [None]
|
|
plt.figure(figsize=(8, 6))
|
|
for unique_label in unique_labels:
|
|
mask = (labels == unique_label) if labels is not None else slice(None)
|
|
plt.scatter(pca_result[mask, 0], pca_result[mask, 1], label=unique_label)
|
|
plt.xlabel('Principal Component 1')
|
|
plt.ylabel('Principal Component 2')
|
|
if labels is not None:
|
|
plt.legend()
|
|
os.makedirs("figures/", exist_ok=True)
|
|
plt.savefig(f"figures/{fname}.pdf", bbox_inches='tight')
|
|
return pca_result
|
|
|
|
def compute_and_plot_tsne(data1, data2, labels=None, fname='tsne'):
|
|
"""
|
|
Compute and plot t-SNE for a given standardized dataset with 2 components.
|
|
Parameters:
|
|
- data: Input dataset
|
|
- labels: Labels for data points (optional)
|
|
Returns:
|
|
- tsne_result: Result of t-SNE transformation
|
|
"""
|
|
scaler1 = StandardScaler()
|
|
data_standardized1 = scaler1.fit_transform(data1)
|
|
tsne1 = TSNE(n_components=2)
|
|
tsne_result1 = tsne1.fit_transform(data_standardized1)
|
|
scaler2 = StandardScaler()
|
|
data_standardized2 = scaler2.fit_transform(data2)
|
|
tsne2 = TSNE(n_components=2)
|
|
tsne_result2 = tsne2.fit_transform(data_standardized2)
|
|
tsne_result = np.concatenate([tsne_result1, tsne_result2])
|
|
unique_labels = np.unique(labels) if labels is not None else [None]
|
|
plt.figure(figsize=(8, 6))
|
|
for unique_label in unique_labels:
|
|
mask = (labels == unique_label) if labels is not None else slice(None)
|
|
plt.scatter(tsne_result[mask, 0], tsne_result[mask, 1], label=unique_label)
|
|
plt.xlabel('t-SNE Component 1')
|
|
plt.ylabel('t-SNE Component 2')
|
|
if labels is not None:
|
|
plt.legend()
|
|
plt.savefig(f"figures/{fname}.pdf", bbox_inches='tight')
|
|
return tsne_result
|
|
|
|
def compute_and_plot_umap(data1, data2, labels=None, fname='umap'):
|
|
"""
|
|
Compute and plot UMAP for a given standardized dataset with 2 components.
|
|
Parameters:
|
|
- data: Input dataset
|
|
- labels: Labels for data points (optional)
|
|
Returns:
|
|
- umap_result: Result of UMAP transformation
|
|
"""
|
|
scaler1 = StandardScaler()
|
|
data_standardized1 = scaler1.fit_transform(data1)
|
|
umap_model1 = umap.UMAP(n_components=2)
|
|
umap_result1 = umap_model1.fit_transform(data_standardized1)
|
|
scaler2 = StandardScaler()
|
|
data_standardized2 = scaler2.fit_transform(data2)
|
|
umap_model2 = umap.UMAP(n_components=2)
|
|
umap_result2 = umap_model2.fit_transform(data_standardized2)
|
|
umap_result = np.concatenate([umap_result1, umap_result2])
|
|
unique_labels = np.unique(labels) if labels is not None else [None]
|
|
plt.figure(figsize=(8, 6))
|
|
for unique_label in unique_labels:
|
|
mask = (labels == unique_label) if labels is not None else slice(None)
|
|
plt.scatter(umap_result[mask, 0], umap_result[mask, 1], label=unique_label)
|
|
plt.xlabel('UMAP Component 1')
|
|
plt.ylabel('UMAP Component 2')
|
|
if labels is not None:
|
|
plt.legend()
|
|
plt.savefig(f"figures/{fname}.pdf", bbox_inches='tight')
|
|
return umap_result
|
|
|
|
def prepare_data_tom(mode):
|
|
dataset_splits = make_splits('config/dataset_splits_new.json')
|
|
d_flag = False
|
|
d_move_flag = False
|
|
if mode == 'train':
|
|
data = [GameParser(f,d_flag,1,7,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
|
data += [GameParser(f,d_flag,2,7,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
|
elif mode == 'test':
|
|
data = [GameParser(f,d_flag,1,7,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
|
data += [GameParser(f,d_flag,2,7,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
|
else:
|
|
raise ValueError('train or test are supported')
|
|
tom6repr = []
|
|
tom7repr = []
|
|
tom8repr = []
|
|
tom6labels = []
|
|
tom7labels = []
|
|
tom8labels = []
|
|
for game in data:
|
|
_, _, _, q, _, _, interm, _ = zip(*list(game))
|
|
interm = np.array(interm)
|
|
# intermediate = np.concatenate([ToM6,ToM7,ToM8,DAct,DMove])
|
|
tom6, tom7, tom8, _, _ = np.split(interm, np.cumsum([1024] * 5)[:-1], axis=1)
|
|
q = [parse_q(x, game) for x in q]
|
|
q, l = zip(*q)
|
|
indexes = [idx for idx, element in enumerate(l) if element is not None]
|
|
tom6repr.append(tom6[indexes])
|
|
tom7repr.append(tom7[indexes])
|
|
tom8repr.append(tom8[indexes])
|
|
l = [item[1] for item in l if item is not None]
|
|
tom6labels.append([['NO', 'MAYBE', 'YES'].index(item[0]) for item in l])
|
|
tom7labels.append([['NO', 'MAYBE', 'YES'].index(item[1]) for item in l])
|
|
tom8labels.append([game.materials_dict[item[2]] if item[2] in game.materials_dict else 0 for item in l])
|
|
tom6labels = sum(tom6labels, [])
|
|
tom7labels = sum(tom7labels, [])
|
|
tom8labels = sum(tom8labels, [])
|
|
return np.concatenate(tom6repr), tom6labels, np.concatenate(tom7repr), tom7labels, np.concatenate(tom8repr), tom8labels
|
|
|
|
def prepare_data_cpa(mode, experiment):
|
|
dataset_splits = make_splits('config/dataset_splits_new.json')
|
|
d_flag = False
|
|
d_move_flag = False
|
|
if mode == "train":
|
|
data = [GameParser(f,d_flag,1,7,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
|
data += [GameParser(f,d_flag,2,7,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
|
if experiment == 2:
|
|
with open('XXX', 'rb') as f:
|
|
feats = pickle.load(f)
|
|
elif experiment == 3:
|
|
with open('XXX', 'rb') as f:
|
|
feats = pickle.load(f)
|
|
else: raise ValueError
|
|
elif mode == "test":
|
|
data = [GameParser(f,d_flag,1,7,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
|
data += [GameParser(f,d_flag,2,7,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
|
if experiment == 2:
|
|
with open('XXX', 'rb') as f:
|
|
feats = pickle.load(f)
|
|
elif experiment == 3:
|
|
with open('XXX', 'rb') as f:
|
|
feats = pickle.load(f)
|
|
else: raise ValueError
|
|
else:
|
|
raise ValueError('train or test are supported')
|
|
tom6labels = []
|
|
tom7labels = []
|
|
tom8labels = []
|
|
features = [item[0] for item in feats]
|
|
game_names = [item[1] for item in feats]
|
|
selected_feats = []
|
|
for i, game in enumerate(data):
|
|
_, _, _, q, _, _, _, _ = zip(*list(game))
|
|
q = [parse_q(x, game) for x in q]
|
|
q, l = zip(*q)
|
|
indexes = [idx for idx, element in enumerate(l) if element is not None]
|
|
assert game.game_path.split("/")[-1] == game_names[i].split("/")[-1]
|
|
selected_feats.append(features[i][indexes])
|
|
l = [item[1] for item in l if item is not None]
|
|
tom6labels.append([['NO', 'MAYBE', 'YES'].index(item[0]) for item in l])
|
|
tom7labels.append([['NO', 'MAYBE', 'YES'].index(item[1]) for item in l])
|
|
tom8labels.append([game.materials_dict[item[2]] if item[2] in game.materials_dict else 0 for item in l])
|
|
tom6labels = sum(tom6labels, [])
|
|
tom7labels = sum(tom7labels, [])
|
|
tom8labels = sum(tom8labels, [])
|
|
selected_feats = np.concatenate(selected_feats)
|
|
return selected_feats, tom6labels, tom7labels, tom8labels
|
|
|
|
def fit_and_test_LR(X_train, y_train, X_test, y_test, max_iter=100):
|
|
logreg_model = LogisticRegression(max_iter=max_iter)
|
|
logreg_model.fit(X_train, y_train)
|
|
y_pred = logreg_model.predict(X_test)
|
|
f1 = f1_score(y_test, y_pred, average="weighted", zero_division=1)
|
|
# classification_report_output = classification_report(y_test, y_pred)
|
|
print("F1 score:", f1)
|
|
# print("Classification Report:\n", classification_report_output)
|
|
return logreg_model
|
|
|
|
def fit_and_test_RF(X_train, y_train, X_test, y_test, n_estimators):
|
|
model = RandomForestClassifier(n_estimators=n_estimators)
|
|
model.fit(X_train, y_train)
|
|
y_pred = model.predict(X_test)
|
|
f1 = f1_score(y_test, y_pred, average="weighted", zero_division=1)
|
|
print("F1 score:", f1)
|
|
return model
|
|
|
|
def wilcoxon_test(model1, model2, X_test_1, X_test_2):
|
|
probabilities_model1 = model1.predict_proba(X_test_1)[:, 1]
|
|
probabilities_model2 = model2.predict_proba(X_test_2)[:, 1]
|
|
differences = probabilities_model1 - probabilities_model2
|
|
_, p_value_wilcoxon = wilcoxon(differences)
|
|
print("Wilcoxon signed-rank test p-value:", p_value_wilcoxon)
|
|
return p_value_wilcoxon
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--task", type=str)
|
|
parser.add_argument("--seed", type=int)
|
|
parser.add_argument("--experiment", type=int)
|
|
args = parser.parse_args()
|
|
set_seed(args.seed)
|
|
task = args.task
|
|
experiment = args.experiment
|
|
|
|
if task == "tom":
|
|
tom6_train_x, tom6_train_labels, tom7_train_x, tom7_train_labels, tom8_train_x, tom8_train_labels = prepare_data_tom("train")
|
|
tom6_test_x, tom6_test_labels, tom7_test_x, tom7_test_labels, tom8_test_x, tom8_test_labels = prepare_data_tom("test")
|
|
print("=========== EXP 6 ========================================")
|
|
# 0.6056079527261083
|
|
fit_and_test_LR(tom6_train_x, tom6_train_labels, tom6_test_x, tom6_test_labels, 100)
|
|
print("=========== EXP 7 ========================================")
|
|
# 0.5090737845776365
|
|
fit_and_test_LR(tom7_train_x, tom7_train_labels, tom7_test_x, tom7_test_labels, 100)
|
|
print("=========== EXP 8 ========================================")
|
|
# 0.10206891928130866
|
|
fit_and_test_LR(tom8_train_x, tom8_train_labels, tom8_test_x, tom8_test_labels, 6)
|
|
breakpoint()
|
|
|
|
elif task == "cpa":
|
|
train_x, tom6_train_labels, tom7_train_labels, tom8_train_labels = prepare_data_cpa("train", experiment)
|
|
test_x, tom6_test_labels, tom7_test_labels, tom8_test_labels = prepare_data_cpa("test", experiment)
|
|
print("=========== EXP 6 ========================================")
|
|
# 0.5157497361676466 139
|
|
fit_and_test_LR(train_x, tom6_train_labels, test_x, tom6_test_labels, 139 if experiment == 2 else 11)
|
|
print("=========== EXP 7 ========================================")
|
|
# 0.49755418256915795 307
|
|
fit_and_test_LR(train_x, tom7_train_labels, test_x, tom7_test_labels, 307 if experiment == 2 else 25)
|
|
print("=========== EXP 8 ========================================")
|
|
# 0.14099639490943838 23
|
|
fit_and_test_LR(train_x, tom8_train_labels, test_x, tom8_test_labels, 23 if experiment == 2 else 9)
|
|
breakpoint()
|
|
|
|
elif task == "random":
|
|
tom6_train_x, tom6_train_labels, tom7_train_x, tom7_train_labels, tom8_train_x, tom8_train_labels = prepare_data_tom("train")
|
|
tom6_test_x, tom6_test_labels, tom7_test_x, tom7_test_labels, tom8_test_x, tom8_test_labels = prepare_data_tom("test")
|
|
tom6_train_x = np.random.randn(*tom6_train_x.shape) * 0.1
|
|
tom7_train_x = np.random.randn(*tom7_train_x.shape) * 0.1
|
|
tom8_train_x = np.random.randn(*tom8_train_x.shape) * 0.1
|
|
tom6_test_x = np.random.randn(*tom6_test_x.shape) * 0.1
|
|
tom7_test_x = np.random.randn(*tom7_test_x.shape) * 0.1
|
|
tom8_test_x = np.random.randn(*tom8_test_x.shape) * 0.1
|
|
print("=========== EXP 6 ========================================")
|
|
# 0.4573518645097593
|
|
fit_and_test_LR(tom6_train_x, tom6_train_labels, tom6_test_x, tom6_test_labels, 100)
|
|
print("=========== EXP 7 ========================================")
|
|
# 0.45066310491597705
|
|
fit_and_test_LR(tom7_train_x, tom7_train_labels, tom7_test_x, tom7_test_labels, 100)
|
|
print("=========== EXP 8 ========================================")
|
|
# 0.09281225255303022
|
|
fit_and_test_LR(tom8_train_x, tom8_train_labels, tom8_test_x, tom8_test_labels, 100)
|
|
breakpoint()
|
|
|
|
elif task == "all":
|
|
############## TOM
|
|
print("############## TOM")
|
|
tom6_train_x_tom, tom6_train_labels_tom, tom7_train_x_tom, tom7_train_labels_tom, tom8_train_x_tom, tom8_train_labels_tom = prepare_data_tom("train")
|
|
tom6_test_x_tom, tom6_test_labels_tom, tom7_test_x_tom, tom7_test_labels_tom, tom8_test_x_tom, tom8_test_labels_tom = prepare_data_tom("test")
|
|
print("=========== EXP 6 ========================================")
|
|
model_tom_6 = fit_and_test_LR(tom6_train_x_tom, tom6_train_labels_tom, tom6_test_x_tom, tom6_test_labels_tom, 100)
|
|
print("=========== EXP 7 ========================================")
|
|
model_tom_7 = fit_and_test_LR(tom7_train_x_tom, tom7_train_labels_tom, tom7_test_x_tom, tom7_test_labels_tom, 100)
|
|
print("=========== EXP 8 ========================================")
|
|
model_tom_8 = fit_and_test_LR(tom8_train_x_tom, tom8_train_labels_tom, tom8_test_x_tom, tom8_test_labels_tom, 6)
|
|
############## CPA
|
|
print("############## CPA")
|
|
train_x_cpa, tom6_train_labels_cpa, tom7_train_labels_cpa, tom8_train_labels_cpa = prepare_data_cpa("train", experiment)
|
|
test_x_cpa, tom6_test_labels_cpa, tom7_test_labels_cpa, tom8_test_labels_cpa = prepare_data_cpa("test", experiment)
|
|
print("=========== EXP 6 ========================================")
|
|
model_cpa_6 = fit_and_test_LR(train_x_cpa, tom6_train_labels_cpa, test_x_cpa, tom6_test_labels_cpa, 139)
|
|
print("=========== EXP 7 ========================================")
|
|
model_cpa_7 = fit_and_test_LR(train_x_cpa, tom7_train_labels_cpa, test_x_cpa, tom7_test_labels_cpa, 307)
|
|
print("=========== EXP 8 ========================================")
|
|
model_cpa_8 = fit_and_test_LR(train_x_cpa, tom8_train_labels_cpa, test_x_cpa, tom8_test_labels_cpa, 23)
|
|
############## RANDOM
|
|
print("############## RANDOM")
|
|
tom6_train_x_rand = np.random.randn(*tom6_train_x_tom.shape) * 0.1
|
|
tom7_train_x_rand = np.random.randn(*tom7_train_x_tom.shape) * 0.1
|
|
tom8_train_x_rand = np.random.randn(*tom8_train_x_tom.shape) * 0.1
|
|
tom6_test_x_rand = np.random.randn(*tom6_test_x_tom.shape) * 0.1
|
|
tom7_test_x_rand = np.random.randn(*tom7_test_x_tom.shape) * 0.1
|
|
tom8_test_x_rand = np.random.randn(*tom8_test_x_tom.shape) * 0.1
|
|
print("=========== EXP 6 ========================================")
|
|
model_rand_6 = fit_and_test_LR(tom6_train_x_rand, tom6_train_labels_tom, tom6_test_x_rand, tom6_test_labels_tom, 100)
|
|
print("=========== EXP 7 ========================================")
|
|
model_rand_7 = fit_and_test_LR(tom7_train_x_rand, tom7_train_labels_tom, tom7_test_x_rand, tom7_test_labels_tom, 100)
|
|
print("=========== EXP 8 ========================================")
|
|
model_rand_8 = fit_and_test_LR(tom8_train_x_rand, tom8_train_labels_tom, tom8_test_x_rand, tom8_test_labels_tom, 100)
|
|
wilcoxon_test(model_tom_6, model_cpa_6, tom6_test_x_tom, test_x_cpa)
|
|
wilcoxon_test(model_rand_6, model_cpa_6, tom6_test_x_rand, test_x_cpa)
|
|
wilcoxon_test(model_rand_6, model_tom_6, tom6_test_x_rand, tom6_test_x_tom)
|
|
wilcoxon_test(model_tom_7, model_cpa_7, tom7_test_x_tom, test_x_cpa)
|
|
wilcoxon_test(model_rand_7, model_cpa_7, tom7_test_x_rand, test_x_cpa)
|
|
wilcoxon_test(model_rand_7, model_tom_7, tom7_test_x_rand, tom7_test_x_tom)
|
|
wilcoxon_test(model_tom_8, model_cpa_8, tom8_test_x_tom, test_x_cpa)
|
|
wilcoxon_test(model_rand_8, model_cpa_8, tom8_test_x_rand, test_x_cpa)
|
|
wilcoxon_test(model_rand_8, model_tom_8, tom8_test_x_rand, tom8_test_x_tom)
|
|
scaler = StandardScaler()
|
|
scaled_tom6_test_x_tom = scaler.fit_transform(tom6_test_x_tom)
|
|
scaler = StandardScaler()
|
|
scaled_tom7_test_x_tom = scaler.fit_transform(tom7_test_x_tom)
|
|
scaler = StandardScaler()
|
|
scaled_tom8_test_x_tom = scaler.fit_transform(tom8_test_x_tom)
|
|
scaler = StandardScaler()
|
|
scaled_test_x_cpa = scaler.fit_transform(test_x_cpa)
|
|
sim6 = [cosine_similarity(t, c) for t, c in zip(scaled_tom6_test_x_tom, scaled_test_x_cpa)]
|
|
sim7 = [cosine_similarity(t, c) for t, c in zip(scaled_tom7_test_x_tom, scaled_test_x_cpa)]
|
|
sim8 = [cosine_similarity(t, c) for t, c in zip(scaled_tom8_test_x_tom, scaled_test_x_cpa)]
|
|
print(f"[tom6] max sim: {np.max(sim6)}, mean sim: {np.mean(sim6)}")
|
|
print(f"[tom7] max sim: {np.max(sim7)}, mean sim: {np.mean(sim7)}")
|
|
print(f"[tom8] max sim: {np.max(sim8)}, mean sim: {np.mean(sim8)}")
|
|
breakpoint()
|
|
|
|
else:
|
|
raise ValueError
|