52 lines
1.7 KiB
Python
52 lines
1.7 KiB
Python
|
import pdb, os, random, copy
|
||
|
import pandas as pd
|
||
|
import numpy as np
|
||
|
import pickle as pkl
|
||
|
import torch
|
||
|
from torch.utils.data import DataLoader,Dataset
|
||
|
|
||
|
def str2bool(v):
|
||
|
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
||
|
return True
|
||
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
||
|
return False
|
||
|
else:
|
||
|
raise argparse.ArgumentTypeError('Unsupported value encountered.')
|
||
|
|
||
|
def set_seed(seed):
|
||
|
torch.manual_seed(seed)
|
||
|
torch.cuda.manual_seed(seed)
|
||
|
torch.cuda.manual_seed_all(seed)
|
||
|
np.random.seed(seed) # Numpy module.
|
||
|
random.seed(seed) # Python random module.
|
||
|
torch.backends.cudnn.benchmark = False
|
||
|
torch.backends.cudnn.deterministic = True
|
||
|
|
||
|
def DFT(data, clicks=None):
|
||
|
for idx, channel in enumerate(range(data.shape[-1])):
|
||
|
freq = np.fft.fft(data[:,:,channel])
|
||
|
freq = freq[:, :data.shape[1]//2] # symmetric
|
||
|
mag = np.abs(freq)
|
||
|
phase = np.angle(freq)
|
||
|
mag = (mag - np.min(mag)) / (np.max(mag) - np.min(mag))
|
||
|
phase = (phase - np.min(phase)) / (np.max(phase) - np.min(phase))
|
||
|
if idx==0:
|
||
|
mags, phases = mag, phase
|
||
|
else:
|
||
|
mags, phases = np.stack([mags, mag], axis=-1), np.stack([phases, phase], axis=-1)
|
||
|
data = np.concatenate([data, mags, phases], axis=1)
|
||
|
return data, clicks
|
||
|
|
||
|
def MakeDir(dirName):
|
||
|
if not os.path.exists(dirName):
|
||
|
os.makedirs(dirName)
|
||
|
|
||
|
class simpleSet(Dataset):
|
||
|
def __init__(self, features, labels):
|
||
|
self.features = features
|
||
|
labels = labels.astype(int)
|
||
|
self.labels = labels - np.min(labels)
|
||
|
def __getitem__(self, index):
|
||
|
return self.features[index], self.labels[index]
|
||
|
def __len__(self):
|
||
|
return len(self.features)
|