118 lines
3.7 KiB
Python
118 lines
3.7 KiB
Python
import json
|
|
import os
|
|
import random
|
|
import torch.nn.parallel
|
|
import torch.backends.cudnn as cudnn
|
|
import torch.distributed as dist
|
|
import torch.optim
|
|
import torch.multiprocessing as mp
|
|
import torch.utils.data
|
|
import torch.utils.data.distributed
|
|
import numpy as np
|
|
|
|
from data_load_json import PlanningDataset
|
|
from utils import *
|
|
from utils.args import get_args
|
|
from train_mlp import ResMLP, head
|
|
|
|
|
|
def main():
|
|
args = get_args()
|
|
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
|
|
|
if args.verbose:
|
|
print(args)
|
|
if args.seed is not None:
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
|
ngpus_per_node = torch.cuda.device_count()
|
|
|
|
if args.multiprocessing_distributed:
|
|
args.world_size = ngpus_per_node * args.world_size
|
|
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
|
else:
|
|
main_worker(args.gpu, ngpus_per_node, args)
|
|
|
|
|
|
def main_worker(gpu, ngpus_per_node, args):
|
|
args.gpu = gpu
|
|
|
|
if args.distributed:
|
|
if args.multiprocessing_distributed:
|
|
args.rank = args.rank * ngpus_per_node + gpu
|
|
dist.init_process_group(
|
|
backend=args.dist_backend,
|
|
init_method=args.dist_url,
|
|
world_size=args.world_size,
|
|
rank=args.rank,
|
|
)
|
|
if args.gpu is not None:
|
|
torch.cuda.set_device(args.gpu)
|
|
args.batch_size = int(args.batch_size / ngpus_per_node)
|
|
args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
|
|
args.num_thread_reader = int(args.num_thread_reader / ngpus_per_node)
|
|
elif args.gpu is not None:
|
|
torch.cuda.set_device(args.gpu)
|
|
|
|
test_dataset = PlanningDataset(
|
|
args.root,
|
|
args=args,
|
|
is_val=True,
|
|
model=None,
|
|
)
|
|
|
|
# create model
|
|
model = head(args.observation_dim, args.class_dim)
|
|
|
|
if args.distributed:
|
|
if args.gpu is not None:
|
|
model.cuda(args.gpu)
|
|
model = torch.nn.parallel.DistributedDataParallel(
|
|
model, device_ids=[args.gpu], find_unused_parameters=True)
|
|
else:
|
|
model.cuda()
|
|
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
|
|
|
|
elif args.gpu is not None:
|
|
model = model.cuda(args.gpu)
|
|
else:
|
|
model = torch.nn.DataParallel(model).cuda()
|
|
|
|
#checkpoint_ = torch.load("save_max_mlp/epoch0022_act_T3.pth.tar", map_location='cuda:{}'.format(args.rank))
|
|
checkpoint_ = torch.load("save_max_mlp/"+args.dataset+"/"+args.checkpoint_mlp, map_location='cuda:{}'.format(args.rank))
|
|
model.load_state_dict(checkpoint_["model"])
|
|
|
|
if args.cudnn_benchmark:
|
|
cudnn.benchmark = True
|
|
|
|
model.eval()
|
|
|
|
json_ret = []
|
|
correct = 0
|
|
|
|
for i in range(len(test_dataset)):
|
|
frames_t, vid_names, frame_cnts = test_dataset[i]
|
|
event_class = model(frames_t)
|
|
event_class_id = torch.argmax(event_class)
|
|
if event_class_id == vid_names['task_id']:
|
|
correct += 1
|
|
vid_names['event_class'] = event_class_id.item()
|
|
json_current = {}
|
|
json_current['id'] = vid_names
|
|
json_current['instruction_len'] = frame_cnts
|
|
|
|
json_ret.append(json_current)
|
|
|
|
data_name = "dataset/" + args.dataset + "/" + args.dataset + "_mlp_T" + str(args.horizon) + ".json"
|
|
|
|
with open(data_name, 'w') as f:
|
|
json.dump(json_ret, f)
|
|
|
|
print('acc: ', correct / len(test_dataset))
|
|
|
|
if __name__ == "__main__":
|
|
main()
|