mtomnet/boss/models/utils.py

96 lines
3.5 KiB
Python
Raw Normal View History

2025-01-10 15:39:20 +01:00
import torch
left_bias = [0.10659290976303822,
0.025158905348262015,
0.02811095449589107,
0.026342384511050237,
0.025318475572458178,
0.02283183957873461,
0.021581872822531316,
0.08062285577511237,
0.03824366373234754,
0.04853594319300018,
0.09653998563867983,
0.02961357410707162,
0.02961357410707162,
0.03172787957767081,
0.029985904630196004,
0.02897529321028696,
0.06602218026116327,
0.015345336560197867,
0.026900880295736816,
0.024879657455918726,
0.028669450280577644,
0.01936118720246802,
0.02341693040078721,
0.014707055663413206,
0.027007260445200926,
0.04146166325363687,
0.04243238211749688]
right_bias = [0.13147256721895695,
0.012433179968617855,
0.01623627031195979,
0.013683146724821148,
0.015252253929416771,
0.012579452674131008,
0.03127576394244834,
0.10325523257360177,
0.041155820323927554,
0.06563655221935587,
0.12684503071726816,
0.016156485199861705,
0.0176989973670913,
0.020238823435546928,
0.01918831945958884,
0.01791175766601952,
0.08768383819579266,
0.019002154198026647,
0.029600276588388607,
0.01578415467673732,
0.0176989973670913,
0.011834791627882237,
0.014919815962341426,
0.007552990611951809,
0.029759846812584773,
0.04981250498656951,
0.05533097524002021]
def build_ocr_graph(device):
ocr_graph = [
[15, [10, 4], [17, 2]],
[13, [16, 7], [18, 4]],
[11, [16, 4], [7, 10]],
[14, [10, 11], [7, 1]],
[12, [10, 9], [16, 3]],
[1, [7, 2], [9, 9], [10, 2]],
[5, [8, 8], [6, 8]],
[4, [9, 8], [7, 6]],
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
[2, [10, 1], [7, 7], [9, 3]],
[19, [10, 2], [26, 6]],
[20, [10, 7], [26, 5]],
[22, [25, 4], [10, 8]],
[23, [25, 15]],
[21, [16, 5], [24, 8]]
]
edge_index = []
edge_attr = []
for i in range(len(ocr_graph)):
for j in range(1, len(ocr_graph[i])):
source_node = ocr_graph[i][0]
target_node = ocr_graph[i][j][0]
edge_index.append([source_node, target_node])
edge_attr.append(ocr_graph[i][j][1])
ocr_edge_index = torch.tensor(edge_index).t().long()
ocr_edge_attr = torch.tensor(edge_attr).to(torch.float).unsqueeze(1)
x = torch.arange(0, 27)
ocr_x = torch.nn.functional.one_hot(x, num_classes=27).to(torch.float)
return ocr_x.to(device), ocr_edge_index.to(device), ocr_edge_attr.to(device)
def pose_edge_index():
return torch.tensor(
[[17, 15, 15, 0, 0, 16, 16, 18, 0, 1, 4, 3, 3, 2, 2, 1, 1, 5, 5, 6, 6, 7, 1, 8, 8, 9, 9, 10, 10, 11, 11, 24, 11, 23, 23, 22, 8, 12, 12, 13, 13, 14, 14, 21, 14, 19, 19, 20],
[15, 17, 0, 15, 16, 0, 18, 16, 1, 0, 3, 4, 2, 3, 1, 2, 5, 1, 6, 5, 7, 6, 8, 1, 9, 8, 10, 9, 11, 10, 24, 11, 23, 11, 22, 23, 12, 8, 13, 12, 14, 13, 21, 14, 19, 14, 20, 19]],
dtype=torch.long)