From 8849f3a236fc4f2f0ea89b3a75d454f0f776a15c Mon Sep 17 00:00:00 2001 From: Guanhua Zhang Date: Thu, 10 Oct 2024 14:01:06 +0200 Subject: [PATCH] Clarify data dimensions --- README.md | 4 ++-- main.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 04ca5d9..f35bda9 100644 --- a/README.md +++ b/README.md @@ -32,14 +32,14 @@ We enable two ways to use Mouse2Vec on your datasets for downstream tasks. Execute ```shell -CUDA_VISIBLE_DEVICES=3 python main.py --ssl True --load True --stage 0 --testDataset [Your Dataset] +CUDA_VISIBLE_DEVICES=3 python main.py --sl True --load True --stage 0 --testDataset [Your Dataset] ``` 2. Finetune both Mouse2Vec and the classifier
Execute ```shell -CUDA_VISIBLE_DEVICES=3 python main.py --ssl True --load True --stage 1 --testDataset [Your Dataset] +CUDA_VISIBLE_DEVICES=3 python main.py --sl True --load True --stage 1 --testDataset [Your Dataset] ``` # Citation diff --git a/main.py b/main.py index e85e8a9..8bc7904 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,4 @@ -import numpy as np -import random, pdb, os, copy import argparse -import pandas as pd -import pickle as pkl import torch from utils import MakeDir, set_seed, str2bool @@ -50,9 +46,14 @@ if __name__ == '__main__': if opt.ssl: ''' Load dataset for pretraining the Autoencoder - taskdata: on-screen coordinates (x,y), their magnitude and phase + taskdata: on-screen coordinates (x,y), their magnitude and phase. Shape (N, 2, 200) + - N: the number of sliding windows generated from the mouse data + - 2: X or Y + - 200: 100 (20Hz * 5s) + 50 (half of 100 data points) + 50 (half of 100 data points), see the DFT function in utils.py + taskclicks: indicate if each mouse data point is a click (1) or move (0) event. Shape (N, 100) + - 100: 200Hz * 5s ''' - taskdata, taskclicks = loadPretrainDataset(opt) + taskdata, taskclicks = loadPretrainDataset(opt) # Plug the loader of your dataset model = Model(opt).to(opt.device) self_supervised_learning(model, taskdata, taskclicks, opt, modelfile=aemodelfile) @@ -60,7 +61,6 @@ if __name__ == '__main__': model = torch.load(aemodelfile, map_location=opt.device) else: model = Model(opt).to(opt.device) - pdb.set_trace() if opt.sl: '''