Clarify data dimensions
This commit is contained in:
parent
dd4a0e2d2e
commit
8849f3a236
2 changed files with 9 additions and 9 deletions
|
@ -32,14 +32,14 @@ We enable two ways to use Mouse2Vec on your datasets for downstream tasks.
|
|||
|
||||
Execute
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=3 python main.py --ssl True --load True --stage 0 --testDataset [Your Dataset]
|
||||
CUDA_VISIBLE_DEVICES=3 python main.py --sl True --load True --stage 0 --testDataset [Your Dataset]
|
||||
```
|
||||
|
||||
2. Finetune both Mouse2Vec and the classifier <br>
|
||||
|
||||
Execute
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=3 python main.py --ssl True --load True --stage 1 --testDataset [Your Dataset]
|
||||
CUDA_VISIBLE_DEVICES=3 python main.py --sl True --load True --stage 1 --testDataset [Your Dataset]
|
||||
```
|
||||
|
||||
# Citation
|
||||
|
|
14
main.py
14
main.py
|
@ -1,8 +1,4 @@
|
|||
import numpy as np
|
||||
import random, pdb, os, copy
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import pickle as pkl
|
||||
import torch
|
||||
|
||||
from utils import MakeDir, set_seed, str2bool
|
||||
|
@ -50,9 +46,14 @@ if __name__ == '__main__':
|
|||
if opt.ssl:
|
||||
'''
|
||||
Load dataset for pretraining the Autoencoder
|
||||
taskdata: on-screen coordinates (x,y), their magnitude and phase
|
||||
taskdata: on-screen coordinates (x,y), their magnitude and phase. Shape (N, 2, 200)
|
||||
- N: the number of sliding windows generated from the mouse data
|
||||
- 2: X or Y
|
||||
- 200: 100 (20Hz * 5s) + 50 (half of 100 data points) + 50 (half of 100 data points), see the DFT function in utils.py
|
||||
taskclicks: indicate if each mouse data point is a click (1) or move (0) event. Shape (N, 100)
|
||||
- 100: 200Hz * 5s
|
||||
'''
|
||||
taskdata, taskclicks = loadPretrainDataset(opt)
|
||||
taskdata, taskclicks = loadPretrainDataset(opt) # Plug the loader of your dataset
|
||||
model = Model(opt).to(opt.device)
|
||||
self_supervised_learning(model, taskdata, taskclicks, opt, modelfile=aemodelfile)
|
||||
|
||||
|
@ -60,7 +61,6 @@ if __name__ == '__main__':
|
|||
model = torch.load(aemodelfile, map_location=opt.device)
|
||||
else:
|
||||
model = Model(opt).to(opt.device)
|
||||
pdb.set_trace()
|
||||
|
||||
if opt.sl:
|
||||
'''
|
||||
|
|
Loading…
Reference in a new issue