ActionDiffusion_WACV2025/temp.py

119 lines
3.7 KiB
Python
Raw Normal View History

2024-12-02 15:42:58 +01:00
import json
import os
import random
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import numpy as np
from data_load_json import PlanningDataset
from utils import *
from utils.args import get_args
from train_mlp import ResMLP, head
def main():
args = get_args()
os.environ['PYTHONHASHSEED'] = str(args.seed)
if args.verbose:
print(args)
if args.seed is not None:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
ngpus_per_node = torch.cuda.device_count()
if args.multiprocessing_distributed:
args.world_size = ngpus_per_node * args.world_size
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
else:
main_worker(args.gpu, ngpus_per_node, args)
def main_worker(gpu, ngpus_per_node, args):
args.gpu = gpu
if args.distributed:
if args.multiprocessing_distributed:
args.rank = args.rank * ngpus_per_node + gpu
dist.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
args.batch_size = int(args.batch_size / ngpus_per_node)
args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
args.num_thread_reader = int(args.num_thread_reader / ngpus_per_node)
elif args.gpu is not None:
torch.cuda.set_device(args.gpu)
test_dataset = PlanningDataset(
args.root,
args=args,
is_val=True,
model=None,
)
# create model
model = head(args.observation_dim, args.class_dim)
if args.distributed:
if args.gpu is not None:
model.cuda(args.gpu)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.gpu], find_unused_parameters=True)
else:
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
elif args.gpu is not None:
model = model.cuda(args.gpu)
else:
model = torch.nn.DataParallel(model).cuda()
#checkpoint_ = torch.load("save_max_mlp/epoch0022_act_T3.pth.tar", map_location='cuda:{}'.format(args.rank))
checkpoint_ = torch.load("save_max_mlp/"+args.dataset+"/"+args.checkpoint_mlp, map_location='cuda:{}'.format(args.rank))
model.load_state_dict(checkpoint_["model"])
if args.cudnn_benchmark:
cudnn.benchmark = True
model.eval()
json_ret = []
correct = 0
for i in range(len(test_dataset)):
frames_t, vid_names, frame_cnts = test_dataset[i]
event_class = model(frames_t)
event_class_id = torch.argmax(event_class)
if event_class_id == vid_names['task_id']:
correct += 1
vid_names['event_class'] = event_class_id.item()
json_current = {}
json_current['id'] = vid_names
json_current['instruction_len'] = frame_cnts
json_ret.append(json_current)
data_name = "dataset/" + args.dataset + "/" + args.dataset + "_mlp_T" + str(args.horizon) + ".json"
with open(data_name, 'w') as f:
json.dump(json_ret, f)
print('acc: ', correct / len(test_dataset))
if __name__ == "__main__":
main()