initial commit
This commit is contained in:
commit
7be61f8c6d
137 changed files with 33491 additions and 0 deletions
|
@ -0,0 +1,89 @@
|
|||
# coding: utf-8
|
||||
|
||||
import sys
|
||||
dataDir = '../../VQA'
|
||||
sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir))
|
||||
from vqa import VQA
|
||||
from vqaEvaluation.vqaEval import VQAEval
|
||||
import matplotlib.pyplot as plt
|
||||
import skimage.io as io
|
||||
import json
|
||||
import random
|
||||
import os
|
||||
|
||||
# set up file names and paths
|
||||
versionType ='v2_' # this should be '' when using VQA v2.0 dataset
|
||||
taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
|
||||
dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
|
||||
dataSubType ='train2014'
|
||||
annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
|
||||
quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
|
||||
imgDir ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
|
||||
resultType ='fake'
|
||||
fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType']
|
||||
|
||||
# An example result json file has been provided in './Results' folder.
|
||||
|
||||
[resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \
|
||||
resultType, fileType) for fileType in fileTypes]
|
||||
|
||||
# create vqa object and vqaRes object
|
||||
vqa = VQA(annFile, quesFile)
|
||||
vqaRes = vqa.loadRes(resFile, quesFile)
|
||||
|
||||
# create vqaEval object by taking vqa and vqaRes
|
||||
vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2
|
||||
|
||||
# evaluate results
|
||||
"""
|
||||
If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
|
||||
By default it uses all the question ids in annotation file
|
||||
"""
|
||||
vqaEval.evaluate()
|
||||
|
||||
# print accuracies
|
||||
print "\n"
|
||||
print "Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall'])
|
||||
print "Per Question Type Accuracy is the following:"
|
||||
for quesType in vqaEval.accuracy['perQuestionType']:
|
||||
print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType])
|
||||
print "\n"
|
||||
print "Per Answer Type Accuracy is the following:"
|
||||
for ansType in vqaEval.accuracy['perAnswerType']:
|
||||
print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType])
|
||||
print "\n"
|
||||
# demo how to use evalQA to retrieve low score result
|
||||
evals = [quesId for quesId in vqaEval.evalQA if vqaEval.evalQA[quesId]<35] #35 is per question percentage accuracy
|
||||
if len(evals) > 0:
|
||||
print 'ground truth answers'
|
||||
randomEval = random.choice(evals)
|
||||
randomAnn = vqa.loadQA(randomEval)
|
||||
vqa.showQA(randomAnn)
|
||||
|
||||
print '\n'
|
||||
print 'generated answer (accuracy %.02f)'%(vqaEval.evalQA[randomEval])
|
||||
ann = vqaRes.loadQA(randomEval)[0]
|
||||
print "Answer: %s\n" %(ann['answer'])
|
||||
|
||||
imgId = randomAnn[0]['image_id']
|
||||
imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
|
||||
if os.path.isfile(imgDir + imgFilename):
|
||||
I = io.imread(imgDir + imgFilename)
|
||||
plt.imshow(I)
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
|
||||
# plot accuracy for various question types
|
||||
plt.bar(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].values(), align='center')
|
||||
plt.xticks(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].keys(), rotation='0',fontsize=10)
|
||||
plt.title('Per Question Type Accuracy', fontsize=10)
|
||||
plt.xlabel('Question Types', fontsize=10)
|
||||
plt.ylabel('Accuracy', fontsize=10)
|
||||
plt.show()
|
||||
|
||||
# save evaluation results to ./Results folder
|
||||
json.dump(vqaEval.accuracy, open(accuracyFile, 'w'))
|
||||
json.dump(vqaEval.evalQA, open(evalQAFile, 'w'))
|
||||
json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w'))
|
||||
json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w'))
|
||||
|
|
@ -0,0 +1 @@
|
|||
author='aagrawal'
|
|
@ -0,0 +1,192 @@
|
|||
# coding=utf-8
|
||||
|
||||
__author__='aagrawal'
|
||||
|
||||
import re
|
||||
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
||||
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
|
||||
import sys
|
||||
|
||||
|
||||
class VQAEval:
|
||||
def __init__(self, vqa, vqaRes, n=2):
|
||||
self.n = n
|
||||
self.accuracy = {}
|
||||
self.evalQA = {}
|
||||
self.evalQuesType = {}
|
||||
self.evalAnsType = {}
|
||||
self.vqa = vqa
|
||||
self.vqaRes = vqaRes
|
||||
self.params = {'question_id': vqa.getQuesIds()}
|
||||
self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \
|
||||
"couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \
|
||||
"hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \
|
||||
"he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \
|
||||
"Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \
|
||||
"maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \
|
||||
"mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \
|
||||
"ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \
|
||||
"she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \
|
||||
"somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \
|
||||
"somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \
|
||||
"someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \
|
||||
"something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \
|
||||
"there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \
|
||||
"they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \
|
||||
"wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \
|
||||
"whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \
|
||||
"whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \
|
||||
"whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \
|
||||
"wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \
|
||||
"y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \
|
||||
"youll": "you'll", "youre": "you're", "youve": "you've"}
|
||||
self.manualMap = { 'none': '0',
|
||||
'zero': '0',
|
||||
'one': '1',
|
||||
'two': '2',
|
||||
'three': '3',
|
||||
'four': '4',
|
||||
'five': '5',
|
||||
'six': '6',
|
||||
'seven': '7',
|
||||
'eight': '8',
|
||||
'nine': '9',
|
||||
'ten': '10'
|
||||
}
|
||||
self.articles = ['a',
|
||||
'an',
|
||||
'the'
|
||||
]
|
||||
|
||||
|
||||
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
||||
self.commaStrip = re.compile("(\d)(\,)(\d)")
|
||||
self.punct = [';', r"/", '[', ']', '"', '{', '}',
|
||||
'(', ')', '=', '+', '\\', '_', '-',
|
||||
'>', '<', '@', '`', ',', '?', '!']
|
||||
|
||||
|
||||
def evaluate(self, quesIds=None):
|
||||
if quesIds == None:
|
||||
quesIds = [quesId for quesId in self.params['question_id']]
|
||||
gts = {}
|
||||
res = {}
|
||||
for quesId in quesIds:
|
||||
gts[quesId] = self.vqa.qa[quesId]
|
||||
res[quesId] = self.vqaRes.qa[quesId]
|
||||
|
||||
# =================================================
|
||||
# Compute accuracy
|
||||
# =================================================
|
||||
accQA = []
|
||||
accQuesType = {}
|
||||
accAnsType = {}
|
||||
# print "computing accuracy"
|
||||
step = 0
|
||||
for quesId in quesIds:
|
||||
for ansDic in gts[quesId]['answers']:
|
||||
ansDic['answer'] = ansDic['answer'].replace('\n', ' ')
|
||||
ansDic['answer'] = ansDic['answer'].replace('\t', ' ')
|
||||
ansDic['answer'] = ansDic['answer'].strip()
|
||||
resAns = res[quesId]['answer']
|
||||
resAns = resAns.replace('\n', ' ')
|
||||
resAns = resAns.replace('\t', ' ')
|
||||
resAns = resAns.strip()
|
||||
gtAcc = []
|
||||
gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
|
||||
|
||||
if len(set(gtAnswers)) > 1:
|
||||
for ansDic in gts[quesId]['answers']:
|
||||
ansDic['answer'] = self.processPunctuation(ansDic['answer'])
|
||||
ansDic['answer'] = self.processDigitArticle(ansDic['answer'])
|
||||
resAns = self.processPunctuation(resAns)
|
||||
resAns = self.processDigitArticle(resAns)
|
||||
|
||||
for gtAnsDatum in gts[quesId]['answers']:
|
||||
otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum]
|
||||
matchingAns = [item for item in otherGTAns if item['answer'].lower()==resAns.lower()]
|
||||
acc = min(1, float(len(matchingAns))/3)
|
||||
gtAcc.append(acc)
|
||||
quesType = gts[quesId]['question_type']
|
||||
ansType = gts[quesId]['answer_type']
|
||||
avgGTAcc = float(sum(gtAcc))/len(gtAcc)
|
||||
accQA.append(avgGTAcc)
|
||||
if quesType not in accQuesType:
|
||||
accQuesType[quesType] = []
|
||||
accQuesType[quesType].append(avgGTAcc)
|
||||
if ansType not in accAnsType:
|
||||
accAnsType[ansType] = []
|
||||
accAnsType[ansType].append(avgGTAcc)
|
||||
self.setEvalQA(quesId, avgGTAcc)
|
||||
self.setEvalQuesType(quesId, quesType, avgGTAcc)
|
||||
self.setEvalAnsType(quesId, ansType, avgGTAcc)
|
||||
if step%100 == 0:
|
||||
self.updateProgress(step/float(len(quesIds)))
|
||||
step = step + 1
|
||||
|
||||
self.setAccuracy(accQA, accQuesType, accAnsType)
|
||||
# print "Done computing accuracy"
|
||||
|
||||
def processPunctuation(self, inText):
|
||||
outText = inText
|
||||
for p in self.punct:
|
||||
if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
|
||||
outText = outText.replace(p, '')
|
||||
else:
|
||||
outText = outText.replace(p, ' ')
|
||||
outText = self.periodStrip.sub("",
|
||||
outText,
|
||||
re.UNICODE)
|
||||
return outText
|
||||
|
||||
def processDigitArticle(self, inText):
|
||||
outText = []
|
||||
tempText = inText.lower().split()
|
||||
for word in tempText:
|
||||
word = self.manualMap.setdefault(word, word)
|
||||
if word not in self.articles:
|
||||
outText.append(word)
|
||||
else:
|
||||
pass
|
||||
for wordId, word in enumerate(outText):
|
||||
if word in self.contractions:
|
||||
outText[wordId] = self.contractions[word]
|
||||
outText = ' '.join(outText)
|
||||
return outText
|
||||
|
||||
def setAccuracy(self, accQA, accQuesType, accAnsType):
|
||||
self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n)
|
||||
self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
|
||||
self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}
|
||||
|
||||
def setEvalQA(self, quesId, acc):
|
||||
self.evalQA[quesId] = round(100*acc, self.n)
|
||||
|
||||
def setEvalQuesType(self, quesId, quesType, acc):
|
||||
if quesType not in self.evalQuesType:
|
||||
self.evalQuesType[quesType] = {}
|
||||
self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
|
||||
|
||||
def setEvalAnsType(self, quesId, ansType, acc):
|
||||
if ansType not in self.evalAnsType:
|
||||
self.evalAnsType[ansType] = {}
|
||||
self.evalAnsType[ansType][quesId] = round(100*acc, self.n)
|
||||
|
||||
def updateProgress(self, progress):
|
||||
barLength = 20
|
||||
status = ""
|
||||
if isinstance(progress, int):
|
||||
progress = float(progress)
|
||||
if not isinstance(progress, float):
|
||||
progress = 0
|
||||
status = "error: progress var must be float\r\n"
|
||||
if progress < 0:
|
||||
progress = 0
|
||||
status = "Halt...\r\n"
|
||||
if progress >= 1:
|
||||
progress = 1
|
||||
status = "Done...\r\n"
|
||||
block = int(round(barLength*progress))
|
||||
text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
|
||||
sys.stdout.write(text)
|
||||
sys.stdout.flush()
|
Loading…
Add table
Add a link
Reference in a new issue