from torch.utils.data import Dataset import torch import pandas as pd def loadDataset(conf): eval('taskdata, tasklabels = load%s(conf)'%(conf.pretrainDataset)) # plug in the function to load your own dataset tasklabels = pd.DataFrame(tasklabels, columns=['user']) print('taskdata.shape:', taskdata.shape) # (N, 2, window_length*sample_freq) return taskdata, tasklabels class SimpleSet(Dataset): def __init__(self, data, labels, intflag=True): self.data = torch.tensor(data, dtype=torch.float) if intflag: self.label = torch.tensor(labels, dtype=torch.long) else: self.label = torch.tensor(labels, dtype=torch.float) def __len__(self): return len(self.data) def __getitem__(self, index): return self.data[index], self.label[index]