neuro-symbolic-visual-dialog/executor/symbolic_executor.py

1679 lines
61 KiB
Python
Raw Normal View History

2022-08-10 16:49:55 +02:00
"""
author: Adnen Abdessaied
maintainer: "Adnen Abdessaied"
website: adnenabdessaied.de
version: 1.0.1
"""
import json
import numpy as np
from copy import deepcopy
from executor.clevr_statics import COLORS, MATERIALS, SHAPES, SIZES
from executor.clevr_statics import ANSWER_CANDIDATES as ANSWER_CANDIDATES_CLEVR
from executor.clevr_statics import ATTRIBUTES_ALL as ATTRIBUTES_ALL_CLEVR
from executor.minecraft_statics import DIRECTIONS, NATURES, CLASSES
from executor.minecraft_statics import ANSWER_CANDIDATES as ANSWER_CANDIDATES_MINECRAFT
from executor.minecraft_statics import ATTRIBUTES_ALL as ATTRIBUTES_ALL_MINECRAFT
from utils import load_clevr_scenes, load_minecraft_scenes
class SymbolicExecutorClevr(object):
"""Symbolic executor for clevr-dialog
"""
def __init__(self, scenesPath):
super(SymbolicExecutorClevr, self).__init__()
self.functions = {}
self.registerFunctions()
self.uniqueObjFlag = False
self.colors = COLORS
self.materials = MATERIALS
self.shapes = SHAPES
self.sizes = SIZES
self.answer_candidates = ANSWER_CANDIDATES_CLEVR
self.attribute_all = ATTRIBUTES_ALL_CLEVR
self.scenes = load_clevr_scenes(scenesPath)
def reset(self, sceneIdx):
"""Resets the scene
Args:
sceneIdx: The index of the new scene
"""
self.scene = self.scenes[sceneIdx]
for _obj in self.scene:
_obj["identifier"] = None
# store previous objects in a list to better answer
# xxx-imm, xxx-imm2, xxx-group and xxx-early questions.
self.objs = []
self.groups = []
self.visited = []
self.currentObj = None
self.currentGrp = []
self.uniqueObjFlag = False
def registerFunctions(self):
"""Registers the available functions of the executor.
"""
# Captions - extreme location
self.functions["extreme-right"] = self.extremeRight
self.functions["extreme-left"] = self.extremeLeft
self.functions["extreme-behind"] = self.extremeBehind
self.functions["extreme-front"] = self.extremeFront
self.functions["extreme-center"] = self.extremeCenter
# Captions - multiple objects
self.functions["count-att"] = self.countAttributeCaption
# Captions - object relations
self.functions["obj-relation"] = self.objRelation
# Captions - unique object
self.functions["unique-obj"] = self.uniqueObject
# Questions - Count
self.functions["count-all"] = self.countAll
self.functions["count-other"] = self.countOther
self.functions["count-all-group"] = self.countAllGroup
self.functions["count-attribute"] = self.countAttribute
self.functions["count-attribute-group"] = self.countAttributeGroup
self.functions["count-obj-rel-imm"] = self.countObjRelImm
self.functions["count-obj-rel-imm2"] = self.countObjRelImm2
self.functions["count-obj-rel-early"] = self.countObjRelEarly
self.functions["count-obj-exclude-imm"] = self.countObjExcludeImm
self.functions["count-obj-exclude-early"] = self.countObjExcludeEarly
# Questions - Exist
self.functions["exist-other"] = self.existOther
self.functions["exist-attribute"] = self.existAttribute
self.functions["exist-attribute-group"] = self.existAttributeGroup
self.functions["exist-obj-rel-imm"] = self.existObjRelImm
self.functions["exist-obj-rel-imm2"] = self.existObjRelImm
self.functions["exist-obj-rel-early"] = self.existObjRelEarly
self.functions["exist-obj-exclude-imm"] = self.existObjExcludeImm
self.functions["exist-obj-exclude-early"] = self.existObjExcludeEarly
# Questions - Seek
self.functions["seek-attr-imm"] = self.seekAttrImm
self.functions["seek-attr-imm2"] = self.seekAttrImm
self.functions["seek-attr-early"] = self.seekAttributeEarly
self.functions["seek-attr-rel-imm"] = self.seekAttributeRelImm
self.functions["seek-attr-rel-early"] = self.seekAttributeRelEarly
def getAttributeType(self, attribute):
assert attribute in self.attribute_all, "The attribute {} is unkown".format(
attribute)
if attribute in self.colors:
return "color"
elif attribute in self.materials:
return "material"
elif attribute in self.shapes:
return "shape"
elif attribute in self.sizes:
return "size"
def execute(self, functionLabel, functionArgs):
assert functionLabel in self.functions, "{} is not a valid function".format(
functionLabel)
function = self.functions[functionLabel]
answer = function(*functionArgs)
return answer
def updateCurrentObj(self, obj):
self.currentObj = obj
objsCopy = deepcopy(self.objs)
for i, _obj in enumerate(objsCopy):
if _obj["id"] == obj["id"]:
del self.objs[i]
# Current obj is always kept at the end of the visited objs
self.objs.append(obj)
def updateVisited(self, obj):
if len(self.visited) == 0:
self.visited.append(obj)
else:
newObjFlag = True
for _obj in self.visited:
if _obj["id"] == obj["id"]:
newObjFlag = False
break
if newObjFlag:
self.visited.append(obj)
def getOther(self):
others = []
if len(self.visited) < len(self.scene):
for _obj in self.scene:
notExisting = True
for __obj in self.visited:
if __obj["id"] == _obj["id"]:
notExisting = False
break
if notExisting:
others.append(_obj)
return others
def updateIdentifier(self, obj, attribute):
if obj["identifier"] is None:
obj["identifier"] = attribute
else:
identifiers = obj["identifier"].split("-")
if attribute not in identifiers:
identifiers.append(attribute)
obj["identifier"] = "-".join(identifiers)
# Captions
def extremeRight(self, *attributes):
attributes = list(attributes)
attributeTypes = list(
map(lambda att: self.getAttributeType(att), attributes))
leftToRight = deepcopy(self.scene)
leftToRight.sort(key=lambda o: o["position"][0])
extremeRightObj = leftToRight[-1]
for attributeType, attribute in zip(attributeTypes, attributes):
assert extremeRightObj[attributeType] == attribute
self.updateIdentifier(extremeRightObj, attribute)
self.updateCurrentObj(extremeRightObj)
self.updateVisited(extremeRightObj)
del leftToRight
def extremeLeft(self, *attributes):
attributes = list(attributes)
attributeTypes = list(
map(lambda att: self.getAttributeType(att), attributes))
leftToRight = deepcopy(self.scene)
leftToRight.sort(key=lambda o: o["position"][0])
extremeLeftObj = leftToRight[0]
for attributeType, attribute in zip(attributeTypes, attributes):
assert extremeLeftObj[attributeType] == attribute
self.updateIdentifier(extremeLeftObj, attribute)
self.updateCurrentObj(extremeLeftObj)
self.updateVisited(extremeLeftObj)
del leftToRight
def extremeFront(self, *attributes):
attributes = list(attributes)
attributeTypes = list(
map(lambda att: self.getAttributeType(att), attributes))
backToFront = deepcopy(self.scene)
backToFront.sort(key=lambda o: o["position"][1])
extremeFrontObj = backToFront[-1]
for attributeType, attribute in zip(attributeTypes, attributes):
assert extremeFrontObj[attributeType] == attribute
self.updateIdentifier(extremeFrontObj, attribute)
self.updateCurrentObj(extremeFrontObj)
self.updateVisited(extremeFrontObj)
del backToFront
def extremeBehind(self, *attributes):
attributes = list(attributes)
attributeTypes = list(
map(lambda att: self.getAttributeType(att), attributes))
backToFront = deepcopy(self.scene)
backToFront.sort(key=lambda o: o["position"][1])
extremeBehindObj = backToFront[0]
for attributeType, attribute in zip(attributeTypes, attributes):
assert extremeBehindObj[attributeType] == attribute
self.updateIdentifier(extremeBehindObj, attribute)
self.updateCurrentObj(extremeBehindObj)
self.updateVisited(extremeBehindObj)
del backToFront
def extremeCenter(self, *attributes):
attributes = list(attributes)
attributeTypes = list(
map(lambda att: self.getAttributeType(att), attributes))
numObjs = len(self.scene)
frontToBack = deepcopy(self.scene)
frontToBack.sort(key=lambda o: o["position"][1], reverse=True)
rightToLeft = deepcopy(self.scene)
rightToLeft.sort(key=lambda o: o["position"][0], reverse=True)
prelimenaryCandidates = []
for i, objFrontToBack in enumerate(frontToBack):
numObjsInFront = i
numObjsBehind = len(rightToLeft) - i - 1
if numObjsInFront <= numObjs / 2 and numObjsBehind <= numObjs / 2:
prelimenaryCandidates.append(objFrontToBack)
foundCenter = False
for _obj in prelimenaryCandidates:
for i, objRightToLeft in enumerate(rightToLeft):
if _obj["id"] == objRightToLeft["id"]:
numObjsToTheRight = i
numObjsToTheLeft = len(frontToBack) - i - 1
if numObjsToTheRight <= numObjs / 2 and numObjsToTheLeft <= numObjs / 2:
foundCenter = True
for attributeType, attribute in zip(attributeTypes, attributes):
if _obj[attributeType] != attribute:
foundCenter = False
break
break
if foundCenter:
break
# assert foundCenter, "[ERROR] Failed to find center object ..."
for attributeType, attribute in zip(attributeTypes, attributes):
# assert _obj[attributeType] == attribute
self.updateIdentifier(_obj, attribute)
self.updateCurrentObj(_obj)
self.updateVisited(_obj)
del rightToLeft, frontToBack
def countAttributeCaption(self, attribute):
attributeType = self.getAttributeType(attribute)
objs = []
for _obj in self.scene:
if _obj[attributeType] == attribute:
objs.append(deepcopy(_obj))
for _obj in objs:
self.updateIdentifier(_obj, attribute)
# self.updateCurrentObj(_obj)
# update the current group
self.currentGrp = objs
# update the visited objects list
for _obj in objs:
self.updateVisited(_obj)
def getAnchorAttribute(self, attribute_1, attribute_2, scene):
# The anchor object is unique. If we filter the object list
# based on the attribute anchor, we must find only one object.
filterAttribute_1 = self.filterAttribute(scene, attribute_1)
if len(filterAttribute_1) == 1:
return attribute_1
else:
return attribute_2
def objRelation(self, attribute, attributeAnchor, relation):
assert relation in ["left", "right", "front", "behind"]
# find the anchor object
if attributeAnchor != self.getAnchorAttribute(attribute, attributeAnchor, self.scene):
temp = deepcopy(attribute)
attribute = deepcopy(attributeAnchor)
attributeAnchor = temp
if relation == "left":
relation = "right"
elif relation == "right":
relation = "left"
elif relation == "behind":
relation = "front"
elif relation == "front":
relation = "behind"
# Order the objects in the scene w.r.t. the relation
sceneCopy = deepcopy(self.scene)
if relation in ["left", "right"]:
sceneCopy.sort(key=lambda o: o["position"][0])
else:
sceneCopy.sort(key=lambda o: o["position"][1])
# get the anchor object
attributeTypeAnchor = self.getAttributeType(attributeAnchor)
for i, _obj in enumerate(sceneCopy):
if _obj[attributeTypeAnchor] == attributeAnchor:
break
# save the anchor object before the main object
anchorObj = _obj
self.updateIdentifier(anchorObj, attributeAnchor)
self.updateCurrentObj(anchorObj)
self.updateVisited(anchorObj)
if relation in ["left", "behind"]:
sceneCopy = list(reversed(sceneCopy[:i]))
else:
sceneCopy = sceneCopy[i+1:]
attributeType = self.getAttributeType(attribute)
# get the main object
for _obj in sceneCopy:
# and not equalDicts(_obj, anchorObj):
if _obj[attributeType] == attribute:
break
self.updateIdentifier(_obj, attribute)
self.updateCurrentObj(_obj)
self.updateVisited(_obj)
del sceneCopy
def uniqueObject(self, *attributes):
attributes = list(attributes)
attributeTypes = list(
map(lambda att: self.getAttributeType(att), attributes))
for _obj in self.scene:
found = True
for attributeType, attribute in zip(attributeTypes, attributes):
if _obj[attributeType] != attribute:
found = False
break
if found:
break
for att in attributes:
self.updateIdentifier(_obj, att)
self.updateCurrentObj(_obj)
self.updateVisited(_obj)
# Questions
def filterOutObj(self, scene, obj):
sceneCopy = deepcopy(scene)
for i, _obj in enumerate(scene):
if obj["id"] == _obj["id"]:
break
del sceneCopy[i]
return sceneCopy
def filterAttribute(self, scene, attribute):
attributeType = self.getAttributeType(attribute)
filtered = []
if len(scene) == 0:
return filtered
for _obj in scene:
if _obj[attributeType] == attribute:
filtered.append(_obj)
return filtered
def excludeAttribute(self, scene, obj, attributeType):
filtered = []
if len(scene) == 0:
return filtered
for _obj in scene:
if _obj["id"] != obj["id"] and obj[attributeType] == _obj[attributeType]:
filtered.append(_obj)
# Update the visited objects list
if len(filtered) > 0:
for _obj in filtered:
self.updateVisited(_obj)
return filtered
def filterLeft(self, scene, obj):
filtered = []
if len(scene) == 0:
return filtered
for _obj in self.scene:
# if the x-coordinate of _obj is smaller than the x-coordinate of slef.currentObj,
# then _obj is located to the left of self.currentObj
if _obj["position"][0] < obj["position"][0] and _obj["id"] != obj["id"]:
filtered.append(_obj)
return filtered
def filterRight(self, scene, obj):
filtered = []
for _obj in self.scene:
# if the x-coordinate of _obj is bigger than the x-coordinate of slef.currentObj,
# then _obj is located to the right of self.currentObj
if _obj["position"][0] > obj["position"][0] and _obj["id"] != obj["id"]:
filtered.append(_obj)
return filtered
def filterFront(self, scene, obj):
filtered = []
if len(scene) == 0:
return filtered
for _obj in self.scene:
# if the y-coordinate of _obj is smaller than the y-coordinate of slef.currentObj,
# then _obj is located in front of self.currentObj
if _obj["position"][1] > obj["position"][1] and _obj["id"] != obj["id"]:
filtered.append(_obj)
return filtered
def filterBehind(self, scene, obj):
# assert type(scene) == list, "Excpected type list got {} instead".format(type(scene))
filtered = []
if len(scene) == 0:
return filtered
for _obj in scene:
# if the y-coordinate of _obj is bigger than the y-coordinate of slef.currentObj,
# then _obj is located behind self.currentObj
if _obj["position"][1] < obj["position"][1] and _obj["id"] != obj["id"]:
filtered.append(_obj)
return filtered
def filterPosition(self, scene, obj, pos):
# assert type(scene) == list, "Excpected type list got {} instead".format(type(scene))
assert pos in ["left", "right", "front", "behind"]
if pos == "left":
filtered = self.filterLeft(scene, obj)
elif pos == "right":
filtered = self.filterRight(scene, obj)
elif pos == "front":
filtered = self.filterFront(scene, obj)
elif pos == "behind":
filtered = self.filterBehind(scene, obj)
# Update the visited objects list
# for _obj in filtered:
# self.updateVisited(_obj)
return filtered
###########################################################################
# Counting questions #
###########################################################################
def countAll(self):
self.currentGrp = deepcopy(self.scene)
self.groups.append(deepcopy(self.scene))
return len(self.scene)
def countOther(self):
others = self.getOther()
if len(others) > 0:
self.currentGrp = others
self.groups.append(others)
if len(others) == 1:
obj = others[0]
for _obj in self.objs:
if _obj["id"] == obj["id"]:
obj = _obj
break
self.updateCurrentObj(obj)
self.updateVisited(obj)
return len(others)
def countAllGroup(self):
return len(self.currentGrp)
def countAttribute(self, attribute, updateCurrentObj=True):
filtered = self.filterAttribute(self.scene, attribute)
if len(filtered) == 0:
return 0
# Update the visited objects list
for _obj in filtered:
self.updateVisited(_obj)
if len(filtered) == 1:
obj = filtered[0]
new = True
for _obj in self.objs:
if _obj["id"] == obj["id"]:
obj = _obj
new = False
break
self.updateIdentifier(obj, attribute)
self.updateVisited(obj)
if updateCurrentObj:
self.updateCurrentObj(obj)
else:
if new:
self.objs.append(obj)
self.groups.append(filtered)
self.currentGrp = filtered
return len(filtered)
def countAttributeGroup(self, attribute, updateCurrentObj=True):
filtered = self.filterAttribute(self.currentGrp, attribute)
if len(filtered) == 0:
return 0
# Update the visited objects list
for _obj in filtered:
self.updateVisited(_obj)
if len(filtered) == 1:
obj = filtered[0]
new = True
for _obj in self.objs:
if _obj["id"] == obj["id"]:
obj = _obj
new = False
break
self.updateIdentifier(obj, attribute)
self.updateVisited(obj)
if updateCurrentObj:
self.updateCurrentObj(obj)
else:
if new:
self.objs.append(obj)
self.groups.append(filtered)
self.currentGrp = filtered
return len(filtered)
def countObjRelImm(self, pos, updateCurrentObj=True):
filtered = self.filterPosition(self.scene, self.currentObj, pos)
if len(filtered) == 0:
return 0
# Update the visited objects list
for _obj in filtered:
self.updateVisited(_obj)
self.currentGrp = filtered
self.groups.append(filtered)
if len(filtered) == 1:
obj = filtered[0]
new = True
for _obj in self.objs:
if _obj["id"] == obj["id"]:
obj = _obj
new = False
break
if updateCurrentObj:
self.updateCurrentObj(obj)
self.uniqueObjFlag = True
else:
if new:
self.objs.append(obj)
return len(filtered)
def countObjRelImm2(self, pos):
if self.uniqueObjFlag:
# del self.objs[-1]
self.updateCurrentObj(self.objs[-2])
self.uniqueObjFlag = False
return self.countObjRelImm(pos)
def countObjRelEarly(self, pos, earlyObjAttribute, updateCurrentObj=True):
for objEarly in reversed(self.objs):
if objEarly["identifier"] is not None:
identifiers = objEarly["identifier"].split("-")
if earlyObjAttribute in identifiers:
break
else:
continue
filtered = self.filterPosition(self.scene, objEarly, pos)
if len(filtered) == 0:
return 0
# Update the visited objects list
for _obj in filtered:
self.updateVisited(_obj)
if len(filtered) == 1:
obj = filtered[0]
new = True
for _obj in self.objs:
if _obj["id"] == obj["id"]:
obj = _obj
new = False
break
if updateCurrentObj:
self.updateCurrentObj(obj)
else:
if new:
self.objs.append(obj)
else:
self.updateCurrentObj(objEarly)
self.currentGrp = filtered
self.groups.append(filtered)
return len(filtered)
def countObjExcludeImm(self, attributeType, updateCurrentObj=True):
filtered = self.excludeAttribute(
self.scene, self.currentObj, attributeType)
if len(filtered) == 0:
return 0
if len(filtered) == 1:
obj = filtered[0]
new = True
for _obj in self.objs:
if _obj["id"] == obj["id"]:
obj = _obj
new = False
break
if updateCurrentObj:
self.updateCurrentObj(obj)
else:
if new:
self.objs.append(obj)
self.currentGrp = filtered
self.groups.append(filtered)
return len(filtered)
def countObjExcludeEarly(self, attributeType, earlyObjAttribute, updateCurrentObj=True):
for objEarly in reversed(self.objs):
if objEarly["identifier"] is not None:
identifiers = objEarly["identifier"].split("-")
if earlyObjAttribute in identifiers:
break
else:
continue
filtered = self.excludeAttribute(self.scene, objEarly, attributeType)
if len(filtered) == 0:
return 0
if len(filtered) == 1:
obj = filtered[0]
new = True
for _obj in self.objs:
if _obj["id"] == obj["id"]:
obj = _obj
new = False
break
if updateCurrentObj:
self.updateCurrentObj(obj)
else:
if new:
self.objs.append(obj)
else:
self.updateCurrentObj(objEarly)
self.currentGrp = filtered
self.groups.append(filtered)
return len(filtered)
###########################################################################
# Existence questions #
###########################################################################
def existOther(self):
others = self.getOther()
numOther = len(others)
if numOther > 0:
self.currentGrp = others
self.groups.append(others)
for _obj in others:
self.updateVisited(_obj)
return "yes" if numOther > 0 else "no"
def existAttribute(self, attribute):
filtered = self.filterAttribute(self.scene, attribute)
numAttribute = len(filtered)
if numAttribute == 0:
return "no"
# Update the visited objects list
for _obj in filtered:
self.updateVisited(_obj)
if len(filtered) == 1:
obj = filtered[0]
new = True
for _obj in self.objs:
if _obj["id"] == obj["id"]:
self.updateIdentifier(_obj, attribute)
new = False
break
if new:
self.updateIdentifier(obj, attribute)
self.objs.append(obj)
# self.updateCurrentObj(obj)
self.currentGrp = filtered
self.groups.append(filtered)
return "yes"
def existAttributeGroup(self, attribute):
numAttributeGrp = self.countAttributeGroup(
attribute, updateCurrentObj=False)
return "yes" if numAttributeGrp > 0 else "no"
def existObjRelImm(self, pos):
numObjs = self.countObjRelImm(pos, updateCurrentObj=False)
return "yes" if numObjs > 0 else "no"
def existObjRelEarly(self, pos, earlyObjAttribute):
numObjs = self.countObjRelEarly(
pos, earlyObjAttribute, updateCurrentObj=False)
return "yes" if numObjs > 0 else "no"
def existObjExcludeImm(self, attributeType):
numObjs = self.countObjExcludeImm(
attributeType, updateCurrentObj=False)
return "yes" if numObjs > 0 else "no"
def existObjExcludeEarly(self, attributeType, earlyObjAttribute):
for objEarly in reversed(self.objs):
if objEarly["identifier"] is not None:
identifiers = objEarly["identifier"].split("-")
if earlyObjAttribute in identifiers:
break
else:
continue
filtered = self.excludeAttribute(self.scene, objEarly, attributeType)
numObjs = len(filtered)
if numObjs == 0:
return "no"
self.currentGrp = filtered
self.groups.append(filtered)
return "yes"
###########################################################################
# Seek questions #
###########################################################################
def seekAttrImm(self, attributeType):
assert attributeType in self.currentObj, "Attributre <{}> is not valid"
self.updateIdentifier(self.currentObj, self.currentObj[attributeType])
return self.currentObj[attributeType]
def seekAttributeEarly(self, attributeType, earlyObjAttribute):
for objEarly in reversed(self.objs):
if objEarly["identifier"] is not None:
identifiers = objEarly["identifier"].split("-")
if earlyObjAttribute in identifiers:
break
else:
continue
self.updateIdentifier(objEarly, objEarly[attributeType])
self.updateCurrentObj(objEarly)
self.updateVisited(objEarly)
return objEarly[attributeType]
def seekAttributeRelImm(self, attributeType, pos):
filtered = self.filterPosition(self.scene, self.currentObj, pos)
if len(filtered) == 0:
return "none"
else:
# Get the closest object to slef.obj
if pos == "left":
filtered.sort(key=lambda x: x["position"][0])
obj = filtered[-1]
elif pos == "right":
filtered.sort(key=lambda x: x["position"][0])
obj = filtered[0]
elif pos == "front":
filtered.sort(key=lambda x: x["position"][1])
obj = filtered[0]
elif pos == "behind":
filtered.sort(key=lambda x: x["position"][1])
obj = filtered[-1]
for _obj in self.objs:
if _obj["id"] == obj["id"]:
obj["identifier"] = _obj["identifier"]
break
self.updateIdentifier(obj, obj[attributeType])
self.updateCurrentObj(obj)
self.updateVisited(obj)
return obj[attributeType]
def seekAttributeRelEarly(self, attributeType, pos, earlyObjAttribute):
for objEarly in reversed(self.objs):
if objEarly["identifier"] is not None:
identifiers = objEarly["identifier"].split("-")
if earlyObjAttribute in identifiers:
break
else:
continue
filtered = self.filterPosition(self.scene, objEarly, pos)
if len(filtered) == 0:
return "none"
else:
# Get the closest object to slef.obj
if pos == "left":
filtered.sort(key=lambda x: x["position"][0])
obj = filtered[-1]
elif pos == "right":
filtered.sort(key=lambda x: x["position"][0])
obj = filtered[0]
elif pos == "front":
filtered.sort(key=lambda x: x["position"][1])
obj = filtered[0]
elif pos == "behind":
filtered.sort(key=lambda x: x["position"][1])
obj = filtered[-1]
for _obj in self.objs:
if _obj["id"] == obj["id"]:
obj["identifier"] = _obj["identifier"]
break
self.updateIdentifier(obj, obj[attributeType])
self.updateCurrentObj(obj)
self.updateVisited(obj)
return obj[attributeType]
class SymbolicExecutorMinecraft(object):
"""Symbolic executor for minecraft-dialog
"""
def __init__(self, scenesPath):
super(SymbolicExecutorMinecraft, self).__init__()
self.functions = {}
self.registerFunctions()
self.uniqueObjFlag = False
self.classes = CLASSES
self.natures = NATURES
self.directions = DIRECTIONS
self.answer_candidates = ANSWER_CANDIDATES_MINECRAFT
self.attribute_all = ATTRIBUTES_ALL_MINECRAFT
self.scenes = load_minecraft_scenes(scenesPath)
def reset(self, sceneIdx):
self.scene = self.scenes[sceneIdx]
for _obj in self.scene:
_obj["identifier"] = None
# store previous objects in a list to better answer
# xxx-imm, xxx-imm2, xxx-group and xxx-early questions.
self.objs = []
self.groups = []
self.visited = []
self.currentObj = None
self.currentGrp = []
self.uniqueObjFlag = False
def registerFunctions(self):
# Captions - extreme location
self.functions["extreme-right"] = self.extremeRight
self.functions["extreme-left"] = self.extremeLeft
self.functions["extreme-behind"] = self.extremeBehind
self.functions["extreme-front"] = self.extremeFront
self.functions["extreme-center"] = self.extremeCenter
# Captions - multiple objects
self.functions["count-att"] = self.countAttributeCaption
# Captions - object relations
self.functions["obj-relation"] = self.objRelation
# Captions - unique object
self.functions["unique-obj"] = self.uniqueObject
# Questions - Count
self.functions["count-all"] = self.countAll
self.functions["count-other"] = self.countOther
self.functions["count-all-group"] = self.countAllGroup
self.functions["count-attribute"] = self.countAttribute
self.functions["count-attribute-group"] = self.countAttributeGroup
self.functions["count-obj-rel-imm"] = self.countObjRelImm
self.functions["count-obj-rel-imm2"] = self.countObjRelImm2
self.functions["count-obj-rel-early"] = self.countObjRelEarly
self.functions["count-obj-exclude-imm"] = self.countObjExcludeImm
self.functions["count-obj-exclude-early"] = self.countObjExcludeEarly
# Questions - Exist
self.functions["exist-other"] = self.existOther
self.functions["exist-attribute"] = self.existAttribute
self.functions["exist-attribute-group"] = self.existAttributeGroup
self.functions["exist-obj-rel-imm"] = self.existObjRelImm
self.functions["exist-obj-rel-imm2"] = self.existObjRelImm
self.functions["exist-obj-rel-early"] = self.existObjRelEarly
self.functions["exist-obj-exclude-imm"] = self.existObjExcludeImm
self.functions["exist-obj-exclude-early"] = self.existObjExcludeEarly
# Questions - Seek
self.functions["seek-attr-imm"] = self.seekAttrImm
self.functions["seek-attr-imm2"] = self.seekAttrImm
self.functions["seek-attr-early"] = self.seekAttributeEarly
self.functions["seek-attr-rel-imm"] = self.seekAttributeRelImm
self.functions["seek-attr-rel-early"] = self.seekAttributeRelEarly
def getAttributeType(self, attribute):
assert attribute in self.attribute_all, "The attribute {} is unkown".format(
attribute)
if attribute in self.classes:
return "class"
elif attribute in self.directions:
return "direction"
elif attribute in self.natures:
return "nature"
def execute(self, functionLabel, functionArgs):
assert functionLabel in self.functions, "{} is not a valid function".format(
functionLabel)
function = self.functions[functionLabel]
answer = function(*functionArgs)
return answer
def updateCurrentObj(self, obj):
self.currentObj = obj
objsCopy = deepcopy(self.objs)
for i, _obj in enumerate(objsCopy):
if _obj["id"] == obj["id"]:
del self.objs[i]
# Current obj is always kept at the end of the visited objs
self.objs.append(obj)
def updateVisited(self, obj):
if len(self.visited) == 0:
self.visited.append(obj)
else:
newObjFlag = True
for _obj in self.visited:
if _obj["id"] == obj["id"]:
newObjFlag = False
break
if newObjFlag:
self.visited.append(obj)
def getOther(self):
others = []
if len(self.visited) < len(self.scene):
for _obj in self.scene:
notExisting = True
for __obj in self.visited:
if __obj["id"] == _obj["id"]:
notExisting = False
break
if notExisting:
others.append(_obj)
return others
def updateIdentifier(self, obj, attribute):
if obj["identifier"] is None:
obj["identifier"] = attribute
else:
identifiers = obj["identifier"].split("-")
if attribute not in identifiers:
identifiers.append(attribute)
obj["identifier"] = "-".join(identifiers)
# Captions
def extremeRight(self, *attributes):
attributes = list(attributes)
attributeTypes = list(
map(lambda att: self.getAttributeType(att), attributes))
rightToLeft = deepcopy(self.scene)
rightToLeft.sort(key=lambda o: o["position"][0], reverse=True)
# Some objects in the minecraft dataset share the same coordinate
# values leading to nonuniqueness in init. the scene. To reduce the
# error risk, we choose the extreme obj with the correct attribute
for _obj in rightToLeft:
found = True
for attributeType, attribute in zip(attributeTypes, attributes):
if _obj[attributeType] != attribute:
found = False
break
if found:
break
extremeRightObj = _obj
assert extremeRightObj["position"][0] == rightToLeft[0]["position"][0]
for att in attributes:
self.updateIdentifier(extremeRightObj, att)
self.updateCurrentObj(extremeRightObj)
self.updateVisited(extremeRightObj)
del rightToLeft
def extremeLeft(self, *attributes):
attributes = list(attributes)
attributeTypes = list(
map(lambda att: self.getAttributeType(att), attributes))
leftToRight = deepcopy(self.scene)
leftToRight.sort(key=lambda o: o["position"][0])
# Some objects in the minecraft dataset share the same coordinate
# values leading to nonuniqueness in init. the scene. To reduce the
# error risk, we choose the extreme obj with the correct attribute
for _obj in leftToRight:
found = True
for attributeType, attribute in zip(attributeTypes, attributes):
if _obj[attributeType] != attribute:
found = False
break
if found:
break
extremeLeftObj = _obj
assert extremeLeftObj["position"][0] == leftToRight[0]["position"][0]
for att in attributes:
self.updateIdentifier(extremeLeftObj, att)
self.updateCurrentObj(extremeLeftObj)
self.updateVisited(extremeLeftObj)
del leftToRight
def extremeFront(self, *attributes):
attributes = list(attributes)
attributeTypes = list(
map(lambda att: self.getAttributeType(att), attributes))
frontToBack = deepcopy(self.scene)
frontToBack.sort(key=lambda o: o["position"][1])
# Some objects in the minecraft dataset share the same coordinate
# values leading to nonuniqueness in init. the scene. To reduce the
# error risk, we choose the extreme obj with the correct attribute
for _obj in frontToBack:
found = True
for attributeType, attribute in zip(attributeTypes, attributes):
if _obj[attributeType] != attribute:
found = False
break
if found:
break
extremeFrontObj = _obj
assert extremeFrontObj["position"][1] == frontToBack[0]["position"][1]
for att in attributes:
self.updateIdentifier(extremeFrontObj, att)
self.updateCurrentObj(extremeFrontObj)
self.updateVisited(extremeFrontObj)
del frontToBack
def extremeBehind(self, *attributes):
attributes = list(attributes)
attributeTypes = list(
map(lambda att: self.getAttributeType(att), attributes))
backToFront = deepcopy(self.scene)
backToFront.sort(key=lambda o: o["position"][1], reverse=True)
# Some objects in the minecraft dataset share the same coordinate
# values leading to nonuniqueness in init. the scene. To reduce the
# error risk, we choose the extreme obj with the correct attribute
for _obj in backToFront:
found = True
for attributeType, attribute in zip(attributeTypes, attributes):
if _obj[attributeType] != attribute:
found = False
break
if found:
break
extremeRearObj = _obj
assert extremeRearObj["position"][1] == backToFront[0]["position"][1]
for att in attributes:
self.updateIdentifier(extremeRearObj, att)
self.updateCurrentObj(extremeRearObj)
self.updateVisited(extremeRearObj)
del backToFront
def extremeCenter(self, *attributes):
attributes = list(attributes)
attributeTypes = list(
map(lambda att: self.getAttributeType(att), attributes))
numObjs = len(self.scene)
frontToBack = deepcopy(self.scene)
frontToBack.sort(key=lambda o: o["position"][1])
rightToLeft = deepcopy(self.scene)
rightToLeft.sort(key=lambda o: o["position"][0], reverse=True)
prelimenaryCandidates = []
for i, objFrontToBack in enumerate(frontToBack):
numObjsInFront = i
numObjsBehind = len(rightToLeft) - i - 1