21 lines
823 B
Python
21 lines
823 B
Python
|
from torch.utils.data import Dataset
|
||
|
import torch
|
||
|
import pandas as pd
|
||
|
|
||
|
def loadDataset(conf):
|
||
|
eval('taskdata, tasklabels = load%s(conf)'%(conf.pretrainDataset)) # plug in the function to load your own dataset
|
||
|
tasklabels = pd.DataFrame(tasklabels, columns=['user'])
|
||
|
print('taskdata.shape:', taskdata.shape) # (N, 2, window_length*sample_freq)
|
||
|
return taskdata, tasklabels
|
||
|
|
||
|
class SimpleSet(Dataset):
|
||
|
def __init__(self, data, labels, intflag=True):
|
||
|
self.data = torch.tensor(data, dtype=torch.float)
|
||
|
if intflag:
|
||
|
self.label = torch.tensor(labels, dtype=torch.long)
|
||
|
else:
|
||
|
self.label = torch.tensor(labels, dtype=torch.float)
|
||
|
def __len__(self):
|
||
|
return len(self.data)
|
||
|
def __getitem__(self, index):
|
||
|
return self.data[index], self.label[index]
|