DisMouse/dataset.py

21 lines
823 B
Python
Raw Normal View History

2024-10-08 14:18:47 +02:00
from torch.utils.data import Dataset
import torch
import pandas as pd
def loadDataset(conf):
eval('taskdata, tasklabels = load%s(conf)'%(conf.pretrainDataset)) # plug in the function to load your own dataset
tasklabels = pd.DataFrame(tasklabels, columns=['user'])
print('taskdata.shape:', taskdata.shape) # (N, 2, window_length*sample_freq)
return taskdata, tasklabels
class SimpleSet(Dataset):
def __init__(self, data, labels, intflag=True):
self.data = torch.tensor(data, dtype=torch.float)
if intflag:
self.label = torch.tensor(labels, dtype=torch.long)
else:
self.label = torch.tensor(labels, dtype=torch.float)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index], self.label[index]