"""Supporting script checks constraints for caption and question generation. Author: Satwik Kottur """ import copy import json import random import numpy as np import global_vars as gvars # Some quick methods. def apply_immediate(hist): return (len(hist['objects']) == 1 and hist['mergeable'] and 'exist' not in hist['template']) def apply_group(hist): return (len(hist['objects']) >= 2 and hist['mergeable'] and 'count' not in prev_group) def caption(scene, templates): """Constraints for caption generation. Args: scene: CLEVR Scene graphs to generate captions with constraints template: List of caption templates Returns: sample_captions: Samples from caption hypotheses """ caption_hypotheses = {} # Sweep through all templates to extract 'interesting' captions. n_objs = len(scene['objects']) rels = scene['relationships'] # Caption Type 1: Extreme locations. ext_loc_templates = [ii for ii in templates if ii['type'] == 'extreme-loc'] # number of objects in the scene filter_objs = copy.deepcopy(scene['objects']) attr_counts = get_attribute_counts_for_objects(scene, filter_objs) hypotheses = [] for template in ext_loc_templates: # absolute location based constraint constraint = template['constraints'][0] extreme_type = constraint['args'][0] # check if there is an object that is at the center of the image # roughly in the middle along front-back and right-left dim if extreme_type == 'center': for ii, obj in enumerate(filter_objs): bla = [len(rels[kk][ii]) <= n_objs / 2 for kk in ['front', 'behind', 'right', 'left']] matches = np.sum([len(rels[kk][ii]) <= n_objs / 2 for kk in ['front', 'behind', 'right', 'left']]) if matches == 4: hypotheses.append((extreme_type, copy.deepcopy(obj))) else: for ii, obj in enumerate(filter_objs): if len(rels[extreme_type][ii]) == 0: hypotheses.append((extreme_type, copy.deepcopy(obj))) # sample one at random, and create the graph item # Filter hypothesis which are ambiguous otherwise. for index, (_, hypothesis) in enumerate(hypotheses): uniq_attr = [attr for attr in gvars.METAINFO['attributes'] if attr_counts[(attr, hypothesis[attr])] == 1] for attr in uniq_attr: del hypotheses[index][1][attr] hypotheses = [ii for ii in hypotheses if len(ii[1]) > 1] caption_hypotheses['extreme-loc'] = hypotheses # Caption Type 2: Unique object and attribute. filter_objs = copy.deepcopy(scene['objects']) # each hypothesis is (object, attribute) pair hypotheses = [] for ii, obj in enumerate(filter_objs): # get unique set of attributes uniq_attrs = [ii for ii in gvars.METAINFO['attributes'] if attr_counts[(ii, obj[ii])] == 1] # for each, add it to hypothesis for attr in uniq_attrs: hypotheses.append((obj, attr)) caption_hypotheses['unique-obj'] = hypotheses # Caption Type 3: Unique attribute count based caption. # count unique object based constraint # Each hypothesis is object collection. caption_hypotheses['count-attr'] = [(attr_val, count) for attr_val, count in attr_counts.items() if count > 1] # Caption Type 4: Relation between two objects. # Out of the two, one has a unique attribute. # find a pair of objects sharing a relation, unique filter_objs = copy.deepcopy(scene['objects']) n_objs = len(filter_objs) # get a dict of unique attributes for each object uniq_attr = [[] for ii in range(n_objs)] non_uniq_attr = [[] for ii in range(n_objs)] for ind, obj in enumerate(filter_objs): uniq_attr[ind] = [attr for attr in gvars.METAINFO['attributes'] if attr_counts[(attr, obj[attr])] == 1] non_uniq_attr[ind] = [attr for attr in gvars.METAINFO['attributes'] if attr_counts[(attr, obj[attr])] > 1] uniqueness = [len(ii) > 0 for ii in uniq_attr] # Hypothesis is a uniq object and non-unique obj2 sharing relation R # global ordering for uniqueness hypotheses = [] for rel, order in scene['relationships'].items(): num_rel = [(ii, len(order[ii])) for ii in range(n_objs)] num_rel = sorted(num_rel, key=lambda x: x[1], reverse=True) # take only the ids num_rel = [ii[0] for ii in num_rel] for index, obj_id in enumerate(num_rel[:-1]): next_obj_id = num_rel[index + 1] # if unique, check if the next one has non-unique attributes if uniqueness[obj_id]: if len(non_uniq_attr[next_obj_id]) > 0: obj1 = (obj_id, random.choice(uniq_attr[obj_id])) obj2 = (next_obj_id, random.choice(non_uniq_attr[next_obj_id])) hypotheses.append((obj1, rel, obj2)) # if not unique, check if the next one has unique attributes else: if len(uniq_attr[next_obj_id]) > 0: obj1 = (obj_id, random.choice(non_uniq_attr[obj_id])) obj2 = (next_obj_id, random.choice(uniq_attr[next_obj_id])) hypotheses.append((obj1, rel, obj2)) caption_hypotheses['obj-relation'] = hypotheses sample_captions = sample_from_hypotheses( caption_hypotheses, scene, templates) return sample_captions def question(scene, dialog, template): """Constraints question generation. Inputs: scene:Partial scene graphs on CLEVR images with generated captions template: List of question templates to use Output: list of object groups """ ques_round = len(dialog['graph']['history']) - 1 graph = dialog['graph'] # check for constraints and answer question if 'group' in template['label']: groups = [] # Pick a group hypothesis for ii in graph['history']: if 'count' in ii or len(ii['objects']) == 0: groups.append(ii) if template['label'] == 'count-all': # Preliminary checks: # (A) count-all cannot follow count-all, count-other for prev_history in graph['history'][1:]: if prev_history['template'] in ['count-all', 'count-other']: return [] # create object group obj_group = [] new_obj = {'required': [], 'optional': []} for obj_id, ii in enumerate(scene['objects']): obj_copy = copy.deepcopy(new_obj) obj_copy['id'] = ii['id'] obj_group.append(obj_copy) # create graph item graph_item = {'round': ques_round + 1, 'objects': copy.deepcopy(obj_group), 'template': template['label'], 'mergeable': True, 'count': len(obj_group)} # clean graph item graph_item = clean_graph_item(graph_item) # no constraints, count the number of objects in true scene return [{'answer': len(obj_group), 'group_id': ques_round + 1, 'objects': [], 'graph': graph_item}] elif (template['label'] == 'count-other' or template['label'] == 'exist-other'): # preliminary checks: # (A) exist-other cannot follow exist-other, count-all, count-other # (B) count-other cannot follow count-all, count-other for prev_history in graph['history'][1:]: if prev_history['template'] in ['count-all', 'count-other']: return [] if (prev_history['template'] == 'exist-other' and template['label'] == 'exist-other'): return [] # get a list of all objects we know known_ids = [jj['id'] for ii in graph['history'] for jj in ii['objects']] known_ids = list(set(known_ids)) n_objs = len(scene['objects']) difference = n_objs - len(known_ids) diff_ids = [ii for ii in range(n_objs) if ii not in known_ids] # create empty objects for these obj_group = [{'id': ii} for ii in diff_ids] # create graph item graph_item = {'round': ques_round + 1, 'objects': obj_group, 'template': template['label'], 'mergeable': False} if 'count' in template['label']: graph_item['count'] = difference graph_item['mergeable'] = True # merge if count is known answer = difference elif 'exist' in template['label']: # If heads (> 0.5) -- difference > 0 if random.random() > 0.5: if difference > 0: answer = 'yes' else: return [] else: if difference == 0: answer = 'no' else: return [] # no constraints, count the number of objects in true scene return [{'answer': answer, 'group_id': ques_round + 1, 'objects': [], 'graph': graph_item}] elif template['label'] == 'count-all-group': # we need a group in the previous round prev_group = graph['history'][-1] prev_label = prev_group['template'] if not (len(prev_group['objects']) > 1 and 'count' not in prev_group and 'obj-relation' not in prev_label): return [] # check if count is not given before attrs = [ii for ii in gvars.METAINFO['attributes'] if ii in prev_group] count = 0 for obj in prev_group['objects']: count += all([obj[ii] == prev_group['objects'][0][ii] for ii in attrs]) # create object group obj_group = [] new_obj = {'required': [], 'optional': []} for obj_id, ii in enumerate(scene['objects']): obj_copy = copy.deepcopy(new_obj) obj_copy['id'] = ii['id'] obj_group.append(obj_copy) # create graph item graph_item = {'round': ques_round + 1, 'objects': copy.deepcopy(obj_group), 'template': template['label'], 'mergeable': True, 'count': count} # clean graph item graph_item = clean_graph_item(graph_item) # no constraints, count the number of objects in true scene return [{'answer': count, 'group_id': ques_round + 1, 'objects': [], 'graph': graph_item}] elif ('count-obj-exclude' in template['label'] or 'exist-obj-exclude' in template['label']): # placeholder for object description, see below obj_desc = None prev_history = graph['history'][-1] scene_counts = get_attribute_counts_for_objects(scene) if 'imm' in template['label']: # we need an immediate group in the previous round if apply_immediate(prev_history): focus_id = prev_history['objects'][0]['id'] else: return [] elif 'early' in template['label']: # search through history for an object with unique attribute attr_counts = get_known_attribute_counts(graph) # get attributes with just one count single_count = [ii for ii, count in attr_counts.items() if count == 1] # remove attributes that point to objects in the previous round # TODO: re-think this again obj_ids = get_unique_attribute_objects(graph, single_count) prev_history_obj_ids = [ii['id'] for ii in prev_history['objects']] single_count = [ii for ii in single_count if obj_ids[ii] not in prev_history_obj_ids] if len(single_count) == 0: return [] # give preference to attributes with multiple counts in scene graph #scene_counts = get_attribute_counts_for_objects(scene) ambiguous_attrs = [ii for ii in single_count if scene_counts[ii] > 1] if len(ambiguous_attrs) > 0: focus_attr = random.choice(ambiguous_attrs) else: focus_attr = random.choice(single_count) focus_id = obj_ids[focus_attr] # unique object description obj_desc = {'required': [focus_attr[0]], 'optional': [], focus_attr[0]: focus_attr[1]} # get the known attributes for the current object focus_obj = graph['objects'][focus_id] known_attrs = [attr for attr in gvars.METAINFO['attributes'] if attr in focus_obj and '%s_exclude_count' % attr not in focus_obj] # for count: only if existence if True, else count it trivially zero if 'count' in template['label']: for attr in known_attrs[::-1]: if not focus_obj.get('%s_exclude_exist' % attr, True): known_attrs.remove(attr) # for exist: get relations without exist before elif 'exist' in template['label']: known_attrs = [attr for attr in known_attrs if '%s_exclude_exist' % attr not in focus_obj] # select an attribute if len(known_attrs) == 0: return[] # split this into zero and non-zero if 'exist' in template['label']: focus_attrs = [(ii, scene['objects'][focus_id][ii]) for ii in known_attrs] zero_count = [ii for ii in focus_attrs if scene_counts[ii] == 1] nonzero_count = [ii for ii in focus_attrs if scene_counts[ii] > 1] if random.random() > 0.5: if len(zero_count) > 0: attr = random.choice(zero_count)[0] else: return [] else: if len(nonzero_count) > 0: attr = random.choice(nonzero_count)[0] else: return [] else: attr = random.choice(known_attrs) # create the object group obj_group = [] new_obj = {'required': ['attribute'], 'optional': []} for obj in scene['objects']: # add if same attribute value and not focus object if obj[attr] == focus_obj[attr] and obj['id'] != focus_id: obj_copy = copy.deepcopy(new_obj) obj_copy['id'] = obj['id'] obj_copy[attr] = focus_obj[attr] obj_group.append(obj_copy) answer = len(obj_group) ref_obj = copy.deepcopy(new_obj) ref_obj['id'] = focus_id ref_obj['volatile'] = True if 'exist' in template['label']: answer = 'yes' if answer > 0 else 'no' ref_obj['%s_exclude_exist' % attr] = answer elif 'count' in template['label']: ref_obj['%s_exclude_count' % attr] = answer obj_group.append(ref_obj) graph_item = {'round': ques_round+1, 'objects': copy.deepcopy(obj_group), 'template': template['label'], 'mergeable': True, 'focus_id': focus_id, 'focus_desc': obj_desc} if 'count' in template['label']: graph_item['count'] = answer graph_item = clean_graph_item(graph_item) ref_obj['attribute'] = attr return [{'answer': answer, 'group_id': ques_round + 1, 'required': [], 'optional': [], 'objects': [ref_obj, obj_desc], 'graph': graph_item}] elif ('count-obj-rel' in template['label'] or 'exist-obj-rel' in template['label']): # placeholder for object description, see below obj_desc = None prev_history = graph['history'][-1] # we need a single object in the previous round if 'imm2' in template['label']: # we need a obj-rel-imm in previous label, same as the current one prev_label = prev_history['template'] cur_label = template['label'] if 'obj-rel-imm' not in prev_label or cur_label[:5] != prev_label[:5]: return [] else: focus_id = prev_history['focus_id'] elif 'imm' in template['label']: # we need an immediate group in the previous round if apply_immediate(prev_history): focus_id = prev_history['objects'][0]['id'] else: return [] elif 'early' in template['label']: # search through history for an object with unique attribute attr_counts = get_known_attribute_counts(graph) # get attributes with just one count single_count = [ii for ii, count in attr_counts.items() if count == 1] # remove attributes that point to objects in the previous round # TODO: re-think this again obj_ids = get_unique_attribute_objects(graph, single_count) prev_history_obj_ids = [ii['id'] for ii in prev_history['objects']] single_count = [ii for ii in single_count if obj_ids[ii] not in prev_history_obj_ids] if len(single_count) == 0: return [] focus_attr = random.choice(single_count) for focus_id, obj in graph['objects'].items(): if obj.get(focus_attr[0], None) == focus_attr[1]: break # unique object description obj_desc = {'required': [focus_attr[0]], 'optional': [], focus_attr[0]: focus_attr[1]} # get relations with unknown counts unknown_rels = [rel for rel in gvars.METAINFO['relations'] if '%s_count' % rel not in graph['objects'][focus_id]] # for count: only if existence if True, else count it trivially zero if 'count' in template['label']: for ii in unknown_rels[::-1]: if not graph['objects'][focus_id].get('%s_exist' % ii, True): unknown_rels.remove(ii) # for exist: get relations without exist before elif 'exist' in template['label']: unknown_rels = [rel for rel in unknown_rels if '%s_exist' % rel not in graph['objects'][focus_id]] # select an object with some known objects if len(unknown_rels) == 0: return [] # pick between yes/no for exist questions, 50% of times if 'exist' in template['label']: zero_count = [ii for ii in unknown_rels if len(scene['relationships'][ii][focus_id]) == 0] nonzero_count = [ii for ii in unknown_rels if len(scene['relationships'][ii][focus_id]) > 0] if random.random() > 0.5: if len(zero_count) > 0: rel = random.choice(zero_count) else: return [] else: if len(nonzero_count) > 0: rel = random.choice(nonzero_count) else: return [] else: rel = random.choice(unknown_rels) # create the object group obj_group = [] new_obj = {'required': ['relation'], 'optional': []} obj_pool = scene['relationships'][rel][focus_id] for obj_id in obj_pool: obj_copy = copy.deepcopy(new_obj) obj_copy['id'] = obj_id obj_group.append(obj_copy) answer = len(obj_pool) ref_obj = copy.deepcopy(new_obj) ref_obj['id'] = focus_id ref_obj['volatile'] = True if 'exist' in template['label']: answer = 'yes' if answer > 0 else 'no' ref_obj['%s_exist' % rel] = answer elif 'count' in template['label']: ref_obj['%s_count' % rel] = answer obj_group.append(ref_obj) graph_item = {'round': ques_round+1, 'objects': copy.deepcopy(obj_group), 'template': template['label'], 'mergeable': True, 'focus_id': focus_id, 'focus_desc': obj_desc} if 'count' in template['label']: graph_item['count'] = answer graph_item = clean_graph_item(graph_item) #ref_obj['relation'] = rel # add attribute as argument return [{'answer': answer, 'group_id': ques_round + 1, 'required': [], 'optional': [], 'relation': rel, 'objects': [ref_obj, obj_desc], 'graph': graph_item}] elif ('count-attribute' in template['label'] or 'exist-attribute' in template['label']): if 'group' in template['label']: # we need an immediate group in the previous round prev_history = graph['history'][-1] prev_label = prev_history['template'] # if exist: > 0 is good, else > 1 is needed min_count = 0 if 'exist' in prev_label else 1 if (len(prev_history['objects']) > min_count and prev_history['mergeable'] and 'obj-relation' not in prev_label): obj_pool = graph['history'][-1]['objects'] else: return [] else: obj_pool = scene['objects'] # get counts for attributes, and sample evenly with 0 and other numbers counts = get_attribute_counts_for_objects(scene, obj_pool) # if exist, choose between zero and others wiht 0.5 probability zero_prob = 0.5 if 'exist' in template['label'] else 0.7 if random.random() > zero_prob: pool = [ii for ii in counts if counts[ii] == 0] else: pool = [ii for ii in counts if counts[ii] != 0] # check if count is already known attr_pool = filter_attributes_with_known_counts(graph, pool) # for exist: get known attributes and remove them if 'exist' in template['label']: known_attr = get_known_attributes(graph) attr_pool = [ii for ii in attr_pool if ii not in known_attr] # if non-empty, sample it if len(attr_pool) == 0: return [] attr, value = random.choice(attr_pool) # add a hypothesi, and return the answer count = 0 obj_group = [] new_obj = {attr: value, 'required': [attr], 'optional': []} for index, obj in enumerate(obj_pool): if scene['objects'][obj['id']][attr] == value: obj_copy = copy.deepcopy(new_obj) obj_copy['id'] = obj['id'] obj_group.append(obj_copy) count += 1 graph_item = {'round': ques_round + 1, 'objects': copy.deepcopy(obj_group), 'template': template['label'], 'mergeable': True, attr: value} if 'count' in template['label']: graph_item['count'] = count answer = count elif 'exist' in template['label']: answer = 'yes' if count > 0 else 'no' # Clean graph item. graph_item = clean_graph_item(graph_item) if count == 0: # Fake object group, to serve for arguments. obj_group = [{attr: value, 'required': [attr], 'optional': []}] return [{'answer': answer, 'group_id': ques_round + 1, 'required': [attr], 'optional': [], 'count': 9999, 'objects': obj_group, 'graph': graph_item}] elif 'seek-attr-rel' in template['label']: # Placeholder for object description, see below. obj_desc = None prev_history = graph['history'][-1] if 'imm' in template['label']: # we need an immediate group in the previous round if apply_immediate(prev_history): focus_id = prev_history['objects'][0]['id'] else: return [] elif 'early' in template['label']: # search through history for an object with unique attribute attr_counts = get_known_attribute_counts(graph) # get attributes with just one count single_count = [ii for ii, count in attr_counts.items() if count == 1] # remove attributes that point to objects in the previous round # TODO: re-think this again obj_ids = get_unique_attribute_objects(graph, single_count) prev_history_obj_ids = [ii['id'] for ii in prev_history['objects']] single_count = [ii for ii in single_count if obj_ids[ii] not in prev_history_obj_ids] if len(single_count) == 0: return [] # give preference to attributes with multiple counts in scene graph scene_counts = get_attribute_counts_for_objects(scene) ambiguous_attrs = [ii for ii in single_count if scene_counts[ii] > 1] if len(ambiguous_attrs) > 0: focus_attr = random.choice(ambiguous_attrs) else: focus_attr = random.choice(single_count) focus_id = obj_ids[focus_attr] # unique object description obj_desc = {'required': [focus_attr[0]], 'optional': [], focus_attr[0]: focus_attr[1]} # for each relation, get the object, sample an attribute, and sample hypotheses = [] for rel in gvars.METAINFO['relations']: gt_relations = scene['relationships'][rel] objs = [(ii, len(gt_relations[ii])) for ii in gt_relations[focus_id]] objs = sorted(objs, key=lambda x: x[1], reverse=True) if len(objs) == 0: # add a null hypotheses # check if the object is known to be extreme if ('%s_count' % rel not in graph['objects'][focus_id] and '%s_exist' % rel not in graph['objects'][focus_id]): random_attr = random.choice(gvars.METAINFO['attributes']) hypotheses.append((None, rel, random_attr)) continue closest_obj = objs[0][0] # check what attributes are known/unknown known_info = graph['objects'].get(closest_obj, {}) for attr in gvars.METAINFO['attributes']: if attr not in known_info: hypotheses.append((closest_obj, rel, attr)) if len(hypotheses) == 0: return [] sample_id, rel, attr = random.choice(hypotheses) # add the new attribute to object new_obj = {'required': ['attribute', 'relation'], 'optional': [], 'id': sample_id} if sample_id is not None: answer = scene['objects'][sample_id][attr] else: answer = 'none' new_obj[attr] = answer graph_item = {'round': ques_round+1, 'objects': [copy.deepcopy(new_obj)], 'template': template['label'], 'mergeable': True, 'focus_id': focus_id, 'focus_desc': obj_desc} # remove objects if none if sample_id is None: graph_item['objects'] = [] graph_item = clean_graph_item(graph_item) # Add attribute as argument. new_obj['attribute'] = attr return [{'answer': new_obj[attr], 'group_id': ques_round + 1, 'required': [], 'optional': [], 'relation': rel, 'objects': [new_obj, obj_desc], 'graph': graph_item}] elif 'seek-attr' in template['label']: # placeholder for object description, see below obj_desc = None prev_history = graph['history'][-1] prev_label = prev_history['template'] implicit_attr = None # we need a single object in the previous round if 'imm2' in template['label']: # we need a seek-attr-imm/seek-attr-rel-imm in previous label if ('seek-attr-imm' not in prev_label and 'seek-attr-rel-imm' not in prev_label): return [] elif len(prev_history['objects']) == 0: return [] else: focus_id = prev_history['objects'][0]['id'] elif 'imm' in template['label']: # we need an immediate group in the previous round if apply_immediate(prev_history): focus_id = prev_history['objects'][0]['id'] else: return [] elif 'sim' in template['label']: if 'seek-attr-imm' not in prev_label: return[] else: prev_obj = prev_history['objects'][0] focus_id = prev_obj['id'] attr = [ii for ii in gvars.METAINFO['attributes'] if ii in prev_obj] assert len(attr) == 1, 'Something wrong in previous history!' implicit_attr = attr[0] if 'early' in template['label']: # search through history for an object with unique attribute attr_counts = get_known_attribute_counts(graph) # get attributes with just one count single_count = [ii for ii, count in attr_counts.items() if count == 1] # remove attributes that point to objects in the previous round # TODO: re-think this again obj_ids = get_unique_attribute_objects(graph, single_count) prev_history_obj_ids = [ii['id'] for ii in prev_history['objects']] single_count = [ii for ii in single_count if obj_ids[ii] not in prev_history_obj_ids] # if there is an attribute, eliminate those options if implicit_attr is not None: single_count = [ii for ii in single_count if ii[0] != implicit_attr] obj_ids = get_unique_attribute_objects(graph, single_count) # again rule out objects whose implicit_attr is known single_count = [ii for ii in single_count if implicit_attr not in graph['objects'][obj_ids[ii]]] if len(single_count) == 0: return [] # give preference to attributes with multiple counts in scene graph scene_counts = get_attribute_counts_for_objects(scene) ambiguous_attrs = [ii for ii in single_count if scene_counts[ii] > 1] if len(ambiguous_attrs) > 0: focus_attr = random.choice(ambiguous_attrs) else: focus_attr = random.choice(single_count) focus_id = get_unique_attribute_objects(graph, [focus_attr])[focus_attr] # unique object description obj_desc = {'required': [focus_attr[0]], 'optional': [], focus_attr[0]: focus_attr[1]} # get unknown attributes, randomly sample one if implicit_attr is None: unknown_attrs = [attr for attr in gvars.METAINFO['attributes'] if attr not in graph['objects'][focus_id]] # TODO: select an object with some known objects if len(unknown_attrs) == 0: return [] attr = random.choice(unknown_attrs) else: attr = implicit_attr # add the new attribute to object new_obj = {'required': ['attribute'], 'optional': [], 'id': focus_id} if 'sim' in template['label']: new_obj['required'] = [] new_obj[attr] = scene['objects'][focus_id][attr] graph_item = {'round': ques_round+1, 'objects': [copy.deepcopy(new_obj)], 'template': template['label'], 'mergeable': True, 'focus_id': focus_id, 'focus_desc': obj_desc} graph_item = clean_graph_item(graph_item) # add attribute as argument new_obj['attribute'] = attr return [{'answer': new_obj[attr], 'group_id': ques_round + 1, 'required': [], 'optional': [], 'objects': [new_obj, obj_desc], 'graph': graph_item}] return [] def sample_from_hypotheses(caption_hypotheses, scene, cap_templates): """Samples from caption hypotheses given the scene and caption templates. Args: caption_hypotheses: List of hypotheses for objects/object pairs scene: CLEVR image scene graph cap_templates: List of caption templates to sample captions Returns: obj_groups: List of object groups and corresponding sampled captions """ obj_groups = [] # Caption Type 1: Extreme location. hypotheses = caption_hypotheses['extreme-loc'] if len(hypotheses) > 0: # extreme location hypotheses extreme_type, focus_obj = random.choice(hypotheses) # sample optional attributes obj_attrs = [attr for attr in gvars.METAINFO['attributes'] if attr in focus_obj] focus_attr = random.choice(obj_attrs) optional_attrs = [ii for ii in obj_attrs if ii != focus_attr] sampled_attrs = sample_optional_tags(optional_attrs, gvars.METAINFO['probabilities']) # add additional attributes req_attrs = sampled_attrs + [focus_attr] filter_obj = {attr: val for attr, val in focus_obj.items() if attr in req_attrs} filter_obj['required'] = req_attrs filter_obj['optional'] = req_attrs filter_obj['id'] = focus_obj['id'] obj_group = {'required': req_attrs, 'optional': [], 'group_id': 0, 'objects': [filter_obj]} # also create a clean graph object graph_item = copy.deepcopy(obj_group) graph_item = clean_graph_item(graph_item) graph_item['mergeable'] = True graph_item['objects'][0]['%s_count' % extreme_type] = 0 graph_item['objects'][0]['%s_exist' % extreme_type] = False graph_item['template'] = 'extreme-%s' % extreme_type obj_group['graph'] = graph_item obj_groups.append([obj_group]) # Caption Type 2: Unique object. hypotheses = caption_hypotheses['unique-obj'] if len(hypotheses) > 0: # sample one at random, and create the graph item focus_obj, focus_attr = random.choice(hypotheses) # sample optional attributes optional_attrs = [ii for ii in gvars.METAINFO['attributes'] if ii != focus_attr] sampled_attrs = sample_optional_tags(optional_attrs, gvars.METAINFO['probabilities']) # add additional attributes req_attrs = sampled_attrs + [focus_attr] filter_obj = {attr: val for attr, val in focus_obj.items() if attr in req_attrs} filter_obj['required'] = req_attrs filter_obj['optional'] = req_attrs filter_obj['id'] = focus_obj['id'] obj_group = {'required': req_attrs, 'optional': [], 'group_id': 0, 'objects': [filter_obj]} # also create a clean graph object graph_item = copy.deepcopy(obj_group) graph_item = clean_graph_item(graph_item) graph_item['mergeable'] = True graph_item['objects'][0]['unique'] = True graph_item['template'] = 'unique-obj' obj_group['graph'] = graph_item obj_groups.append([obj_group]) # Caption Type 3: Unique attribute count based caption. hypotheses = caption_hypotheses['count-attr'] if len(hypotheses) > 0: # Randomly sample one hypothesis and one template. (attr, value), count = random.choice(hypotheses) # Segregate counting templates. count_templates = [ii for ii in cap_templates if 'count' in ii['type']] template = random.choice(count_templates) obj_group = {'group_id': 0, 'count': count, attr: value, 'optional': [], 'required': [], 'objects': []} # get a list of objects which are part of this collection for ii, obj in enumerate(scene['objects']): if obj[attr] == value: new_obj = {'id': obj['id'], attr: value} new_obj['required'] = [attr] new_obj['optional'] = [] obj_group['objects'].append(new_obj) if 'no' in template['label']: # Count is not mentioned. del obj_group['count'] graph_item = copy.deepcopy(obj_group) graph_item['mergeable'] = False else: # Count is mentioned. for index, ii in enumerate(obj_group['objects']): obj_group['objects'][index]['required'].append('count') graph_item = copy.deepcopy(obj_group) graph_item['mergeable'] = True # clean up graph item graph_item['template'] = template['label'] graph_item = clean_graph_item(graph_item) obj_group['graph'] = graph_item obj_group['use_plural'] = True obj_groups.append([obj_group]) # Caption Type 4: Relation between two objects (one of them is unique). hypotheses = caption_hypotheses['obj-relation'] if len(hypotheses) > 0: (obj_id1, attr1), rel, (obj_id2, attr2) = random.choice(hypotheses) obj_group = {'group_id': 0, 'relation': rel} # create object dictionaries obj1 = {'optional': [], 'required': [attr1], 'id': obj_id1, attr1: scene['objects'][obj_id1][attr1]} obj2 = {'optional': [], 'required': [attr2], 'id': obj_id2, attr2: scene['objects'][obj_id2][attr2]} obj_group['objects'] = [obj2, obj1] # also create a clean graph object graph_item = copy.deepcopy(obj_group) graph_item = clean_graph_item(graph_item) graph_item['mergeable'] = True graph_item['template'] = 'obj-relation' obj_group['graph'] = graph_item obj_groups.append([obj_group]) return obj_groups def get_known_attributes(graph): """Fetches a list of known attributes given the scene graph. Args: graph: Scene graph to check unique attributes from Returns: known_attrs: List of known attributes from the scene graph """ known_attrs = [] for obj_id, obj_info in graph['objects'].items(): # The attribute is unique already. # if obj_info.get('unique', False): continue for attr in gvars.METAINFO['attributes']: if attr in obj_info: known_attrs.append((attr, obj_info[attr])) # also go over the groups for ii in graph['history']: # a group of objects, with unknown count #if 'count' not in ii: continue for attr in gvars.METAINFO['attributes']: if attr in ii: known_attrs.append((attr, ii[attr])) known_attrs = list(set(known_attrs)) return known_attrs def get_known_attribute_counts(graph): """Fetches a count of known attributes given the scene graph. Calls get_known_attributes method internally. Args: graph: Scene graph to check unique attributes from Returns: counts: Count of known attributes from the scene graph """ known_attrs = get_known_attributes(graph) # Go through objects and count. counts = {ii: 0 for ii in known_attrs} for _, obj in graph['objects'].items(): for attr, val in known_attrs: if obj.get(attr, None) == val: counts[(attr, val)] += 1 return counts def filter_attributes_with_known_counts(graph, known_attrs): """Filters attributes whose counts are known, given the scene graph. Args: graph: Scene graph from the dialog generated so far known_attrs: List of known attributes from the ground truth scene graph Returns: known_attrs: List of attributes with unknown counts removed inplace """ for attr, val in known_attrs[::-1]: for ii in graph['history']: # A group of objects, with unknown count. if 'count' not in ii: continue # Count is absent. if ii.get(attr, None) == val: known_attrs.remove((attr, val)) return known_attrs def clean_graph_item(graph_item): """Cleans up graph item (remove 'required' and 'optional' tags). Args: graph_item: Input graph item to be cleaned. Returns: clean_graph_item: Copy of the graph item after cleaning. """ clean_graph_item = copy.deepcopy(graph_item) if 'optional' in clean_graph_item: del clean_graph_item['optional'] if 'required' in clean_graph_item: del clean_graph_item['required'] for index, ii in enumerate(clean_graph_item['objects']): if 'optional' in ii: del clean_graph_item['objects'][index]['optional'] if 'required' in ii: del clean_graph_item['objects'][index]['required'] return clean_graph_item def get_attribute_counts_for_objects(scene, objects=None): """Counts attributes for a given set of objects. Args: scene: Scene graph for the dialog generated so far objects: List of objects. Default = None selects all objects Returns: counts: Counts for the attributes for attributes """ # Initialize the dictionary. counts = {} for attr, vals in gvars.METAINFO['values'].items(): for val in vals: counts[(attr, val)] = 0 # Now count for each given object. if objects is None: objects = scene['objects'] for obj in objects: for attr in gvars.METAINFO['attributes']: key = (attr, scene['objects'][obj['id']][attr]) counts[key] = counts.get(key, 0) + 1 return counts def get_unique_attribute_objects(graph, uniq_attrs): """Fetches objects from given scene graph with unique attributes. Args: graph: Scene graph constructed from the dialog generated so far uniq_attrs: List of unique attributes to get attributes Returns: obj_ids: List of object ids with the unique attributes """ obj_ids = {} for obj_id, obj in graph['objects'].items(): for attr, val in uniq_attrs: if obj.get(attr, '') == val: # At this point the key should not be present. assert (attr, val) not in obj_ids, 'Attributes not unique!' obj_ids[(attr, val)] = obj_id return obj_ids def sample_optional_tags(optional, sample_probs): """Samples additional tags depending on given sample probabilities. Args: optional: List of optional tags to sample from. sample_probs: Probabilities of sampling 'n' tags. Returns: sampled: Sampled tags from the optional list """ sampled = [] if len(optional) > 0: n_sample = np.random.choice([0, 1], 1, p=sample_probs[:2])[0] n_sample = min(n_sample, len(optional)) sampled = random.sample(optional, n_sample) return sampled