1050 lines
38 KiB
Python
1050 lines
38 KiB
Python
|
|
||
|
"""Supporting script checks constraints for caption and question generation.
|
||
|
Author: Satwik Kottur
|
||
|
"""
|
||
|
|
||
|
import copy
|
||
|
import json
|
||
|
import random
|
||
|
import numpy as np
|
||
|
|
||
|
import global_vars as gvars
|
||
|
|
||
|
|
||
|
# Some quick methods.
|
||
|
def apply_immediate(hist): return (len(hist['objects']) == 1 and
|
||
|
hist['mergeable'] and
|
||
|
'exist' not in hist['template'])
|
||
|
|
||
|
|
||
|
def apply_group(hist): return (len(hist['objects']) >= 2 and
|
||
|
hist['mergeable'] and
|
||
|
'count' not in prev_group)
|
||
|
|
||
|
|
||
|
def caption(scene, templates):
|
||
|
"""Constraints for caption generation.
|
||
|
Args:
|
||
|
scene: CLEVR Scene graphs to generate captions with constraints
|
||
|
template: List of caption templates
|
||
|
Returns:
|
||
|
sample_captions: Samples from caption hypotheses
|
||
|
"""
|
||
|
|
||
|
caption_hypotheses = {}
|
||
|
|
||
|
# Sweep through all templates to extract 'interesting' captions.
|
||
|
n_objs = len(scene['objects'])
|
||
|
rels = scene['relationships']
|
||
|
|
||
|
# Caption Type 1: Extreme locations.
|
||
|
ext_loc_templates = [ii for ii in templates if ii['type'] == 'extreme-loc']
|
||
|
# number of objects in the scene
|
||
|
filter_objs = copy.deepcopy(scene['objects'])
|
||
|
attr_counts = get_attribute_counts_for_objects(scene, filter_objs)
|
||
|
hypotheses = []
|
||
|
for template in ext_loc_templates:
|
||
|
# absolute location based constraint
|
||
|
constraint = template['constraints'][0]
|
||
|
extreme_type = constraint['args'][0]
|
||
|
|
||
|
# check if there is an object that is at the center of the image
|
||
|
# roughly in the middle along front-back and right-left dim
|
||
|
if extreme_type == 'center':
|
||
|
for ii, obj in enumerate(filter_objs):
|
||
|
bla = [len(rels[kk][ii]) <= n_objs / 2
|
||
|
for kk in ['front', 'behind', 'right', 'left']]
|
||
|
matches = np.sum([len(rels[kk][ii]) <= n_objs / 2
|
||
|
for kk in ['front', 'behind', 'right', 'left']])
|
||
|
if matches == 4:
|
||
|
hypotheses.append((extreme_type, copy.deepcopy(obj)))
|
||
|
else:
|
||
|
for ii, obj in enumerate(filter_objs):
|
||
|
if len(rels[extreme_type][ii]) == 0:
|
||
|
hypotheses.append((extreme_type, copy.deepcopy(obj)))
|
||
|
|
||
|
# sample one at random, and create the graph item
|
||
|
# Filter hypothesis which are ambiguous otherwise.
|
||
|
for index, (_, hypothesis) in enumerate(hypotheses):
|
||
|
uniq_attr = [attr for attr in gvars.METAINFO['attributes']
|
||
|
if attr_counts[(attr, hypothesis[attr])] == 1]
|
||
|
|
||
|
for attr in uniq_attr:
|
||
|
del hypotheses[index][1][attr]
|
||
|
|
||
|
hypotheses = [ii for ii in hypotheses if len(ii[1]) > 1]
|
||
|
caption_hypotheses['extreme-loc'] = hypotheses
|
||
|
|
||
|
# Caption Type 2: Unique object and attribute.
|
||
|
filter_objs = copy.deepcopy(scene['objects'])
|
||
|
# each hypothesis is (object, attribute) pair
|
||
|
hypotheses = []
|
||
|
for ii, obj in enumerate(filter_objs):
|
||
|
# get unique set of attributes
|
||
|
uniq_attrs = [ii for ii in gvars.METAINFO['attributes']
|
||
|
if attr_counts[(ii, obj[ii])] == 1]
|
||
|
# for each, add it to hypothesis
|
||
|
for attr in uniq_attrs:
|
||
|
hypotheses.append((obj, attr))
|
||
|
caption_hypotheses['unique-obj'] = hypotheses
|
||
|
|
||
|
# Caption Type 3: Unique attribute count based caption.
|
||
|
# count unique object based constraint
|
||
|
# Each hypothesis is object collection.
|
||
|
caption_hypotheses['count-attr'] = [(attr_val, count)
|
||
|
for attr_val, count in attr_counts.items()
|
||
|
if count > 1]
|
||
|
|
||
|
# Caption Type 4: Relation between two objects.
|
||
|
# Out of the two, one has a unique attribute.
|
||
|
# find a pair of objects sharing a relation, unique
|
||
|
filter_objs = copy.deepcopy(scene['objects'])
|
||
|
n_objs = len(filter_objs)
|
||
|
|
||
|
# get a dict of unique attributes for each object
|
||
|
uniq_attr = [[] for ii in range(n_objs)]
|
||
|
non_uniq_attr = [[] for ii in range(n_objs)]
|
||
|
for ind, obj in enumerate(filter_objs):
|
||
|
uniq_attr[ind] = [attr for attr in gvars.METAINFO['attributes']
|
||
|
if attr_counts[(attr, obj[attr])] == 1]
|
||
|
non_uniq_attr[ind] = [attr for attr in gvars.METAINFO['attributes']
|
||
|
if attr_counts[(attr, obj[attr])] > 1]
|
||
|
uniqueness = [len(ii) > 0 for ii in uniq_attr]
|
||
|
|
||
|
# Hypothesis is a uniq object and non-unique obj2 sharing relation R
|
||
|
# global ordering for uniqueness
|
||
|
hypotheses = []
|
||
|
for rel, order in scene['relationships'].items():
|
||
|
num_rel = [(ii, len(order[ii])) for ii in range(n_objs)]
|
||
|
num_rel = sorted(num_rel, key=lambda x: x[1], reverse=True)
|
||
|
# take only the ids
|
||
|
num_rel = [ii[0] for ii in num_rel]
|
||
|
|
||
|
for index, obj_id in enumerate(num_rel[:-1]):
|
||
|
next_obj_id = num_rel[index + 1]
|
||
|
# if unique, check if the next one has non-unique attributes
|
||
|
if uniqueness[obj_id]:
|
||
|
if len(non_uniq_attr[next_obj_id]) > 0:
|
||
|
obj1 = (obj_id, random.choice(uniq_attr[obj_id]))
|
||
|
obj2 = (next_obj_id, random.choice(non_uniq_attr[next_obj_id]))
|
||
|
hypotheses.append((obj1, rel, obj2))
|
||
|
# if not unique, check if the next one has unique attributes
|
||
|
else:
|
||
|
if len(uniq_attr[next_obj_id]) > 0:
|
||
|
obj1 = (obj_id, random.choice(non_uniq_attr[obj_id]))
|
||
|
obj2 = (next_obj_id, random.choice(uniq_attr[next_obj_id]))
|
||
|
hypotheses.append((obj1, rel, obj2))
|
||
|
caption_hypotheses['obj-relation'] = hypotheses
|
||
|
sample_captions = sample_from_hypotheses(
|
||
|
caption_hypotheses, scene, templates)
|
||
|
return sample_captions
|
||
|
|
||
|
|
||
|
def question(scene, dialog, template):
|
||
|
"""Constraints question generation.
|
||
|
Inputs:
|
||
|
scene:Partial scene graphs on CLEVR images with generated captions
|
||
|
template: List of question templates to use
|
||
|
Output:
|
||
|
list of object groups
|
||
|
"""
|
||
|
|
||
|
ques_round = len(dialog['graph']['history']) - 1
|
||
|
graph = dialog['graph']
|
||
|
|
||
|
# check for constraints and answer question
|
||
|
if 'group' in template['label']:
|
||
|
groups = []
|
||
|
# Pick a group hypothesis
|
||
|
for ii in graph['history']:
|
||
|
if 'count' in ii or len(ii['objects']) == 0:
|
||
|
groups.append(ii)
|
||
|
|
||
|
if template['label'] == 'count-all':
|
||
|
# Preliminary checks:
|
||
|
# (A) count-all cannot follow count-all, count-other
|
||
|
for prev_history in graph['history'][1:]:
|
||
|
if prev_history['template'] in ['count-all', 'count-other']:
|
||
|
return []
|
||
|
|
||
|
# create object group
|
||
|
obj_group = []
|
||
|
new_obj = {'required': [], 'optional': []}
|
||
|
for obj_id, ii in enumerate(scene['objects']):
|
||
|
obj_copy = copy.deepcopy(new_obj)
|
||
|
obj_copy['id'] = ii['id']
|
||
|
obj_group.append(obj_copy)
|
||
|
|
||
|
# create graph item
|
||
|
graph_item = {'round': ques_round + 1,
|
||
|
'objects': copy.deepcopy(obj_group),
|
||
|
'template': template['label'],
|
||
|
'mergeable': True, 'count': len(obj_group)}
|
||
|
# clean graph item
|
||
|
graph_item = clean_graph_item(graph_item)
|
||
|
# no constraints, count the number of objects in true scene
|
||
|
return [{'answer': len(obj_group), 'group_id': ques_round + 1,
|
||
|
'objects': [], 'graph': graph_item}]
|
||
|
|
||
|
elif (template['label'] == 'count-other' or
|
||
|
template['label'] == 'exist-other'):
|
||
|
# preliminary checks:
|
||
|
# (A) exist-other cannot follow exist-other, count-all, count-other
|
||
|
# (B) count-other cannot follow count-all, count-other
|
||
|
for prev_history in graph['history'][1:]:
|
||
|
if prev_history['template'] in ['count-all', 'count-other']:
|
||
|
return []
|
||
|
|
||
|
if (prev_history['template'] == 'exist-other' and
|
||
|
template['label'] == 'exist-other'):
|
||
|
return []
|
||
|
|
||
|
# get a list of all objects we know
|
||
|
known_ids = [jj['id'] for ii in graph['history'] for jj in ii['objects']]
|
||
|
known_ids = list(set(known_ids))
|
||
|
n_objs = len(scene['objects'])
|
||
|
difference = n_objs - len(known_ids)
|
||
|
diff_ids = [ii for ii in range(n_objs) if ii not in known_ids]
|
||
|
|
||
|
# create empty objects for these
|
||
|
obj_group = [{'id': ii} for ii in diff_ids]
|
||
|
|
||
|
# create graph item
|
||
|
graph_item = {'round': ques_round + 1, 'objects': obj_group,
|
||
|
'template': template['label'], 'mergeable': False}
|
||
|
|
||
|
if 'count' in template['label']:
|
||
|
graph_item['count'] = difference
|
||
|
graph_item['mergeable'] = True # merge if count is known
|
||
|
answer = difference
|
||
|
elif 'exist' in template['label']:
|
||
|
# If heads (> 0.5) -- difference > 0
|
||
|
if random.random() > 0.5:
|
||
|
if difference > 0:
|
||
|
answer = 'yes'
|
||
|
else:
|
||
|
return []
|
||
|
else:
|
||
|
if difference == 0:
|
||
|
answer = 'no'
|
||
|
else:
|
||
|
return []
|
||
|
|
||
|
# no constraints, count the number of objects in true scene
|
||
|
return [{'answer': answer, 'group_id': ques_round + 1,
|
||
|
'objects': [], 'graph': graph_item}]
|
||
|
|
||
|
elif template['label'] == 'count-all-group':
|
||
|
# we need a group in the previous round
|
||
|
prev_group = graph['history'][-1]
|
||
|
prev_label = prev_group['template']
|
||
|
if not (len(prev_group['objects']) > 1 and
|
||
|
'count' not in prev_group and
|
||
|
'obj-relation' not in prev_label):
|
||
|
return []
|
||
|
|
||
|
# check if count is not given before
|
||
|
attrs = [ii for ii in gvars.METAINFO['attributes'] if ii in prev_group]
|
||
|
count = 0
|
||
|
for obj in prev_group['objects']:
|
||
|
count += all([obj[ii] == prev_group['objects'][0][ii] for ii in attrs])
|
||
|
|
||
|
# create object group
|
||
|
obj_group = []
|
||
|
new_obj = {'required': [], 'optional': []}
|
||
|
for obj_id, ii in enumerate(scene['objects']):
|
||
|
obj_copy = copy.deepcopy(new_obj)
|
||
|
obj_copy['id'] = ii['id']
|
||
|
obj_group.append(obj_copy)
|
||
|
|
||
|
# create graph item
|
||
|
graph_item = {'round': ques_round + 1, 'objects': copy.deepcopy(obj_group),
|
||
|
'template': template['label'],
|
||
|
'mergeable': True, 'count': count}
|
||
|
# clean graph item
|
||
|
graph_item = clean_graph_item(graph_item)
|
||
|
# no constraints, count the number of objects in true scene
|
||
|
return [{'answer': count, 'group_id': ques_round + 1,
|
||
|
'objects': [], 'graph': graph_item}]
|
||
|
|
||
|
elif ('count-obj-exclude' in template['label'] or
|
||
|
'exist-obj-exclude' in template['label']):
|
||
|
# placeholder for object description, see below
|
||
|
obj_desc = None
|
||
|
prev_history = graph['history'][-1]
|
||
|
scene_counts = get_attribute_counts_for_objects(scene)
|
||
|
|
||
|
if 'imm' in template['label']:
|
||
|
# we need an immediate group in the previous round
|
||
|
if apply_immediate(prev_history):
|
||
|
focus_id = prev_history['objects'][0]['id']
|
||
|
else:
|
||
|
return []
|
||
|
|
||
|
elif 'early' in template['label']:
|
||
|
# search through history for an object with unique attribute
|
||
|
attr_counts = get_known_attribute_counts(graph)
|
||
|
# get attributes with just one count
|
||
|
single_count = [ii for ii, count in attr_counts.items() if count == 1]
|
||
|
# remove attributes that point to objects in the previous round
|
||
|
# TODO: re-think this again
|
||
|
obj_ids = get_unique_attribute_objects(graph, single_count)
|
||
|
prev_history_obj_ids = [ii['id'] for ii in prev_history['objects']]
|
||
|
single_count = [ii for ii in single_count if
|
||
|
obj_ids[ii] not in prev_history_obj_ids]
|
||
|
|
||
|
if len(single_count) == 0:
|
||
|
return []
|
||
|
|
||
|
# give preference to attributes with multiple counts in scene graph
|
||
|
#scene_counts = get_attribute_counts_for_objects(scene)
|
||
|
ambiguous_attrs = [ii for ii in single_count if scene_counts[ii] > 1]
|
||
|
if len(ambiguous_attrs) > 0:
|
||
|
focus_attr = random.choice(ambiguous_attrs)
|
||
|
else:
|
||
|
focus_attr = random.choice(single_count)
|
||
|
focus_id = obj_ids[focus_attr]
|
||
|
|
||
|
# unique object description
|
||
|
obj_desc = {'required': [focus_attr[0]], 'optional': [],
|
||
|
focus_attr[0]: focus_attr[1]}
|
||
|
|
||
|
# get the known attributes for the current object
|
||
|
focus_obj = graph['objects'][focus_id]
|
||
|
known_attrs = [attr for attr in gvars.METAINFO['attributes']
|
||
|
if attr in focus_obj and
|
||
|
'%s_exclude_count' % attr not in focus_obj]
|
||
|
|
||
|
# for count: only if existence if True, else count it trivially zero
|
||
|
if 'count' in template['label']:
|
||
|
for attr in known_attrs[::-1]:
|
||
|
if not focus_obj.get('%s_exclude_exist' % attr, True):
|
||
|
known_attrs.remove(attr)
|
||
|
# for exist: get relations without exist before
|
||
|
elif 'exist' in template['label']:
|
||
|
known_attrs = [attr for attr in known_attrs
|
||
|
if '%s_exclude_exist' % attr not in focus_obj]
|
||
|
|
||
|
# select an attribute
|
||
|
if len(known_attrs) == 0:
|
||
|
return[]
|
||
|
|
||
|
# split this into zero and non-zero
|
||
|
if 'exist' in template['label']:
|
||
|
focus_attrs = [(ii, scene['objects'][focus_id][ii])
|
||
|
for ii in known_attrs]
|
||
|
zero_count = [ii for ii in focus_attrs if scene_counts[ii] == 1]
|
||
|
nonzero_count = [ii for ii in focus_attrs if scene_counts[ii] > 1]
|
||
|
|
||
|
if random.random() > 0.5:
|
||
|
if len(zero_count) > 0:
|
||
|
attr = random.choice(zero_count)[0]
|
||
|
else:
|
||
|
return []
|
||
|
else:
|
||
|
if len(nonzero_count) > 0:
|
||
|
attr = random.choice(nonzero_count)[0]
|
||
|
else:
|
||
|
return []
|
||
|
else:
|
||
|
attr = random.choice(known_attrs)
|
||
|
|
||
|
# create the object group
|
||
|
obj_group = []
|
||
|
new_obj = {'required': ['attribute'], 'optional': []}
|
||
|
for obj in scene['objects']:
|
||
|
# add if same attribute value and not focus object
|
||
|
if obj[attr] == focus_obj[attr] and obj['id'] != focus_id:
|
||
|
obj_copy = copy.deepcopy(new_obj)
|
||
|
obj_copy['id'] = obj['id']
|
||
|
obj_copy[attr] = focus_obj[attr]
|
||
|
obj_group.append(obj_copy)
|
||
|
answer = len(obj_group)
|
||
|
|
||
|
ref_obj = copy.deepcopy(new_obj)
|
||
|
ref_obj['id'] = focus_id
|
||
|
ref_obj['volatile'] = True
|
||
|
if 'exist' in template['label']:
|
||
|
answer = 'yes' if answer > 0 else 'no'
|
||
|
ref_obj['%s_exclude_exist' % attr] = answer
|
||
|
elif 'count' in template['label']:
|
||
|
ref_obj['%s_exclude_count' % attr] = answer
|
||
|
obj_group.append(ref_obj)
|
||
|
|
||
|
graph_item = {'round': ques_round+1, 'objects': copy.deepcopy(obj_group),
|
||
|
'template': template['label'], 'mergeable': True,
|
||
|
'focus_id': focus_id, 'focus_desc': obj_desc}
|
||
|
if 'count' in template['label']:
|
||
|
graph_item['count'] = answer
|
||
|
graph_item = clean_graph_item(graph_item)
|
||
|
|
||
|
ref_obj['attribute'] = attr
|
||
|
return [{'answer': answer, 'group_id': ques_round + 1,
|
||
|
'required': [], 'optional': [],
|
||
|
'objects': [ref_obj, obj_desc], 'graph': graph_item}]
|
||
|
|
||
|
elif ('count-obj-rel' in template['label'] or
|
||
|
'exist-obj-rel' in template['label']):
|
||
|
# placeholder for object description, see below
|
||
|
obj_desc = None
|
||
|
prev_history = graph['history'][-1]
|
||
|
|
||
|
# we need a single object in the previous round
|
||
|
if 'imm2' in template['label']:
|
||
|
# we need a obj-rel-imm in previous label, same as the current one
|
||
|
prev_label = prev_history['template']
|
||
|
cur_label = template['label']
|
||
|
if 'obj-rel-imm' not in prev_label or cur_label[:5] != prev_label[:5]:
|
||
|
return []
|
||
|
else:
|
||
|
focus_id = prev_history['focus_id']
|
||
|
|
||
|
elif 'imm' in template['label']:
|
||
|
# we need an immediate group in the previous round
|
||
|
if apply_immediate(prev_history):
|
||
|
focus_id = prev_history['objects'][0]['id']
|
||
|
else:
|
||
|
return []
|
||
|
|
||
|
elif 'early' in template['label']:
|
||
|
# search through history for an object with unique attribute
|
||
|
attr_counts = get_known_attribute_counts(graph)
|
||
|
|
||
|
# get attributes with just one count
|
||
|
single_count = [ii for ii, count in attr_counts.items() if count == 1]
|
||
|
# remove attributes that point to objects in the previous round
|
||
|
# TODO: re-think this again
|
||
|
obj_ids = get_unique_attribute_objects(graph, single_count)
|
||
|
prev_history_obj_ids = [ii['id'] for ii in prev_history['objects']]
|
||
|
single_count = [ii for ii in single_count if
|
||
|
obj_ids[ii] not in prev_history_obj_ids]
|
||
|
|
||
|
if len(single_count) == 0:
|
||
|
return []
|
||
|
focus_attr = random.choice(single_count)
|
||
|
for focus_id, obj in graph['objects'].items():
|
||
|
if obj.get(focus_attr[0], None) == focus_attr[1]:
|
||
|
break
|
||
|
|
||
|
# unique object description
|
||
|
obj_desc = {'required': [focus_attr[0]], 'optional': [],
|
||
|
focus_attr[0]: focus_attr[1]}
|
||
|
|
||
|
# get relations with unknown counts
|
||
|
unknown_rels = [rel for rel in gvars.METAINFO['relations']
|
||
|
if '%s_count' % rel not in graph['objects'][focus_id]]
|
||
|
# for count: only if existence if True, else count it trivially zero
|
||
|
if 'count' in template['label']:
|
||
|
for ii in unknown_rels[::-1]:
|
||
|
if not graph['objects'][focus_id].get('%s_exist' % ii, True):
|
||
|
unknown_rels.remove(ii)
|
||
|
|
||
|
# for exist: get relations without exist before
|
||
|
elif 'exist' in template['label']:
|
||
|
unknown_rels = [rel for rel in unknown_rels
|
||
|
if '%s_exist' % rel not in graph['objects'][focus_id]]
|
||
|
|
||
|
# select an object with some known objects
|
||
|
if len(unknown_rels) == 0:
|
||
|
return []
|
||
|
|
||
|
# pick between yes/no for exist questions, 50% of times
|
||
|
if 'exist' in template['label']:
|
||
|
zero_count = [ii for ii in unknown_rels
|
||
|
if len(scene['relationships'][ii][focus_id]) == 0]
|
||
|
nonzero_count = [ii for ii in unknown_rels
|
||
|
if len(scene['relationships'][ii][focus_id]) > 0]
|
||
|
|
||
|
if random.random() > 0.5:
|
||
|
if len(zero_count) > 0:
|
||
|
rel = random.choice(zero_count)
|
||
|
else:
|
||
|
return []
|
||
|
else:
|
||
|
if len(nonzero_count) > 0:
|
||
|
rel = random.choice(nonzero_count)
|
||
|
else:
|
||
|
return []
|
||
|
else:
|
||
|
rel = random.choice(unknown_rels)
|
||
|
|
||
|
# create the object group
|
||
|
obj_group = []
|
||
|
new_obj = {'required': ['relation'], 'optional': []}
|
||
|
obj_pool = scene['relationships'][rel][focus_id]
|
||
|
for obj_id in obj_pool:
|
||
|
obj_copy = copy.deepcopy(new_obj)
|
||
|
obj_copy['id'] = obj_id
|
||
|
obj_group.append(obj_copy)
|
||
|
answer = len(obj_pool)
|
||
|
|
||
|
ref_obj = copy.deepcopy(new_obj)
|
||
|
ref_obj['id'] = focus_id
|
||
|
ref_obj['volatile'] = True
|
||
|
if 'exist' in template['label']:
|
||
|
answer = 'yes' if answer > 0 else 'no'
|
||
|
ref_obj['%s_exist' % rel] = answer
|
||
|
elif 'count' in template['label']:
|
||
|
ref_obj['%s_count' % rel] = answer
|
||
|
obj_group.append(ref_obj)
|
||
|
|
||
|
graph_item = {'round': ques_round+1, 'objects': copy.deepcopy(obj_group),
|
||
|
'template': template['label'], 'mergeable': True,
|
||
|
'focus_id': focus_id, 'focus_desc': obj_desc}
|
||
|
if 'count' in template['label']:
|
||
|
graph_item['count'] = answer
|
||
|
graph_item = clean_graph_item(graph_item)
|
||
|
|
||
|
#ref_obj['relation'] = rel
|
||
|
# add attribute as argument
|
||
|
return [{'answer': answer, 'group_id': ques_round + 1,
|
||
|
'required': [], 'optional': [], 'relation': rel,
|
||
|
'objects': [ref_obj, obj_desc], 'graph': graph_item}]
|
||
|
|
||
|
elif ('count-attribute' in template['label'] or
|
||
|
'exist-attribute' in template['label']):
|
||
|
if 'group' in template['label']:
|
||
|
# we need an immediate group in the previous round
|
||
|
prev_history = graph['history'][-1]
|
||
|
prev_label = prev_history['template']
|
||
|
|
||
|
# if exist: > 0 is good, else > 1 is needed
|
||
|
min_count = 0 if 'exist' in prev_label else 1
|
||
|
if (len(prev_history['objects']) > min_count and
|
||
|
prev_history['mergeable'] and
|
||
|
'obj-relation' not in prev_label):
|
||
|
obj_pool = graph['history'][-1]['objects']
|
||
|
else:
|
||
|
return []
|
||
|
else:
|
||
|
obj_pool = scene['objects']
|
||
|
|
||
|
# get counts for attributes, and sample evenly with 0 and other numbers
|
||
|
counts = get_attribute_counts_for_objects(scene, obj_pool)
|
||
|
|
||
|
# if exist, choose between zero and others wiht 0.5 probability
|
||
|
zero_prob = 0.5 if 'exist' in template['label'] else 0.7
|
||
|
if random.random() > zero_prob:
|
||
|
pool = [ii for ii in counts if counts[ii] == 0]
|
||
|
else:
|
||
|
pool = [ii for ii in counts if counts[ii] != 0]
|
||
|
|
||
|
# check if count is already known
|
||
|
attr_pool = filter_attributes_with_known_counts(graph, pool)
|
||
|
|
||
|
# for exist: get known attributes and remove them
|
||
|
if 'exist' in template['label']:
|
||
|
known_attr = get_known_attributes(graph)
|
||
|
attr_pool = [ii for ii in attr_pool if ii not in known_attr]
|
||
|
|
||
|
# if non-empty, sample it
|
||
|
if len(attr_pool) == 0:
|
||
|
return []
|
||
|
|
||
|
attr, value = random.choice(attr_pool)
|
||
|
# add a hypothesi, and return the answer
|
||
|
count = 0
|
||
|
obj_group = []
|
||
|
new_obj = {attr: value, 'required': [attr], 'optional': []}
|
||
|
for index, obj in enumerate(obj_pool):
|
||
|
if scene['objects'][obj['id']][attr] == value:
|
||
|
obj_copy = copy.deepcopy(new_obj)
|
||
|
obj_copy['id'] = obj['id']
|
||
|
obj_group.append(obj_copy)
|
||
|
count += 1
|
||
|
|
||
|
graph_item = {'round': ques_round + 1, 'objects': copy.deepcopy(obj_group),
|
||
|
'template': template['label'], 'mergeable': True, attr: value}
|
||
|
|
||
|
if 'count' in template['label']:
|
||
|
graph_item['count'] = count
|
||
|
answer = count
|
||
|
elif 'exist' in template['label']:
|
||
|
answer = 'yes' if count > 0 else 'no'
|
||
|
# Clean graph item.
|
||
|
graph_item = clean_graph_item(graph_item)
|
||
|
if count == 0:
|
||
|
# Fake object group, to serve for arguments.
|
||
|
obj_group = [{attr: value, 'required': [attr], 'optional': []}]
|
||
|
|
||
|
return [{'answer': answer, 'group_id': ques_round + 1,
|
||
|
'required': [attr], 'optional': [],
|
||
|
'count': 9999, 'objects': obj_group, 'graph': graph_item}]
|
||
|
|
||
|
elif 'seek-attr-rel' in template['label']:
|
||
|
# Placeholder for object description, see below.
|
||
|
obj_desc = None
|
||
|
prev_history = graph['history'][-1]
|
||
|
|
||
|
if 'imm' in template['label']:
|
||
|
# we need an immediate group in the previous round
|
||
|
if apply_immediate(prev_history):
|
||
|
focus_id = prev_history['objects'][0]['id']
|
||
|
else:
|
||
|
return []
|
||
|
|
||
|
elif 'early' in template['label']:
|
||
|
# search through history for an object with unique attribute
|
||
|
attr_counts = get_known_attribute_counts(graph)
|
||
|
|
||
|
# get attributes with just one count
|
||
|
single_count = [ii for ii, count in attr_counts.items() if count == 1]
|
||
|
# remove attributes that point to objects in the previous round
|
||
|
# TODO: re-think this again
|
||
|
obj_ids = get_unique_attribute_objects(graph, single_count)
|
||
|
prev_history_obj_ids = [ii['id'] for ii in prev_history['objects']]
|
||
|
single_count = [ii for ii in single_count if
|
||
|
obj_ids[ii] not in prev_history_obj_ids]
|
||
|
if len(single_count) == 0:
|
||
|
return []
|
||
|
|
||
|
# give preference to attributes with multiple counts in scene graph
|
||
|
scene_counts = get_attribute_counts_for_objects(scene)
|
||
|
ambiguous_attrs = [ii for ii in single_count if scene_counts[ii] > 1]
|
||
|
if len(ambiguous_attrs) > 0:
|
||
|
focus_attr = random.choice(ambiguous_attrs)
|
||
|
else:
|
||
|
focus_attr = random.choice(single_count)
|
||
|
focus_id = obj_ids[focus_attr]
|
||
|
|
||
|
# unique object description
|
||
|
obj_desc = {'required': [focus_attr[0]], 'optional': [],
|
||
|
focus_attr[0]: focus_attr[1]}
|
||
|
|
||
|
# for each relation, get the object, sample an attribute, and sample
|
||
|
hypotheses = []
|
||
|
for rel in gvars.METAINFO['relations']:
|
||
|
gt_relations = scene['relationships'][rel]
|
||
|
objs = [(ii, len(gt_relations[ii])) for ii in gt_relations[focus_id]]
|
||
|
objs = sorted(objs, key=lambda x: x[1], reverse=True)
|
||
|
if len(objs) == 0:
|
||
|
# add a null hypotheses
|
||
|
# check if the object is known to be extreme
|
||
|
if ('%s_count' % rel not in graph['objects'][focus_id] and
|
||
|
'%s_exist' % rel not in graph['objects'][focus_id]):
|
||
|
random_attr = random.choice(gvars.METAINFO['attributes'])
|
||
|
hypotheses.append((None, rel, random_attr))
|
||
|
continue
|
||
|
|
||
|
closest_obj = objs[0][0]
|
||
|
# check what attributes are known/unknown
|
||
|
known_info = graph['objects'].get(closest_obj, {})
|
||
|
for attr in gvars.METAINFO['attributes']:
|
||
|
if attr not in known_info:
|
||
|
hypotheses.append((closest_obj, rel, attr))
|
||
|
|
||
|
if len(hypotheses) == 0:
|
||
|
return []
|
||
|
sample_id, rel, attr = random.choice(hypotheses)
|
||
|
# add the new attribute to object
|
||
|
new_obj = {'required': ['attribute', 'relation'],
|
||
|
'optional': [], 'id': sample_id}
|
||
|
|
||
|
if sample_id is not None:
|
||
|
answer = scene['objects'][sample_id][attr]
|
||
|
else:
|
||
|
answer = 'none'
|
||
|
new_obj[attr] = answer
|
||
|
|
||
|
graph_item = {'round': ques_round+1, 'objects': [copy.deepcopy(new_obj)],
|
||
|
'template': template['label'], 'mergeable': True,
|
||
|
'focus_id': focus_id, 'focus_desc': obj_desc}
|
||
|
# remove objects if none
|
||
|
if sample_id is None:
|
||
|
graph_item['objects'] = []
|
||
|
graph_item = clean_graph_item(graph_item)
|
||
|
|
||
|
# Add attribute as argument.
|
||
|
new_obj['attribute'] = attr
|
||
|
return [{'answer': new_obj[attr], 'group_id': ques_round + 1,
|
||
|
'required': [], 'optional': [], 'relation': rel,
|
||
|
'objects': [new_obj, obj_desc], 'graph': graph_item}]
|
||
|
|
||
|
elif 'seek-attr' in template['label']:
|
||
|
# placeholder for object description, see below
|
||
|
obj_desc = None
|
||
|
prev_history = graph['history'][-1]
|
||
|
prev_label = prev_history['template']
|
||
|
implicit_attr = None
|
||
|
|
||
|
# we need a single object in the previous round
|
||
|
if 'imm2' in template['label']:
|
||
|
# we need a seek-attr-imm/seek-attr-rel-imm in previous label
|
||
|
if ('seek-attr-imm' not in prev_label and
|
||
|
'seek-attr-rel-imm' not in prev_label):
|
||
|
return []
|
||
|
elif len(prev_history['objects']) == 0:
|
||
|
return []
|
||
|
else:
|
||
|
focus_id = prev_history['objects'][0]['id']
|
||
|
|
||
|
elif 'imm' in template['label']:
|
||
|
# we need an immediate group in the previous round
|
||
|
if apply_immediate(prev_history):
|
||
|
focus_id = prev_history['objects'][0]['id']
|
||
|
else:
|
||
|
return []
|
||
|
|
||
|
elif 'sim' in template['label']:
|
||
|
if 'seek-attr-imm' not in prev_label:
|
||
|
return[]
|
||
|
else:
|
||
|
prev_obj = prev_history['objects'][0]
|
||
|
focus_id = prev_obj['id']
|
||
|
attr = [ii for ii in gvars.METAINFO['attributes'] if ii in prev_obj]
|
||
|
assert len(attr) == 1, 'Something wrong in previous history!'
|
||
|
implicit_attr = attr[0]
|
||
|
|
||
|
if 'early' in template['label']:
|
||
|
# search through history for an object with unique attribute
|
||
|
attr_counts = get_known_attribute_counts(graph)
|
||
|
|
||
|
# get attributes with just one count
|
||
|
single_count = [ii for ii, count in attr_counts.items() if count == 1]
|
||
|
# remove attributes that point to objects in the previous round
|
||
|
# TODO: re-think this again
|
||
|
obj_ids = get_unique_attribute_objects(graph, single_count)
|
||
|
prev_history_obj_ids = [ii['id'] for ii in prev_history['objects']]
|
||
|
single_count = [ii for ii in single_count if
|
||
|
obj_ids[ii] not in prev_history_obj_ids]
|
||
|
|
||
|
# if there is an attribute, eliminate those options
|
||
|
if implicit_attr is not None:
|
||
|
single_count = [ii for ii in single_count if ii[0] != implicit_attr]
|
||
|
obj_ids = get_unique_attribute_objects(graph, single_count)
|
||
|
|
||
|
# again rule out objects whose implicit_attr is known
|
||
|
single_count = [ii for ii in single_count
|
||
|
if implicit_attr not in graph['objects'][obj_ids[ii]]]
|
||
|
|
||
|
if len(single_count) == 0:
|
||
|
return []
|
||
|
|
||
|
# give preference to attributes with multiple counts in scene graph
|
||
|
scene_counts = get_attribute_counts_for_objects(scene)
|
||
|
ambiguous_attrs = [ii for ii in single_count if scene_counts[ii] > 1]
|
||
|
if len(ambiguous_attrs) > 0:
|
||
|
focus_attr = random.choice(ambiguous_attrs)
|
||
|
else:
|
||
|
focus_attr = random.choice(single_count)
|
||
|
focus_id = get_unique_attribute_objects(graph, [focus_attr])[focus_attr]
|
||
|
|
||
|
# unique object description
|
||
|
obj_desc = {'required': [focus_attr[0]], 'optional': [],
|
||
|
focus_attr[0]: focus_attr[1]}
|
||
|
|
||
|
# get unknown attributes, randomly sample one
|
||
|
if implicit_attr is None:
|
||
|
unknown_attrs = [attr for attr in gvars.METAINFO['attributes']
|
||
|
if attr not in graph['objects'][focus_id]]
|
||
|
|
||
|
# TODO: select an object with some known objects
|
||
|
if len(unknown_attrs) == 0:
|
||
|
return []
|
||
|
attr = random.choice(unknown_attrs)
|
||
|
else:
|
||
|
attr = implicit_attr
|
||
|
|
||
|
# add the new attribute to object
|
||
|
new_obj = {'required': ['attribute'], 'optional': [], 'id': focus_id}
|
||
|
if 'sim' in template['label']:
|
||
|
new_obj['required'] = []
|
||
|
new_obj[attr] = scene['objects'][focus_id][attr]
|
||
|
|
||
|
graph_item = {'round': ques_round+1, 'objects': [copy.deepcopy(new_obj)],
|
||
|
'template': template['label'], 'mergeable': True,
|
||
|
'focus_id': focus_id, 'focus_desc': obj_desc}
|
||
|
graph_item = clean_graph_item(graph_item)
|
||
|
|
||
|
# add attribute as argument
|
||
|
new_obj['attribute'] = attr
|
||
|
return [{'answer': new_obj[attr], 'group_id': ques_round + 1,
|
||
|
'required': [], 'optional': [],
|
||
|
'objects': [new_obj, obj_desc], 'graph': graph_item}]
|
||
|
return []
|
||
|
|
||
|
|
||
|
def sample_from_hypotheses(caption_hypotheses, scene, cap_templates):
|
||
|
"""Samples from caption hypotheses given the scene and caption templates.
|
||
|
Args:
|
||
|
caption_hypotheses: List of hypotheses for objects/object pairs
|
||
|
scene: CLEVR image scene graph
|
||
|
cap_templates: List of caption templates to sample captions
|
||
|
Returns:
|
||
|
obj_groups: List of object groups and corresponding sampled captions
|
||
|
"""
|
||
|
|
||
|
obj_groups = []
|
||
|
|
||
|
# Caption Type 1: Extreme location.
|
||
|
hypotheses = caption_hypotheses['extreme-loc']
|
||
|
if len(hypotheses) > 0:
|
||
|
# extreme location hypotheses
|
||
|
extreme_type, focus_obj = random.choice(hypotheses)
|
||
|
# sample optional attributes
|
||
|
obj_attrs = [attr for attr in gvars.METAINFO['attributes']
|
||
|
if attr in focus_obj]
|
||
|
focus_attr = random.choice(obj_attrs)
|
||
|
optional_attrs = [ii for ii in obj_attrs if ii != focus_attr]
|
||
|
sampled_attrs = sample_optional_tags(optional_attrs,
|
||
|
gvars.METAINFO['probabilities'])
|
||
|
|
||
|
# add additional attributes
|
||
|
req_attrs = sampled_attrs + [focus_attr]
|
||
|
filter_obj = {attr: val for attr, val in focus_obj.items()
|
||
|
if attr in req_attrs}
|
||
|
filter_obj['required'] = req_attrs
|
||
|
filter_obj['optional'] = req_attrs
|
||
|
filter_obj['id'] = focus_obj['id']
|
||
|
obj_group = {'required': req_attrs, 'optional': [], 'group_id': 0,
|
||
|
'objects': [filter_obj]}
|
||
|
|
||
|
# also create a clean graph object
|
||
|
graph_item = copy.deepcopy(obj_group)
|
||
|
graph_item = clean_graph_item(graph_item)
|
||
|
graph_item['mergeable'] = True
|
||
|
graph_item['objects'][0]['%s_count' % extreme_type] = 0
|
||
|
graph_item['objects'][0]['%s_exist' % extreme_type] = False
|
||
|
graph_item['template'] = 'extreme-%s' % extreme_type
|
||
|
obj_group['graph'] = graph_item
|
||
|
obj_groups.append([obj_group])
|
||
|
|
||
|
# Caption Type 2: Unique object.
|
||
|
hypotheses = caption_hypotheses['unique-obj']
|
||
|
if len(hypotheses) > 0:
|
||
|
# sample one at random, and create the graph item
|
||
|
focus_obj, focus_attr = random.choice(hypotheses)
|
||
|
# sample optional attributes
|
||
|
optional_attrs = [ii for ii in gvars.METAINFO['attributes']
|
||
|
if ii != focus_attr]
|
||
|
sampled_attrs = sample_optional_tags(optional_attrs,
|
||
|
gvars.METAINFO['probabilities'])
|
||
|
|
||
|
# add additional attributes
|
||
|
req_attrs = sampled_attrs + [focus_attr]
|
||
|
filter_obj = {attr: val for attr, val in focus_obj.items()
|
||
|
if attr in req_attrs}
|
||
|
filter_obj['required'] = req_attrs
|
||
|
filter_obj['optional'] = req_attrs
|
||
|
filter_obj['id'] = focus_obj['id']
|
||
|
obj_group = {'required': req_attrs, 'optional': [], 'group_id': 0,
|
||
|
'objects': [filter_obj]}
|
||
|
|
||
|
# also create a clean graph object
|
||
|
graph_item = copy.deepcopy(obj_group)
|
||
|
graph_item = clean_graph_item(graph_item)
|
||
|
graph_item['mergeable'] = True
|
||
|
graph_item['objects'][0]['unique'] = True
|
||
|
graph_item['template'] = 'unique-obj'
|
||
|
obj_group['graph'] = graph_item
|
||
|
obj_groups.append([obj_group])
|
||
|
|
||
|
# Caption Type 3: Unique attribute count based caption.
|
||
|
hypotheses = caption_hypotheses['count-attr']
|
||
|
if len(hypotheses) > 0:
|
||
|
# Randomly sample one hypothesis and one template.
|
||
|
(attr, value), count = random.choice(hypotheses)
|
||
|
# Segregate counting templates.
|
||
|
count_templates = [ii for ii in cap_templates if 'count' in ii['type']]
|
||
|
template = random.choice(count_templates)
|
||
|
obj_group = {'group_id': 0, 'count': count, attr: value,
|
||
|
'optional': [], 'required': [], 'objects': []}
|
||
|
|
||
|
# get a list of objects which are part of this collection
|
||
|
for ii, obj in enumerate(scene['objects']):
|
||
|
if obj[attr] == value:
|
||
|
new_obj = {'id': obj['id'], attr: value}
|
||
|
new_obj['required'] = [attr]
|
||
|
new_obj['optional'] = []
|
||
|
obj_group['objects'].append(new_obj)
|
||
|
|
||
|
if 'no' in template['label']:
|
||
|
# Count is not mentioned.
|
||
|
del obj_group['count']
|
||
|
graph_item = copy.deepcopy(obj_group)
|
||
|
graph_item['mergeable'] = False
|
||
|
else:
|
||
|
# Count is mentioned.
|
||
|
for index, ii in enumerate(obj_group['objects']):
|
||
|
obj_group['objects'][index]['required'].append('count')
|
||
|
graph_item = copy.deepcopy(obj_group)
|
||
|
graph_item['mergeable'] = True
|
||
|
|
||
|
# clean up graph item
|
||
|
graph_item['template'] = template['label']
|
||
|
graph_item = clean_graph_item(graph_item)
|
||
|
obj_group['graph'] = graph_item
|
||
|
obj_group['use_plural'] = True
|
||
|
obj_groups.append([obj_group])
|
||
|
|
||
|
# Caption Type 4: Relation between two objects (one of them is unique).
|
||
|
hypotheses = caption_hypotheses['obj-relation']
|
||
|
if len(hypotheses) > 0:
|
||
|
(obj_id1, attr1), rel, (obj_id2, attr2) = random.choice(hypotheses)
|
||
|
obj_group = {'group_id': 0, 'relation': rel}
|
||
|
|
||
|
# create object dictionaries
|
||
|
obj1 = {'optional': [], 'required': [attr1], 'id': obj_id1,
|
||
|
attr1: scene['objects'][obj_id1][attr1]}
|
||
|
obj2 = {'optional': [], 'required': [attr2], 'id': obj_id2,
|
||
|
attr2: scene['objects'][obj_id2][attr2]}
|
||
|
obj_group['objects'] = [obj2, obj1]
|
||
|
|
||
|
# also create a clean graph object
|
||
|
graph_item = copy.deepcopy(obj_group)
|
||
|
graph_item = clean_graph_item(graph_item)
|
||
|
graph_item['mergeable'] = True
|
||
|
graph_item['template'] = 'obj-relation'
|
||
|
obj_group['graph'] = graph_item
|
||
|
obj_groups.append([obj_group])
|
||
|
return obj_groups
|
||
|
|
||
|
|
||
|
def get_known_attributes(graph):
|
||
|
"""Fetches a list of known attributes given the scene graph.
|
||
|
Args:
|
||
|
graph: Scene graph to check unique attributes from
|
||
|
Returns:
|
||
|
known_attrs: List of known attributes from the scene graph
|
||
|
"""
|
||
|
|
||
|
known_attrs = []
|
||
|
for obj_id, obj_info in graph['objects'].items():
|
||
|
# The attribute is unique already.
|
||
|
# if obj_info.get('unique', False): continue
|
||
|
for attr in gvars.METAINFO['attributes']:
|
||
|
if attr in obj_info:
|
||
|
known_attrs.append((attr, obj_info[attr]))
|
||
|
|
||
|
# also go over the groups
|
||
|
for ii in graph['history']:
|
||
|
# a group of objects, with unknown count
|
||
|
#if 'count' not in ii: continue
|
||
|
for attr in gvars.METAINFO['attributes']:
|
||
|
if attr in ii:
|
||
|
known_attrs.append((attr, ii[attr]))
|
||
|
known_attrs = list(set(known_attrs))
|
||
|
return known_attrs
|
||
|
|
||
|
|
||
|
def get_known_attribute_counts(graph):
|
||
|
"""Fetches a count of known attributes given the scene graph.
|
||
|
Calls get_known_attributes method internally.
|
||
|
Args:
|
||
|
graph: Scene graph to check unique attributes from
|
||
|
Returns:
|
||
|
counts: Count of known attributes from the scene graph
|
||
|
"""
|
||
|
|
||
|
known_attrs = get_known_attributes(graph)
|
||
|
# Go through objects and count.
|
||
|
counts = {ii: 0 for ii in known_attrs}
|
||
|
for _, obj in graph['objects'].items():
|
||
|
for attr, val in known_attrs:
|
||
|
if obj.get(attr, None) == val:
|
||
|
counts[(attr, val)] += 1
|
||
|
return counts
|
||
|
|
||
|
|
||
|
def filter_attributes_with_known_counts(graph, known_attrs):
|
||
|
"""Filters attributes whose counts are known, given the scene graph.
|
||
|
Args:
|
||
|
graph: Scene graph from the dialog generated so far
|
||
|
known_attrs: List of known attributes from the ground truth scene graph
|
||
|
Returns:
|
||
|
known_attrs: List of attributes with unknown counts removed inplace
|
||
|
"""
|
||
|
|
||
|
for attr, val in known_attrs[::-1]:
|
||
|
for ii in graph['history']:
|
||
|
# A group of objects, with unknown count.
|
||
|
if 'count' not in ii:
|
||
|
continue
|
||
|
# Count is absent.
|
||
|
if ii.get(attr, None) == val:
|
||
|
known_attrs.remove((attr, val))
|
||
|
return known_attrs
|
||
|
|
||
|
|
||
|
def clean_graph_item(graph_item):
|
||
|
"""Cleans up graph item (remove 'required' and 'optional' tags).
|
||
|
Args:
|
||
|
graph_item: Input graph item to be cleaned.
|
||
|
Returns:
|
||
|
clean_graph_item: Copy of the graph item after cleaning.
|
||
|
"""
|
||
|
|
||
|
clean_graph_item = copy.deepcopy(graph_item)
|
||
|
if 'optional' in clean_graph_item:
|
||
|
del clean_graph_item['optional']
|
||
|
if 'required' in clean_graph_item:
|
||
|
del clean_graph_item['required']
|
||
|
|
||
|
for index, ii in enumerate(clean_graph_item['objects']):
|
||
|
if 'optional' in ii:
|
||
|
del clean_graph_item['objects'][index]['optional']
|
||
|
if 'required' in ii:
|
||
|
del clean_graph_item['objects'][index]['required']
|
||
|
return clean_graph_item
|
||
|
|
||
|
|
||
|
def get_attribute_counts_for_objects(scene, objects=None):
|
||
|
"""Counts attributes for a given set of objects.
|
||
|
Args:
|
||
|
scene: Scene graph for the dialog generated so far
|
||
|
objects: List of objects. Default = None selects all objects
|
||
|
Returns:
|
||
|
counts: Counts for the attributes for attributes
|
||
|
"""
|
||
|
|
||
|
# Initialize the dictionary.
|
||
|
counts = {}
|
||
|
for attr, vals in gvars.METAINFO['values'].items():
|
||
|
for val in vals:
|
||
|
counts[(attr, val)] = 0
|
||
|
|
||
|
# Now count for each given object.
|
||
|
if objects is None:
|
||
|
objects = scene['objects']
|
||
|
for obj in objects:
|
||
|
for attr in gvars.METAINFO['attributes']:
|
||
|
key = (attr, scene['objects'][obj['id']][attr])
|
||
|
counts[key] = counts.get(key, 0) + 1
|
||
|
return counts
|
||
|
|
||
|
|
||
|
def get_unique_attribute_objects(graph, uniq_attrs):
|
||
|
"""Fetches objects from given scene graph with unique attributes.
|
||
|
Args:
|
||
|
graph: Scene graph constructed from the dialog generated so far
|
||
|
uniq_attrs: List of unique attributes to get attributes
|
||
|
Returns:
|
||
|
obj_ids: List of object ids with the unique attributes
|
||
|
"""
|
||
|
|
||
|
obj_ids = {}
|
||
|
for obj_id, obj in graph['objects'].items():
|
||
|
for attr, val in uniq_attrs:
|
||
|
if obj.get(attr, '') == val:
|
||
|
# At this point the key should not be present.
|
||
|
assert (attr, val) not in obj_ids, 'Attributes not unique!'
|
||
|
obj_ids[(attr, val)] = obj_id
|
||
|
return obj_ids
|
||
|
|
||
|
|
||
|
def sample_optional_tags(optional, sample_probs):
|
||
|
"""Samples additional tags depending on given sample probabilities.
|
||
|
Args:
|
||
|
optional: List of optional tags to sample from.
|
||
|
sample_probs: Probabilities of sampling 'n' tags.
|
||
|
Returns:
|
||
|
sampled: Sampled tags from the optional list
|
||
|
"""
|
||
|
|
||
|
sampled = []
|
||
|
if len(optional) > 0:
|
||
|
n_sample = np.random.choice([0, 1], 1, p=sample_probs[:2])[0]
|
||
|
n_sample = min(n_sample, len(optional))
|
||
|
sampled = random.sample(optional, n_sample)
|
||
|
return sampled
|