95 lines
3.5 KiB
Python
95 lines
3.5 KiB
Python
import torch
|
|
|
|
left_bias = [0.10659290976303822,
|
|
0.025158905348262015,
|
|
0.02811095449589107,
|
|
0.026342384511050237,
|
|
0.025318475572458178,
|
|
0.02283183957873461,
|
|
0.021581872822531316,
|
|
0.08062285577511237,
|
|
0.03824366373234754,
|
|
0.04853594319300018,
|
|
0.09653998563867983,
|
|
0.02961357410707162,
|
|
0.02961357410707162,
|
|
0.03172787957767081,
|
|
0.029985904630196004,
|
|
0.02897529321028696,
|
|
0.06602218026116327,
|
|
0.015345336560197867,
|
|
0.026900880295736816,
|
|
0.024879657455918726,
|
|
0.028669450280577644,
|
|
0.01936118720246802,
|
|
0.02341693040078721,
|
|
0.014707055663413206,
|
|
0.027007260445200926,
|
|
0.04146166325363687,
|
|
0.04243238211749688]
|
|
|
|
right_bias = [0.13147256721895695,
|
|
0.012433179968617855,
|
|
0.01623627031195979,
|
|
0.013683146724821148,
|
|
0.015252253929416771,
|
|
0.012579452674131008,
|
|
0.03127576394244834,
|
|
0.10325523257360177,
|
|
0.041155820323927554,
|
|
0.06563655221935587,
|
|
0.12684503071726816,
|
|
0.016156485199861705,
|
|
0.0176989973670913,
|
|
0.020238823435546928,
|
|
0.01918831945958884,
|
|
0.01791175766601952,
|
|
0.08768383819579266,
|
|
0.019002154198026647,
|
|
0.029600276588388607,
|
|
0.01578415467673732,
|
|
0.0176989973670913,
|
|
0.011834791627882237,
|
|
0.014919815962341426,
|
|
0.007552990611951809,
|
|
0.029759846812584773,
|
|
0.04981250498656951,
|
|
0.05533097524002021]
|
|
|
|
def build_ocr_graph(device):
|
|
ocr_graph = [
|
|
[15, [10, 4], [17, 2]],
|
|
[13, [16, 7], [18, 4]],
|
|
[11, [16, 4], [7, 10]],
|
|
[14, [10, 11], [7, 1]],
|
|
[12, [10, 9], [16, 3]],
|
|
[1, [7, 2], [9, 9], [10, 2]],
|
|
[5, [8, 8], [6, 8]],
|
|
[4, [9, 8], [7, 6]],
|
|
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
|
|
[2, [10, 1], [7, 7], [9, 3]],
|
|
[19, [10, 2], [26, 6]],
|
|
[20, [10, 7], [26, 5]],
|
|
[22, [25, 4], [10, 8]],
|
|
[23, [25, 15]],
|
|
[21, [16, 5], [24, 8]]
|
|
]
|
|
edge_index = []
|
|
edge_attr = []
|
|
for i in range(len(ocr_graph)):
|
|
for j in range(1, len(ocr_graph[i])):
|
|
source_node = ocr_graph[i][0]
|
|
target_node = ocr_graph[i][j][0]
|
|
edge_index.append([source_node, target_node])
|
|
edge_attr.append(ocr_graph[i][j][1])
|
|
ocr_edge_index = torch.tensor(edge_index).t().long()
|
|
ocr_edge_attr = torch.tensor(edge_attr).to(torch.float).unsqueeze(1)
|
|
x = torch.arange(0, 27)
|
|
ocr_x = torch.nn.functional.one_hot(x, num_classes=27).to(torch.float)
|
|
return ocr_x.to(device), ocr_edge_index.to(device), ocr_edge_attr.to(device)
|
|
|
|
def pose_edge_index():
|
|
return torch.tensor(
|
|
[[17, 15, 15, 0, 0, 16, 16, 18, 0, 1, 4, 3, 3, 2, 2, 1, 1, 5, 5, 6, 6, 7, 1, 8, 8, 9, 9, 10, 10, 11, 11, 24, 11, 23, 23, 22, 8, 12, 12, 13, 13, 14, 14, 21, 14, 19, 19, 20],
|
|
[15, 17, 0, 15, 16, 0, 18, 16, 1, 0, 3, 4, 2, 3, 1, 2, 5, 1, 6, 5, 7, 6, 8, 1, 9, 8, 10, 9, 11, 10, 24, 11, 23, 11, 22, 23, 12, 8, 13, 12, 14, 13, 21, 14, 19, 14, 20, 19]],
|
|
dtype=torch.long)
|