18 lines
519 B
Python
18 lines
519 B
Python
import torch
|
|
from torch.utils.data import Dataset
|
|
import numpy as np
|
|
|
|
class ImagesWithSaliency(Dataset):
|
|
def __init__(self, npy_path, dtype=None):
|
|
self.dtype = dtype
|
|
self.datas = np.load(npy_path, allow_pickle = True)
|
|
|
|
def __len__(self):
|
|
return len(self.datas)
|
|
|
|
def __getitem__(self, idx):
|
|
if self.dtype:
|
|
self.datas[idx][0] = self.datas[idx][0].type(self.dtype)
|
|
self.datas[idx][3] = self.datas[idx][3].type(self.dtype)
|
|
|
|
return self.datas[idx]
|