Clarify data dimensions

This commit is contained in:
Guanhua Zhang 2024-10-10 14:01:06 +02:00
parent dd4a0e2d2e
commit 8849f3a236
2 changed files with 9 additions and 9 deletions

View file

@ -32,14 +32,14 @@ We enable two ways to use Mouse2Vec on your datasets for downstream tasks.
Execute Execute
```shell ```shell
CUDA_VISIBLE_DEVICES=3 python main.py --ssl True --load True --stage 0 --testDataset [Your Dataset] CUDA_VISIBLE_DEVICES=3 python main.py --sl True --load True --stage 0 --testDataset [Your Dataset]
``` ```
2. Finetune both Mouse2Vec and the classifier <br> 2. Finetune both Mouse2Vec and the classifier <br>
Execute Execute
```shell ```shell
CUDA_VISIBLE_DEVICES=3 python main.py --ssl True --load True --stage 1 --testDataset [Your Dataset] CUDA_VISIBLE_DEVICES=3 python main.py --sl True --load True --stage 1 --testDataset [Your Dataset]
``` ```
# Citation # Citation

14
main.py
View file

@ -1,8 +1,4 @@
import numpy as np
import random, pdb, os, copy
import argparse import argparse
import pandas as pd
import pickle as pkl
import torch import torch
from utils import MakeDir, set_seed, str2bool from utils import MakeDir, set_seed, str2bool
@ -50,9 +46,14 @@ if __name__ == '__main__':
if opt.ssl: if opt.ssl:
''' '''
Load dataset for pretraining the Autoencoder Load dataset for pretraining the Autoencoder
taskdata: on-screen coordinates (x,y), their magnitude and phase taskdata: on-screen coordinates (x,y), their magnitude and phase. Shape (N, 2, 200)
- N: the number of sliding windows generated from the mouse data
- 2: X or Y
- 200: 100 (20Hz * 5s) + 50 (half of 100 data points) + 50 (half of 100 data points), see the DFT function in utils.py
taskclicks: indicate if each mouse data point is a click (1) or move (0) event. Shape (N, 100)
- 100: 200Hz * 5s
''' '''
taskdata, taskclicks = loadPretrainDataset(opt) taskdata, taskclicks = loadPretrainDataset(opt) # Plug the loader of your dataset
model = Model(opt).to(opt.device) model = Model(opt).to(opt.device)
self_supervised_learning(model, taskdata, taskclicks, opt, modelfile=aemodelfile) self_supervised_learning(model, taskdata, taskclicks, opt, modelfile=aemodelfile)
@ -60,7 +61,6 @@ if __name__ == '__main__':
model = torch.load(aemodelfile, map_location=opt.device) model = torch.load(aemodelfile, map_location=opt.device)
else: else:
model = Model(opt).to(opt.device) model = Model(opt).to(opt.device)
pdb.set_trace()
if opt.sl: if opt.sl:
''' '''