SalChartQA/Code/tokenizer_bert.py
2024-01-22 21:02:16 +08:00

15 lines
430 B
Python

import torch
from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("roberta-base")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
print('bert-base-uncased tokenizer loaded')
def padding_fn(data):
img, q, fix, hm, name = zip(*data)
input_ids = tokenizer(q, return_tensors="pt", padding=True)
return torch.stack(img), input_ids, torch.stack(fix), torch.stack(hm), name