Mouse2Vec/utils.py

52 lines
1.7 KiB
Python
Raw Normal View History

2024-05-07 17:01:12 +02:00
import pdb, os, random, copy
import pandas as pd
import numpy as np
import pickle as pkl
import torch
from torch.utils.data import DataLoader,Dataset
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Unsupported value encountered.')
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def DFT(data, clicks=None):
for idx, channel in enumerate(range(data.shape[-1])):
freq = np.fft.fft(data[:,:,channel])
freq = freq[:, :data.shape[1]//2] # symmetric
mag = np.abs(freq)
phase = np.angle(freq)
mag = (mag - np.min(mag)) / (np.max(mag) - np.min(mag))
phase = (phase - np.min(phase)) / (np.max(phase) - np.min(phase))
if idx==0:
mags, phases = mag, phase
else:
mags, phases = np.stack([mags, mag], axis=-1), np.stack([phases, phase], axis=-1)
data = np.concatenate([data, mags, phases], axis=1)
return data, clicks
def MakeDir(dirName):
if not os.path.exists(dirName):
os.makedirs(dirName)
class simpleSet(Dataset):
def __init__(self, features, labels):
self.features = features
labels = labels.astype(int)
self.labels = labels - np.min(labels)
def __getitem__(self, index):
return self.features[index], self.labels[index]
def __len__(self):
return len(self.features)