Clarify data dimensions
This commit is contained in:
parent
dd4a0e2d2e
commit
8849f3a236
2 changed files with 9 additions and 9 deletions
|
@ -32,14 +32,14 @@ We enable two ways to use Mouse2Vec on your datasets for downstream tasks.
|
||||||
|
|
||||||
Execute
|
Execute
|
||||||
```shell
|
```shell
|
||||||
CUDA_VISIBLE_DEVICES=3 python main.py --ssl True --load True --stage 0 --testDataset [Your Dataset]
|
CUDA_VISIBLE_DEVICES=3 python main.py --sl True --load True --stage 0 --testDataset [Your Dataset]
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Finetune both Mouse2Vec and the classifier <br>
|
2. Finetune both Mouse2Vec and the classifier <br>
|
||||||
|
|
||||||
Execute
|
Execute
|
||||||
```shell
|
```shell
|
||||||
CUDA_VISIBLE_DEVICES=3 python main.py --ssl True --load True --stage 1 --testDataset [Your Dataset]
|
CUDA_VISIBLE_DEVICES=3 python main.py --sl True --load True --stage 1 --testDataset [Your Dataset]
|
||||||
```
|
```
|
||||||
|
|
||||||
# Citation
|
# Citation
|
||||||
|
|
14
main.py
14
main.py
|
@ -1,8 +1,4 @@
|
||||||
import numpy as np
|
|
||||||
import random, pdb, os, copy
|
|
||||||
import argparse
|
import argparse
|
||||||
import pandas as pd
|
|
||||||
import pickle as pkl
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from utils import MakeDir, set_seed, str2bool
|
from utils import MakeDir, set_seed, str2bool
|
||||||
|
@ -50,9 +46,14 @@ if __name__ == '__main__':
|
||||||
if opt.ssl:
|
if opt.ssl:
|
||||||
'''
|
'''
|
||||||
Load dataset for pretraining the Autoencoder
|
Load dataset for pretraining the Autoencoder
|
||||||
taskdata: on-screen coordinates (x,y), their magnitude and phase
|
taskdata: on-screen coordinates (x,y), their magnitude and phase. Shape (N, 2, 200)
|
||||||
|
- N: the number of sliding windows generated from the mouse data
|
||||||
|
- 2: X or Y
|
||||||
|
- 200: 100 (20Hz * 5s) + 50 (half of 100 data points) + 50 (half of 100 data points), see the DFT function in utils.py
|
||||||
|
taskclicks: indicate if each mouse data point is a click (1) or move (0) event. Shape (N, 100)
|
||||||
|
- 100: 200Hz * 5s
|
||||||
'''
|
'''
|
||||||
taskdata, taskclicks = loadPretrainDataset(opt)
|
taskdata, taskclicks = loadPretrainDataset(opt) # Plug the loader of your dataset
|
||||||
model = Model(opt).to(opt.device)
|
model = Model(opt).to(opt.device)
|
||||||
self_supervised_learning(model, taskdata, taskclicks, opt, modelfile=aemodelfile)
|
self_supervised_learning(model, taskdata, taskclicks, opt, modelfile=aemodelfile)
|
||||||
|
|
||||||
|
@ -60,7 +61,6 @@ if __name__ == '__main__':
|
||||||
model = torch.load(aemodelfile, map_location=opt.device)
|
model = torch.load(aemodelfile, map_location=opt.device)
|
||||||
else:
|
else:
|
||||||
model = Model(opt).to(opt.device)
|
model = Model(opt).to(opt.device)
|
||||||
pdb.set_trace()
|
|
||||||
|
|
||||||
if opt.sl:
|
if opt.sl:
|
||||||
'''
|
'''
|
||||||
|
|
Loading…
Reference in a new issue