Mouse2Vec/utils.py
2024-05-07 17:01:12 +02:00

52 lines
No EOL
1.7 KiB
Python

import pdb, os, random, copy
import pandas as pd
import numpy as np
import pickle as pkl
import torch
from torch.utils.data import DataLoader,Dataset
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Unsupported value encountered.')
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def DFT(data, clicks=None):
for idx, channel in enumerate(range(data.shape[-1])):
freq = np.fft.fft(data[:,:,channel])
freq = freq[:, :data.shape[1]//2] # symmetric
mag = np.abs(freq)
phase = np.angle(freq)
mag = (mag - np.min(mag)) / (np.max(mag) - np.min(mag))
phase = (phase - np.min(phase)) / (np.max(phase) - np.min(phase))
if idx==0:
mags, phases = mag, phase
else:
mags, phases = np.stack([mags, mag], axis=-1), np.stack([phases, phase], axis=-1)
data = np.concatenate([data, mags, phases], axis=1)
return data, clicks
def MakeDir(dirName):
if not os.path.exists(dirName):
os.makedirs(dirName)
class simpleSet(Dataset):
def __init__(self, features, labels):
self.features = features
labels = labels.astype(int)
self.labels = labels - np.min(labels)
def __getitem__(self, index):
return self.features[index], self.labels[index]
def __len__(self):
return len(self.features)