evaluation code
This commit is contained in:
parent
fc7973a49b
commit
3d3cebb956
6 changed files with 660 additions and 13 deletions
101
06_baselines.py
Normal file
101
06_baselines.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from config import conf
|
||||
import os, sys
|
||||
import pandas as pns
|
||||
from config import names as gs
|
||||
import getopt
|
||||
import matplotlib.gridspec as gridspec
|
||||
from sklearn.metrics import f1_score, accuracy_score
|
||||
|
||||
import seaborn as sns
|
||||
sns.set(style='whitegrid', color_codes=True)
|
||||
sns.set_context('poster')
|
||||
|
||||
dark_color = sns.xkcd_rgb['charcoal grey']
|
||||
light_color = sns.xkcd_rgb['cloudy blue']
|
||||
|
||||
max_n_feat = conf.max_n_feat
|
||||
m_iter = conf.max_n_iter
|
||||
|
||||
featurelabels = gs.full_long_label_list
|
||||
participant_ids = np.arange(0, conf.n_participants)
|
||||
|
||||
|
||||
def plot_overview():
|
||||
all_baselines.groupby(by=['trait', 'clf_name'])['F1'].mean().to_csv(conf.figure_folder +
|
||||
'/figure1.csv')
|
||||
print 'Figure1.csv written'
|
||||
|
||||
sns.set(font_scale=2.1)
|
||||
plt.figure(figsize=(20, 10))
|
||||
ax = plt.subplot(1,1,1)
|
||||
sns.barplot(x='trait', y='F1', hue='clf_name', data=all_baselines, capsize=.05, errwidth=3,
|
||||
linewidth=3, estimator=np.mean, edgecolor=dark_color,
|
||||
palette={'our classifier': sns.xkcd_rgb['windows blue'],
|
||||
'most frequent class': sns.xkcd_rgb['faded green'],
|
||||
'random guess':sns.xkcd_rgb['greyish brown'],
|
||||
'label permutation':sns.xkcd_rgb['dusky pink']
|
||||
}
|
||||
)
|
||||
plt.plot([-0.5,6.5], [0.33, 0.33], c=dark_color, linestyle='--', linewidth=3, label='theoretical chance level')
|
||||
handles, labels = ax.get_legend_handles_labels()
|
||||
ax.legend([handles[1], handles[2], handles[3], handles[4], handles[0]], [labels[1], labels[2], labels[3], labels[4], labels[0]], fontsize=20)
|
||||
plt.xlabel('')
|
||||
plt.ylabel('F1 score', fontsize=20)
|
||||
plt.ylim((0, 0.55))
|
||||
filename = conf.figure_folder + '/figure1.pdf'
|
||||
plt.savefig(filename, bbox_inches='tight')
|
||||
plt.close()
|
||||
print 'wrote', filename.split('/')[-1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# collect F1 scores for classifiers on all data from a file that was written by evaluation_single_context.py
|
||||
datapath = conf.get_result_folder(conf.annotation_all) + '/f1s.csv'
|
||||
if not os.path.exists(datapath):
|
||||
print 'could not find', datapath
|
||||
print 'consider (re-)running evaluation_single_context.py'
|
||||
sys.exit(1)
|
||||
our_classifier = pns.read_csv(datapath)
|
||||
our_classifier['clf_name'] = 'our classifier'
|
||||
|
||||
# baseline 1: guess the most frequent class from each training set that was written by train_baseline.py
|
||||
datapath = conf.result_folder + '/most_frequ_class_baseline.csv'
|
||||
if not os.path.exists(datapath):
|
||||
print 'could not find', datapath
|
||||
print 'consider (re-)running train_baseline.py'
|
||||
sys.exit(1)
|
||||
most_frequent_class_df = pns.read_csv(datapath)
|
||||
most_frequent_class_df['clf_name'] = 'most frequent class'
|
||||
|
||||
# compute all other baselines ad hoc
|
||||
collection = []
|
||||
for trait in xrange(0, conf.n_traits):
|
||||
# baseline 2: random guess
|
||||
truth = np.genfromtxt(conf.binned_personality_file, skip_header=1, usecols=(trait+1,), delimiter=',')
|
||||
for i in xrange(0, 100):
|
||||
rand_guess = np.random.randint(1, 4, conf.n_participants)
|
||||
f1 = f1_score(truth, rand_guess, average='macro')
|
||||
collection.append([f1, conf.medium_traitlabels[trait], i, 'random guess'])
|
||||
|
||||
# baseline 3: label permutation test
|
||||
# was computed using label_permutation_test.sh and written into results. ie. is just loaded here
|
||||
for si in xrange(0, m_iter):
|
||||
filename_rand = conf.get_result_filename(conf.annotation_all, trait, True, si, add_suffix=True)
|
||||
if os.path.exists(filename_rand):
|
||||
data = np.load(filename_rand)
|
||||
pr = data['predictions']
|
||||
dt = truth[pr > 0]
|
||||
pr = pr[pr > 0]
|
||||
f1 = f1_score(dt, pr, average='macro')
|
||||
collection.append([f1, conf.medium_traitlabels[trait], si, 'label permutation'])
|
||||
else:
|
||||
print 'did not find', filename_rand
|
||||
print 'consider (re-)running label_permutation_test.sh'
|
||||
sys.exit(1)
|
||||
|
||||
collectiondf = pns.DataFrame(data=collection,columns=['F1','trait','iteration','clf_name'])
|
||||
all_baselines = pns.concat([our_classifier, most_frequent_class_df, collectiondf])
|
||||
|
||||
plot_overview() # Figure 1
|
Loading…
Add table
Add a link
Reference in a new issue