import torch left_bias = [0.10659290976303822, 0.025158905348262015, 0.02811095449589107, 0.026342384511050237, 0.025318475572458178, 0.02283183957873461, 0.021581872822531316, 0.08062285577511237, 0.03824366373234754, 0.04853594319300018, 0.09653998563867983, 0.02961357410707162, 0.02961357410707162, 0.03172787957767081, 0.029985904630196004, 0.02897529321028696, 0.06602218026116327, 0.015345336560197867, 0.026900880295736816, 0.024879657455918726, 0.028669450280577644, 0.01936118720246802, 0.02341693040078721, 0.014707055663413206, 0.027007260445200926, 0.04146166325363687, 0.04243238211749688] right_bias = [0.13147256721895695, 0.012433179968617855, 0.01623627031195979, 0.013683146724821148, 0.015252253929416771, 0.012579452674131008, 0.03127576394244834, 0.10325523257360177, 0.041155820323927554, 0.06563655221935587, 0.12684503071726816, 0.016156485199861705, 0.0176989973670913, 0.020238823435546928, 0.01918831945958884, 0.01791175766601952, 0.08768383819579266, 0.019002154198026647, 0.029600276588388607, 0.01578415467673732, 0.0176989973670913, 0.011834791627882237, 0.014919815962341426, 0.007552990611951809, 0.029759846812584773, 0.04981250498656951, 0.05533097524002021] def build_ocr_graph(device): ocr_graph = [ [15, [10, 4], [17, 2]], [13, [16, 7], [18, 4]], [11, [16, 4], [7, 10]], [14, [10, 11], [7, 1]], [12, [10, 9], [16, 3]], [1, [7, 2], [9, 9], [10, 2]], [5, [8, 8], [6, 8]], [4, [9, 8], [7, 6]], [3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]], [2, [10, 1], [7, 7], [9, 3]], [19, [10, 2], [26, 6]], [20, [10, 7], [26, 5]], [22, [25, 4], [10, 8]], [23, [25, 15]], [21, [16, 5], [24, 8]] ] edge_index = [] edge_attr = [] for i in range(len(ocr_graph)): for j in range(1, len(ocr_graph[i])): source_node = ocr_graph[i][0] target_node = ocr_graph[i][j][0] edge_index.append([source_node, target_node]) edge_attr.append(ocr_graph[i][j][1]) ocr_edge_index = torch.tensor(edge_index).t().long() ocr_edge_attr = torch.tensor(edge_attr).to(torch.float).unsqueeze(1) x = torch.arange(0, 27) ocr_x = torch.nn.functional.one_hot(x, num_classes=27).to(torch.float) return ocr_x.to(device), ocr_edge_index.to(device), ocr_edge_attr.to(device) def pose_edge_index(): return torch.tensor( [[17, 15, 15, 0, 0, 16, 16, 18, 0, 1, 4, 3, 3, 2, 2, 1, 1, 5, 5, 6, 6, 7, 1, 8, 8, 9, 9, 10, 10, 11, 11, 24, 11, 23, 23, 22, 8, 12, 12, 13, 13, 14, 14, 21, 14, 19, 19, 20], [15, 17, 0, 15, 16, 0, 18, 16, 1, 0, 3, 4, 2, 3, 1, 2, 5, 1, 6, 5, 7, 6, 8, 1, 9, 8, 10, 9, 11, 10, 24, 11, 23, 11, 22, 23, 12, 8, 13, 12, 14, 13, 21, 14, 19, 14, 20, 19]], dtype=torch.long)