This commit is contained in:
Guanhua Zhang 2024-05-07 17:01:12 +02:00
commit b429f9807d
7 changed files with 1186 additions and 0 deletions

438
CRT.py Normal file
View file

@ -0,0 +1,438 @@
import pdb, copy
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from base_models import MLP, resnet1d18, SSLDataSet
from sklearn.metrics import f1_score
class cnn_extractor(nn.Module):
def __init__(self, dim, input_plane):
super(cnn_extractor, self).__init__()
self.cnn = resnet1d18(input_channels=dim, inplanes=input_plane)
def forward(self, x):
x = self.cnn(x)
return x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class TFR(nn.Module):
def __init__(self, seq_len, patch_len, dim, depth, heads, mlp_dim, channels=12,
dim_head=64, dropout=0., emb_dropout=0.):
'''
The encoder of CRT
'''
super().__init__()
self.patch_len = patch_len
self.seq_len = seq_len
assert seq_len % (4 * patch_len) == 0, \
'The seq_len should be 4 * n * patch_len, or there must be patch with both magnitude and phase data.'
num_patches = seq_len // patch_len
patch_dim = channels * patch_len
self.to_patch = nn.Sequential(Rearrange('b c (n p1) -> b n c p1', p1=patch_len),
Rearrange('b n c p1 -> (b n) c p1'))
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 3, dim))
self.modal_embedding = nn.Parameter(torch.randn(3, 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 3, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.cnn1 = cnn_extractor(dim=channels, input_plane=dim // 8) # For temporal data
self.cnn2 = cnn_extractor(dim=channels, input_plane=dim // 8) # For magnitude data
self.cnn3 = cnn_extractor(dim=channels, input_plane=dim // 8) # For phase data
def forward(self, x, opt):
batch, _, time_steps = x.shape
t, m, p = x[:, :, :time_steps // 2], x[:, :, time_steps // 2: time_steps * 3 // 4], x[:, :, -time_steps // 4:]
assert t.shape[-1] % opt.patch_len == 0 == m.shape[-1] % opt.patch_len == p.shape[-1] % opt.patch_len
t, m, p = self.to_patch(t), self.to_patch(m), self.to_patch(p)
patch2seq = nn.Sequential(nn.AdaptiveAvgPool1d(1),
Rearrange('(b n) c 1 -> b n c', b=batch))
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=batch)
x = torch.cat((cls_tokens[:, 0:1, :], patch2seq(self.cnn1(t)),
cls_tokens[:, 1:2, :], patch2seq(self.cnn2(m)),
cls_tokens[:, 2:3, :], patch2seq(self.cnn3(p))), dim=1)
b, t, c = x.shape
time_steps = t - 3
t_token_idx, m_token_idx, p_token_idx = 0, time_steps // 2 + 1, time_steps * 3 // 4 + 2
x[:m_token_idx] += self.modal_embedding[:1]
x[m_token_idx: p_token_idx] += self.modal_embedding[1:2]
x[p_token_idx: ] += self.modal_embedding[2:]
x += self.pos_embedding[:, : t]
x = self.dropout(x)
x = self.transformer(x)
t_token, m_token, p_token = x[:, t_token_idx], x[:, m_token_idx], x[:, p_token_idx]
avg = (t_token + m_token + p_token) / 3
return avg
def TFR_Encoder(seq_len, patch_len, dim, in_dim, depth):
vit = TFR(seq_len=seq_len,
patch_len=patch_len,
#num_classes=num_class,
dim=dim,
depth=depth,
heads=8,
mlp_dim=dim,
dropout=0.2,
emb_dropout=0.1,
channels=in_dim)
return vit
class CRT(nn.Module):
def __init__(
self,
encoder,
decoder_dim,
decoder_depth=2,
decoder_heads=8,
decoder_dim_head=64,
patch_len = 20,
in_dim=12
):
super().__init__()
self.encoder = encoder
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
self.to_patch = encoder.to_patch
pixel_values_per_patch = in_dim * patch_len
# decoder parameters
self.modal_embedding = self.encoder.modal_embedding
self.mask_token = nn.Parameter(torch.randn(3, decoder_dim))
self.decoder = Transformer(dim=decoder_dim,
depth=decoder_depth,
heads=decoder_heads,
dim_head=decoder_dim_head,
mlp_dim=decoder_dim)
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
self.to_pixels = nn.ModuleList([nn.Linear(decoder_dim, pixel_values_per_patch) for i in range(3)])
self.projs = nn.ModuleList([nn.Linear(decoder_dim, decoder_dim) for i in range(2)])
self.to_clicks = nn.Linear(decoder_dim, 2 * patch_len)
def IDC_loss(self, tokens, encoded_tokens):
B, T, D = tokens.shape
tokens, encoded_tokens = F.normalize(tokens, dim=-1), F.normalize(encoded_tokens, dim=-1)
encoded_tokens = encoded_tokens.transpose(2, 1)
cross_mul = torch.exp(torch.matmul(tokens, encoded_tokens))
mask = (1 - torch.eye(T)).unsqueeze(0).to(tokens.device)
cross_mul = cross_mul * mask
return torch.log(cross_mul.sum(-1).sum(-1)).mean(-1)
def forward(self, x, clicks, mask_ratio=0.75, beta = 1e-4, beta2 = 1e-4):
device = x.device
patches = self.to_patch[0](x)
batch, num_patches, c, length = patches.shape
clickpatches = self.to_patch[0](clicks.unsqueeze(1).float())
num_masked = int(mask_ratio * num_patches)
rand_indices1 = torch.randperm(num_patches // 2, device=device)
masked_indices1 = rand_indices1[: num_masked // 2].sort()[0]
unmasked_indices1 = rand_indices1[num_masked // 2:].sort()[0]
rand_indices2 = torch.randperm(num_patches // 4, device=device)
masked_indices2, unmasked_indices2 = rand_indices2[: num_masked // 4].sort()[0], rand_indices2[num_masked // 4:].sort()[0]
rand_indices = torch.cat((masked_indices1, unmasked_indices1,
masked_indices2 + num_patches // 2, unmasked_indices2 + num_patches // 2,
masked_indices2 + num_patches // 4 * 3, unmasked_indices2 + num_patches // 4 * 3))
masked_num_t, masked_num_f = masked_indices1.shape[0], 2 * masked_indices2.shape[0]
unmasked_num_t, unmasked_num_f = unmasked_indices1.shape[0], 2 * unmasked_indices2.shape[0]
tpatches = patches[:, : num_patches // 2, :, :]
mpatches, ppatches = patches[:, num_patches // 2: num_patches * 3 // 4, :, :], patches[:, -num_patches // 4:, :, :]
unmasked_tpatches = tpatches[:, unmasked_indices1, :, :]
unmasked_mpatches, unmasked_ppatches = mpatches[:, unmasked_indices2, :, :], ppatches[:, unmasked_indices2, :, :]
t_tokens, m_tokens, p_tokens = self.to_patch[1](unmasked_tpatches), self.to_patch[1](unmasked_mpatches), self.to_patch[1](unmasked_ppatches)
t_tokens, m_tokens, p_tokens = self.encoder.cnn1(t_tokens), self.encoder.cnn2(m_tokens), self.encoder.cnn3(p_tokens)
Flat = nn.Sequential(nn.AdaptiveAvgPool1d(1),
Rearrange('(b n) c 1 -> b n c', b=batch))
t_tokens, m_tokens, p_tokens = Flat(t_tokens), Flat(m_tokens), Flat(p_tokens)
ori_tokens = torch.cat((t_tokens, m_tokens, p_tokens), 1).clone()
cls_tokens = repeat(self.encoder.cls_token, '() n d -> b n d', b=batch)
tokens = torch.cat((cls_tokens[:, 0:1, :], t_tokens,
cls_tokens[:, 1:2, :], m_tokens,
cls_tokens[:, 2:3, :], p_tokens), dim=1)
t_idx, m_idx, p_idx = num_patches // 2 - 1, num_patches * 3 // 4 - 1, num_patches - 1
pos_embedding = torch.cat((self.encoder.pos_embedding[:, 0:1, :], self.encoder.pos_embedding[:, unmasked_indices1 + 1, :],
self.encoder.pos_embedding[:, t_idx + 2: t_idx + 3],
self.encoder.pos_embedding[:, unmasked_indices2 + t_idx + 3, :],
self.encoder.pos_embedding[:, m_idx + 3: m_idx + 4],
self.encoder.pos_embedding[:, unmasked_indices2 + m_idx + 4, :]), dim=1)
modal_embedding = torch.cat((repeat(self.modal_embedding[0], '1 d -> 1 n d', n=unmasked_num_t + 1),
repeat(self.modal_embedding[1], '1 d -> 1 n d', n=unmasked_num_f // 2 + 1),
repeat(self.modal_embedding[2], '1 d -> 1 n d', n=unmasked_num_f // 2 + 1)), dim=1)
tokens = tokens + pos_embedding + modal_embedding
encoded_tokens = self.encoder.transformer(tokens) # tokens: unmasked tokens + CLS tokens
t_idx, m_idx, p_idx = unmasked_num_t, unmasked_num_f // 2 + unmasked_num_t + 1, -1
idc_loss = self.IDC_loss(self.projs[0](ori_tokens), self.projs[1](torch.cat(([encoded_tokens[:, 1: t_idx+1], encoded_tokens[:, t_idx+2: m_idx+1], encoded_tokens[:, m_idx+2: ]]), dim=1)))
decoder_tokens = encoded_tokens
mask_tokens1 = repeat(self.mask_token[0], 'd -> b n d', b=batch, n=masked_num_t)
mask_tokens2 = repeat(self.mask_token[1], 'd -> b n d', b=batch, n=masked_num_f // 2)
mask_tokens3 = repeat(self.mask_token[2], 'd -> b n d', b=batch, n=masked_num_f // 2)
mask_tokens = torch.cat((mask_tokens1, mask_tokens2, mask_tokens3), dim=1)
decoder_pos_emb = self.decoder_pos_emb(torch.cat(
(masked_indices1, masked_indices2 + num_patches // 2, masked_indices2 + num_patches * 3 // 4)))
mask_tokens = mask_tokens + decoder_pos_emb
decoder_tokens = torch.cat((decoder_tokens, mask_tokens), dim=1)
decoded_tokens = self.decoder(decoder_tokens)
mask_tokens = decoded_tokens[:, -mask_tokens.shape[1]:]
pred_pixel_values_t = self.to_pixels[0](torch.cat((decoder_tokens[:, 1: t_idx + 1], mask_tokens[:, : masked_num_t]), 1))
pred_pixel_values_m = self.to_pixels[1](torch.cat((decoder_tokens[:, t_idx+2: m_idx+1], mask_tokens[:, masked_num_t: masked_num_f // 2 + masked_num_t]), 1))
pred_pixel_values_p = self.to_pixels[2](torch.cat((decoder_tokens[:, m_idx+2: -mask_tokens.shape[1]], mask_tokens[:, -masked_num_f // 2:]), 1))
pred_pixel_values = torch.cat((pred_pixel_values_t, pred_pixel_values_m, pred_pixel_values_p), dim=1)
recon_loss = F.mse_loss(pred_pixel_values, rearrange(patches[:,rand_indices], 'b n c p -> b n (c p)'))
rmvcls_embedding = torch.cat((decoder_tokens[:, 1: t_idx + 1], mask_tokens[:, : masked_num_t]), 1)
click_pred = self.to_clicks(rmvcls_embedding)
click_pred = rearrange(click_pred, 'b n (c p) -> b n c p', p=2)
click_pred = rearrange(click_pred, 'b n c p -> (b n c) p')
clicksGT = clickpatches[:, rand_indices1].squeeze()
clicksGT = rearrange(clicksGT, 'b n c -> (b n c)')
clicksOH = F.one_hot(clicksGT.to(torch.int64), num_classes=2)
pos_weight = (clicks==0).sum()/(clicks==1).sum()
click_pred_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(click_pred, clicksOH.float())
return recon_loss + beta * idc_loss + beta2 * click_pred_loss
class Model(nn.Module):
def __init__(self, opt):
super(Model, self).__init__()
self.opt = opt
self.encoder = TFR_Encoder(seq_len=opt.seq_len,
patch_len=opt.patch_len,
dim=opt.dim,
in_dim=opt.in_dim,
depth=opt.depth)
self.crt = CRT(encoder=self.encoder,
decoder_dim=opt.dim,
in_dim=opt.in_dim,
patch_len=opt.patch_len)
def forward(self, x, clicks, ratio = 0.5):
return self.crt(x, clicks, mask_ratio=ratio, beta2=self.opt.beta2)
def self_supervised_learning(model, X, clicks, opt, modelfile, min_ratio=0.3, max_ratio=0.8):
optimizer = optim.Adam(model.parameters(), opt.AElearningRate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20)
model.to(opt.device)
model.train()
dataset = SSLDataSet(X, clicks)
dataloader = DataLoader(dataset, batch_size=opt.AEbatchSize, shuffle=True)
losses = []
for _ in range(opt.AEepochs):
for idx, batch in enumerate(dataloader):
x, clicks = tuple(t.to(opt.device) for t in batch)
loss = model.to(opt.device)(x, clicks, ratio=max(min_ratio, min(max_ratio, _ / opt.AEepochs)))
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(float(loss))
if idx % 20 == 0:
print(idx,'/',len(dataloader), sum(losses) / len(losses))
scheduler.step(_)
torch.save(model.to('cpu'), modelfile)
class encoder_classifier(nn.Module):
def __init__(self, AE, opt):
super().__init__()
self.opt = opt
hidden_channels = 64
self.inputdim = opt.dim
self.encoder = copy.deepcopy(AE.encoder)
self.classifier = nn.Sequential(
nn.Linear(self.inputdim, hidden_channels),
nn.ReLU(),
nn.Dropout(p = opt.Headdropout),
nn.Linear(hidden_channels, hidden_channels),
nn.ReLU(),
nn.Dropout(p = opt.Headdropout),
nn.Linear(hidden_channels, opt.num_class)
)
# initialize the classifier
for m in self.classifier.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def forward(self, x):
features = self.encoder(x, self.opt)
out = self.classifier(features)
return out
def weight_init_layer(m):
if hasattr(m, 'reset_parameters'):
m.reset_parameters()
def weight_init_whole(m):
children = list(m.children())
if len(children)==0: # No more children layers
m.apply(weight_init_layer)
else: # Get children layers
for c in children:
weight_init_whole(c)
def finetuning(AE, train_set, valid_set, opt, modelfile, multi_label=True):
stage_info = {0: 'Train the classifier only', 1: 'Finetune the whole model'}
print('Stage %d: '%(opt.stage), stage_info[opt.stage])
train_loader = DataLoader(train_set, batch_size=opt.HeadbatchSize, shuffle=True)
model = encoder_classifier(AE, opt)
model.train()
best_res = 0
step = 0
if opt.stage == 0:
optimizer = optim.Adam(model.classifier.parameters(), lr=opt.HeadlearningRate)
else:
lr = opt.HeadlearningRate/2
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20)
for epoch in range(opt.Headepochs):
for batch_idx, batch in enumerate(train_loader):
step += 1
x, y, clicks = tuple(t.to(opt.device) for t in batch)
pred = model.to(opt.device)(x)
if not multi_label:
pos_weight = (y==0.).sum()/y.sum()
y = F.one_hot(y, num_classes=2)
#loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(pred, y.squeeze().float())
loss = nn.BCEWithLogitsLoss()(pred, y.squeeze().float())
else:
if pred.shape[0]==1:
loss = nn.CrossEntropyLoss()(pred, y.long())
else:
loss = nn.CrossEntropyLoss()(pred, y.squeeze().long())
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 10 == 0:
[trainacc, trainf1] = test(model, train_set, opt.HeadbatchSize, multi_label)
if trainf1 > best_res:
best_res = trainf1
scheduler.step(best_res)
def test(model, dataset, batch_size, multi_label):
def topAcc(GTlabel, pred_proba, top):
count = 0
for i in range(GTlabel.shape[0]):
if (GTlabel[i]) in np.argsort(-pred_proba[i])[:top]:
count+=1.0
return count/GTlabel.shape[0]
model.eval()
model.to(model.opt.device)
testloader = DataLoader(dataset, batch_size=batch_size)
pred_prob = None
with torch.no_grad():
for batch in testloader:
x, y, clicks = tuple(t.to(model.opt.device) for t in batch)
pred = model(x)
if pred_prob is None:
pred_prob = pred.cpu().detach().numpy()
else:
pred_prob = np.concatenate([pred_prob, pred.cpu().detach().numpy()], axis=0)
acc = topAcc(dataset.label.squeeze(), pred_prob, 1)
pred = np.argmax(pred_prob, axis=1)
f1 = f1_score(dataset.label.squeeze(), pred, average='macro')
model.train()
return [acc, f1]

69
README.md Normal file
View file

@ -0,0 +1,69 @@
<div align="center">
<h1> Mouse2Vec: Learning Reusable Semantic Representations of Mouse Behaviour </h1>
**[Guanhua Zhang][4], &nbsp; [Zhiming Hu][5], &nbsp; [Mihai Bâce][3], &nbsp; [Andreas Bulling][6]** <br>
**ACM CHI 2024**, Honolulu, Hawaii <br>
**[[Project][2]]** **[[Paper][7]]** </div>
----------------
# Setup
We recommend to setup a virtual environment using Anaconda. <br>
1. Create a conda environment and install dependencies
```shell
conda env create --name mouse2vec --file=env.yaml
conda activate mouse2vec
```
2. Clone our repository to download our code and a pretrained model
```shell
git clone this_repo.git
```
# Run the code
Our code supports training using GPUs or CPUs. It will prioritise GPUs if available (line 45 in main.py). You can also assign a particular card via CUDA_VISIBLE_DEVICES (e.g., the following commands use GPU card no.3).
## Train Mouse2Vec Autoencoder
<br>
Execute
```shell
CUDA_VISIBLE_DEVICES=3 python main.py --ssl True
```
## Use Mouse2Vec for Downstream Tasks
We enable two ways to use Mouse2Vec on your datasets for downstream tasks.
1. Use the (frozen) pretrained model and only train a MLP-based classifier <br>
Execute
```shell
CUDA_VISIBLE_DEVICES=3 python main.py --ssl True --load True --stage 0 --testDataset [Your Dataset]
```
2. Finetune both Mouse2Vec and the classifier <br>
Execute
```shell
CUDA_VISIBLE_DEVICES=3 python main.py --ssl True --load True --stage 1 --testDataset [Your Dataset]
```
# Citation
If you find our code useful or use it in your own projects, please cite our paper:
```
@inproceedings{zhang24_chi,
title = {Mouse2Vec: Learning Reusable Semantic Representations of Mouse Behaviour},
author = {Zhang, Guanhua and Hu, Zhiming and B{\^a}ce, Mihai and Bulling, Andreas},
year = {2024},
pages = {1--17},
booktitle = {Proc. ACM SIGCHI Conference on Human Factors in Computing Systems (CHI)},
doi = {10.1145/3613904.3642141}
}
```
# Acknowledgements
Our work relied on the codebase of [Cross Reconstruction Transformer][1]. Thanks to the authors for sharing their code.
[1]: https://github.com/BobZwr/Cross-Reconstruction-Transformer
[2]: https://perceptualui.org/publications/zhang24_chi/
[3]: https://scholar.google.com/citations?user=ku-t0MMAAAAJ&hl=en&oi=ao
[4]: https://scholar.google.com/citations?user=NqkK0GwAAAAJ&hl=en
[5]: https://scholar.google.com/citations?hl=en&user=OLB_xBEAAAAJ
[6]: https://www.perceptualui.org/people/bulling/
[7]: https://perceptualui.org/publications/zhang24_chi.pdf

274
base_models.py Normal file
View file

@ -0,0 +1,274 @@
import numpy as np, pdb
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
class SSLDataSet(Dataset):
def __init__(self, data, clicks):
super(SSLDataSet, self).__init__()
self.data = data
self.clicks = clicks
def __getitem__(self, idx):
return torch.tensor(self.data[idx], dtype=torch.float), torch.tensor(self.clicks[idx], dtype=torch.float)
def __len__(self):
return self.data.shape[0]
class FTDataSet(Dataset):
def __init__(self, data, label, clicks, actions=None, multi_label=False):
super(FTDataSet, self).__init__()
self.data = data
self.label = label
self.clicks = clicks
self.multi_label = multi_label
self.actions = actions
def __getitem__(self, index):
if self.multi_label:
if self.actions is None:
return (torch.tensor(self.data[index], dtype=torch.float), torch.tensor(self.label[index], dtype=torch.float), torch.tensor(self.clicks[index], dtype=torch.float))
else:
return (torch.tensor(self.data[index], dtype=torch.float), torch.tensor(self.label[index], dtype=torch.float), torch.tensor(self.clicks[index], dtype=torch.float), torch.tensor(self.actions[index], dtype=torch.float))
else:
return (torch.tensor(self.data[index], dtype=torch.float), torch.tensor(self.label[index], dtype=torch.long), torch.tensor(self.clicks[index], dtype=torch.float))
def __len__(self):
return self.data.shape[0]
# Resnet 1d
def conv(in_planes, out_planes, stride=1, kernel_size=3):
"convolution with padding 自动使用zeros进行padding"
return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=(kernel_size - 1) // 2, bias=False)
class ZeroPad1d(nn.Module):
def __init__(self, pad_left, pad_right):
super().__init__()
self.pad_left = pad_left
self.pad_right = pad_right
def forward(self, x):
return F.pad(x, (self.pad_left, self.pad_right))
class BasicBlock1d(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = conv(inplanes, planes, stride=stride, kernel_size=3)
self.bn1 = nn.BatchNorm1d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv(planes, planes, kernel_size=3)
self.bn2 = nn.BatchNorm1d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck1d(nn.Module):
"""Bottleneck for ResNet52 ..."""
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
kernel_size = 3
self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm1d(planes)
self.conv2 = nn.Conv1d(planes, planes, kernel_size=kernel_size, stride=stride,
padding=(kernel_size - 1) // 2, bias=False)
self.bn2 = nn.BatchNorm1d(planes)
self.conv3 = nn.Conv1d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm1d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet1d(nn.Module):
'''1d adaptation of the torchvision resnet'''
def __init__(self, block, layers, kernel_size=3, input_channels=12, inplanes=64,
fix_feature_dim=False, kernel_size_stem=None, stride_stem=2, pooling_stem=True,
stride=2):
super(ResNet1d, self).__init__()
self.inplanes = inplanes
layers_tmp = []
if kernel_size_stem is None:
kernel_size_stem = kernel_size[0] if isinstance(kernel_size, list) else kernel_size
# conv-bn-relu (basic feature extraction)
layers_tmp.append(nn.Conv1d(input_channels, inplanes,
kernel_size=kernel_size_stem,
stride=stride_stem,
padding=(kernel_size_stem - 1) // 2, bias=False))
layers_tmp.append(nn.BatchNorm1d(inplanes))
layers_tmp.append(nn.ReLU(inplace=True))
if pooling_stem is True:
layers_tmp.append(nn.MaxPool1d(kernel_size=3, stride=2, padding=1))
for i, l in enumerate(layers):
if i == 0:
layers_tmp.append(self._make_block(block, inplanes, layers[0]))
else:
layers_tmp.append(
self._make_block(block, inplanes if fix_feature_dim else (2 ** i) * inplanes, layers[i],
stride=stride))
self.feature_extractor = nn.Sequential(*layers_tmp)
def _make_block(self, block, planes, blocks, stride=1, kernel_size=3):
down_sample = None
# 注定会进行下采样
if stride != 1 or self.inplanes != planes * block.expansion:
down_sample = nn.Sequential(
nn.Conv1d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm1d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, down_sample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
return self.feature_extractor(x)
def resnet1d14(inplanes, input_channels):
return ResNet1d(BasicBlock1d, [2,2,2], inplanes=inplanes, input_channels=input_channels)
def resnet1d18(**kwargs):
return ResNet1d(BasicBlock1d, [2, 2, 2, 2], **kwargs)
# MLP
class MLP(nn.Module):
def __init__(self, in_channels, hidden_channels, n_classes, bn = True):
super(MLP, self).__init__()
self.in_channels = in_channels
self.n_classes = n_classes
self.hidden_channels = hidden_channels
self.fc1 = nn.Linear(self.in_channels, self.hidden_channels)
self.fc2 = nn.Linear(self.hidden_channels, self.n_classes)
self.ac = nn.ReLU()
self.bn = nn.BatchNorm1d(hidden_channels)
self.ln = nn.LayerNorm(hidden_channels)
self.fc3 = nn.Linear(self.hidden_channels, self.hidden_channels)
def forward(self, x):
hidden = self.fc1(x)
hidden = self.bn(hidden)
hidden = self.ac(hidden)
'''hidden = self.fc3(hidden)
hidden = self.ln2(hidden)
hidden = self.ac(hidden)'''
out = self.fc2(hidden)
return out
# Time-steps features -> aggregated features
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, tensor):
b = tensor.size(0)
return tensor.reshape(b, -1)
class AdaptiveConcatPool1d(nn.Module):
"Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`."
def __init__(self, sz=None):
"Output will be 2*sz or 2 if sz is None"
super().__init__()
sz = sz or 1
self.ap, self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz)
def forward(self, x):
"""x is shaped of B, C, T"""
return torch.cat([self.mp(x), self.ap(x), x[..., -1:]], 1)
def bn_drop_lin(n_in, n_out, bn, p, actn):
"`n_in`->bn->dropout->linear(`n_in`,`n_out`)->`actn`"
layers = list()
if bn:
layers.append(nn.BatchNorm1d(n_in))
if p > 0.:
layers.append(nn.Dropout(p=p))
layers.append(nn.Linear(n_in, n_out))
if actn is not None:
layers.append(actn)
return layers
def create_head1d(nf: int, nc: int, lin_ftrs=[512, ], dropout=0.5, bn: bool = True, act="relu"):
lin_ftrs = [3 * nf] + lin_ftrs + [nc]
activations = [nn.ReLU(inplace=True) if act == "relu" else nn.ELU(inplace=True)] * (len(lin_ftrs) - 2) + [None]
layers = [AdaptiveConcatPool1d(), Flatten()]
for ni, no, actn in zip(lin_ftrs[:-1], lin_ftrs[1:], activations):
layers += bn_drop_lin(ni, no, bn, dropout, actn)
layers += [nn.Sigmoid()]
return nn.Sequential(*layers)

276
env.yaml Normal file
View file

@ -0,0 +1,276 @@
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=4.5=1_gnu
- alembic=1.8.1=py39h06a4308_0
- argon2-cffi=21.3.0=pyhd3eb1b0_0
- argon2-cffi-bindings=21.2.0=py39h7f8727e_0
- asttokens=2.0.5=pyhd3eb1b0_0
- attrs=22.1.0=py39h06a4308_0
- autopage=0.5.0=pyhd8ed1ab_0
- backcall=0.2.0=pyhd3eb1b0_0
- blas=1.0=mkl
- bleach=4.1.0=pyhd3eb1b0_0
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2023.01.10=h06a4308_0
- certifi=2022.12.7=py39h06a4308_0
- cffi=1.14.6=py39h400218f_0
- cliff=3.10.0=pyhd8ed1ab_0
- cmaes=0.8.2=pyh44b312d_0
- cmd2=2.3.3=py39hf3d152e_1
- colorlog=5.0.1=py39h06a4308_1
- comm=0.1.3=pyhd8ed1ab_0
- cudatoolkit=10.2.89=hfd86e86_1
- decorator=5.1.1=pyhd3eb1b0_0
- defusedxml=0.7.1=pyhd3eb1b0_0
- entrypoints=0.4=py39h06a4308_0
- executing=0.8.3=pyhd3eb1b0_0
- ffmpeg=4.3=hf484d3e_0
- freetype=2.11.0=h70c0345_0
- giflib=5.2.1=h7b6447c_0
- gmp=6.2.1=h2531618_2
- gnutls=3.6.15=he1e5248_0
- intel-openmp=2021.4.0=h06a4308_3561
- ipython=8.10.0=py39h06a4308_0
- ipython_genutils=0.2.0=pyhd3eb1b0_1
- jedi=0.18.1=py39h06a4308_1
- jinja2=3.1.2=py39h06a4308_0
- jpeg=9d=h7f8727e_0
- jsonschema=4.17.3=py39h06a4308_0
- jupyter_client=7.1.2=pyhd3eb1b0_0
- jupyter_core=5.3.0=py39h06a4308_0
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.35.1=h7274673_9
- libblas=3.9.0=12_linux64_mkl
- libcblas=3.9.0=12_linux64_mkl
- libffi=3.3=he6710b0_2
- libgcc-ng=9.3.0=h5101ec6_17
- libgfortran-ng=7.5.0=ha8ba4b0_17
- libgfortran4=7.5.0=ha8ba4b0_17
- libgomp=9.3.0=h5101ec6_17
- libiconv=1.15=h63c8f33_5
- libidn2=2.3.2=h7f8727e_0
- liblapack=3.9.0=12_linux64_mkl
- libllvm10=10.0.1=hbcb73fb_5
- libpng=1.6.37=hbc83047_0
- libsodium=1.0.18=h7b6447c_0
- libstdcxx-ng=9.3.0=hd4cf53a_17
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.2.0=h85742a9_0
- libunistring=0.9.10=h27cfd23_0
- libuv=1.40.0=h7b6447c_0
- libwebp=1.2.0=h89dd481_0
- libwebp-base=1.2.0=h27cfd23_0
- llvmlite=0.36.0=py39h612dafd_4
- lz4-c=1.9.3=h295c915_1
- mako=1.2.3=py39h06a4308_0
- markupsafe=2.1.1=py39h7f8727e_0
- matplotlib-inline=0.1.6=py39h06a4308_0
- mistune=0.8.4=py39h27cfd23_1000
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py39h7f8727e_0
- mkl_fft=1.3.1=py39hd3c417c_0
- mkl_random=1.2.2=py39h51133e4_0
- nbconvert=5.5.0=py_0
- nbformat=5.7.0=py39h06a4308_0
- ncurses=6.3=h7f8727e_2
- nest-asyncio=1.5.6=py39h06a4308_0
- nettle=3.7.3=hbbd107a_1
- notebook=6.4.6=pyha770c72_0
- numba=0.53.1=py39ha9443f7_0
- numpy=1.21.2=py39h20f2e39_0
- numpy-base=1.21.2=py39h79a1101_0
- olefile=0.46=pyhd3eb1b0_0
- openh264=2.1.0=hd408876_0
- openssl=1.1.1t=h7f8727e_0
- optuna=2.10.0=pyhd8ed1ab_0
- packaging=23.0=py39h06a4308_0
- pandoc=2.12=h06a4308_3
- pandocfilters=1.5.0=pyhd3eb1b0_0
- parso=0.8.3=pyhd3eb1b0_0
- pbr=5.6.0=pyhd3eb1b0_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pickleshare=0.7.5=pyhd3eb1b0_1003
- pillow=8.4.0=py39h5aabda8_0
- pip=21.2.4=py39h06a4308_0
- platformdirs=2.5.2=py39h06a4308_0
- plotly=4.14.3=pyhd3eb1b0_0
- prometheus_client=0.14.1=py39h06a4308_0
- prompt-toolkit=3.0.36=py39h06a4308_0
- psutil=5.8.0=py39h3811e60_1
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pure_eval=0.2.2=pyhd3eb1b0_0
- pycparser=2.21=pyhd3eb1b0_0
- pygments=2.11.2=pyhd3eb1b0_0
- pyparsing=3.0.9=py39h06a4308_0
- pyperclip=1.8.2=pyhd8ed1ab_2
- pyrsistent=0.18.0=py39heee7806_0
- python=3.9.7=h12debd9_1
- python-dateutil=2.8.2=pyhd3eb1b0_0
- python-fastjsonschema=2.16.2=py39h06a4308_0
- python_abi=3.9=2_cp39
- pytorch=1.10.1=py3.9_cuda10.2_cudnn7.6.5_0
- pytorch-mutex=1.0=cuda
- pyts=0.12.0=pyh6c4a22f_0
- readline=8.1=h27cfd23_0
- retrying=1.3.3=pyhd3eb1b0_2
- send2trash=1.8.0=pyhd3eb1b0_1
- setuptools=58.0.4=py39h06a4308_0
- six=1.16.0=pyhd3eb1b0_0
- sqlalchemy=1.3.23=py39h3811e60_0
- sqlite=3.36.0=hc218d9a_0
- stack_data=0.2.0=pyhd3eb1b0_0
- stevedore=3.5.0=py39hf3d152e_2
- tbb=2020.3=hfd86e86_0
- terminado=0.17.1=py39h06a4308_0
- testpath=0.6.0=py39h06a4308_0
- tk=8.6.11=h1ccaba5_0
- torchaudio=0.10.1=py39_cu102
- torchvision=0.11.2=py39_cu102
- tornado=6.1=py39h27cfd23_0
- tqdm=4.65.0=py39hb070fc8_0
- traitlets=5.7.1=py39h06a4308_0
- typing-extensions=3.10.0.2=hd3eb1b0_0
- typing_extensions=3.10.0.2=pyh06a4308_0
- wcwidth=0.2.5=pyhd3eb1b0_0
- webencodings=0.5.1=py39h06a4308_1
- wheel=0.37.0=pyhd3eb1b0_1
- xz=5.2.5=h7b6447c_0
- yaml=0.2.5=h7b6447c_0
- zeromq=4.3.4=h2531618_0
- zlib=1.2.11=h7f8727e_4
- zstd=1.4.9=haebb681_0
- pip:
- absl-py==1.0.0
- aiohttp==3.8.1
- aiosignal==1.2.0
- altair==4.1.0
- appdirs==1.4.4
- astor==0.8.1
- async-timeout==4.0.2
- base58==2.1.1
- blinker==1.4
- cachetools==4.2.4
- charset-normalizer==2.0.9
- click==7.1.2
- configparser==5.2.0
- cycler==0.11.0
- cython==0.29.35
- datasets==1.17.0
- debugpy==1.5.1
- deprecated==1.2.13
- dictset==0.3.1.2
- dill==0.3.4
- docker-pycreds==0.4.0
- dtaidistance==2.3.10
- einops==0.6.0
- fastdtw==0.3.4
- filelock==3.4.0
- fonttools==4.28.5
- frozenlist==1.2.0
- fsspec==2021.11.1
- gitdb==4.0.9
- gitpython==3.1.24
- google-auth==2.3.3
- google-auth-oauthlib==0.4.6
- grpcio==1.43.0
- h5py==3.6.0
- hdbscan==0.8.29
- hmmlearn==0.2.8
- huggingface-hub==0.2.1
- idna==3.3
- imbalanced-learn==0.9.0
- import-ipynb==0.1.3
- importlib-metadata==6.6.0
- install-jdk==0.3.0
- ipykernel==6.6.0
- ipynb==0.5.1
- ipywidgets==7.6.5
- javabridge==1.0.16
- joblib==1.1.0
- jupyterlab-widgets==1.0.2
- kiwisolver==1.3.2
- markdown==3.3.6
- matplotlib==3.5.1
- msgpack==1.0.3
- multidict==5.2.0
- multiprocess==0.70.12.2
- nltk==3.7
- numexpr==2.8.1
- oauthlib==3.1.1
- opencv-python==4.5.5.62
- pandas==1.5.3
- pathtools==0.1.2
- patsy==0.5.2
- promise==2.3
- protobuf==3.19.1
- pyarrow==6.0.1
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pydeck==0.7.1
- pympler==1.0
- pynndescent==0.5.10
- pystaggrelite3==0.1.3
- python-crfsuite==0.9.9
- python-weka-wrapper==0.3.18
- pytorch-ignite==0.4.8
- pytorch-warmup==0.1.0
- pytz==2021.3
- pytz-deprecation-shim==0.1.0.post0
- pyyaml==6.0
- pyzmq==22.3.0
- ray==1.9.2
- redis==4.1.1
- regex==2021.11.10
- requests==2.26.0
- requests-oauthlib==1.3.0
- rsa==4.8
- sacremoses==0.0.46
- scikit-learn==1.0.1
- scipy==1.10.1
- seaborn==0.11.2
- sentencepiece==0.1.96
- sentry-sdk==1.5.1
- seqeval==1.2.2
- setproctitle==1.3.2
- shortuuid==1.0.8
- simpletransformers==0.63.3
- sklearn-crfsuite==0.3.6
- smmap==5.0.0
- statsmodels==0.13.2
- streamlit==1.3.0
- subprocess32==3.5.4
- tables==3.7.0
- tabulate==0.8.9
- tensorboard==2.11.2
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- termcolor==1.1.0
- textdistance==4.2.2
- threadpoolctl==3.0.0
- tokenizers==0.10.3
- toml==0.10.2
- toolz==0.11.2
- torch-summary==1.4.5
- torch-tb-profiler==0.4.1
- torchinfo==1.7.0
- torchtext==0.11.1
- transformers==4.14.1
- tzdata==2021.5
- tzlocal==4.1
- umap-learn==0.5.3
- urllib3==1.26.7
- validators==0.18.2
- watchdog==2.1.6
- weka==2.0.2
- werkzeug==2.0.2
- widgetsnbextension==3.5.2
- wrapt==1.13.3
- xxhash==2.0.2
- yarl==1.7.2
- yaspin==2.1.0
- zipp==3.15.0

77
main.py Normal file
View file

@ -0,0 +1,77 @@
import numpy as np
import random, pdb, os, copy
import argparse
import pandas as pd
import pickle as pkl
import torch
from utils import MakeDir, set_seed, str2bool
from CRT import Model, self_supervised_learning, finetuning
from base_models import FTDataSet
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--ssl", type=str2bool, default=False, help='Self-supervised learning (pretrain the autoencoder)')
parser.add_argument("--sl", type=str2bool, default=True, help='Supervised learning (downstream tasks)')
parser.add_argument("--load", type=str2bool, default=True)
parser.add_argument("--patch_len", type=int, default=5)
parser.add_argument("--dim", type=int, default=128)
parser.add_argument("--depth", type=int, default=6)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--Headdropout", type=float, default=0.1)
parser.add_argument("--beta2", type=float, default=1)
parser.add_argument("--AElearningRate", type=float, default=1e-4)
parser.add_argument("--HeadlearningRate", type=float, default=1e-3)
parser.add_argument("--AEbatchSize", type=int, default=512)
parser.add_argument("--HeadbatchSize", type=int, default=32)
parser.add_argument("--AEepochs", type=int, default=100)
parser.add_argument("--Headepochs", type=int, default=50)
parser.add_argument("--AEwin", type=int, default=5)
parser.add_argument("--Headwin", type=int, default=5)
parser.add_argument("--slid", type=int, default=1)
parser.add_argument("--timeWinFreq", type=int, default=20)
parser.add_argument("--pretrainDataset", type=str, default='Buffalo_EMAKI')
parser.add_argument("--testDataset", type=str)
parser.add_argument("--stage", type=int, choices=[0,1], default=0, help='0: Train the classifier only; 1: Finetune the whole model')
parser.add_argument("--seed", type=int, default=0)
opt = parser.parse_args()
opt.in_dim = 2 # x,y
opt.seq_len = opt.AEwin * opt.timeWinFreq *2
opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_seed(opt.seed)
aemodelfile = 'pretrained_model.pkl'
if opt.ssl:
'''
Load dataset for pretraining the Autoencoder
taskdata: on-screen coordinates (x,y), their magnitude and phase
'''
taskdata, taskclicks = loadPretrainDataset(opt)
model = Model(opt).to(opt.device)
self_supervised_learning(model, taskdata, taskclicks, opt, modelfile=aemodelfile)
if opt.load:
model = torch.load(aemodelfile, map_location=opt.device)
else:
model = Model(opt).to(opt.device)
pdb.set_trace()
if opt.sl:
'''
Load dataset for downstream tasks
_X: on-screen coordinates (x,y), their magnitude and phase
_y: labels
_clicks: clicks
'''
train_X, train_y, train_clicks, valid_X, valid_y, valid_clicks = loadDownstreamDataset(opt)
TrainSet = FTDataSet(train_X, train_y, train_clicks, multi_label=True) # Binary or multi-class
ValidSet = FTDataSet(valid_X, valid_y, valid_clicks, multi_label=True)
finetuning(model, TrainSet, ValidSet, opt, aemodelfile.split('.pkl')[0]+'downstream.pkl', multi_label=True)

BIN
pretrained_model.pkl Normal file

Binary file not shown.

52
utils.py Normal file
View file

@ -0,0 +1,52 @@
import pdb, os, random, copy
import pandas as pd
import numpy as np
import pickle as pkl
import torch
from torch.utils.data import DataLoader,Dataset
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Unsupported value encountered.')
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def DFT(data, clicks=None):
for idx, channel in enumerate(range(data.shape[-1])):
freq = np.fft.fft(data[:,:,channel])
freq = freq[:, :data.shape[1]//2] # symmetric
mag = np.abs(freq)
phase = np.angle(freq)
mag = (mag - np.min(mag)) / (np.max(mag) - np.min(mag))
phase = (phase - np.min(phase)) / (np.max(phase) - np.min(phase))
if idx==0:
mags, phases = mag, phase
else:
mags, phases = np.stack([mags, mag], axis=-1), np.stack([phases, phase], axis=-1)
data = np.concatenate([data, mags, phases], axis=1)
return data, clicks
def MakeDir(dirName):
if not os.path.exists(dirName):
os.makedirs(dirName)
class simpleSet(Dataset):
def __init__(self, features, labels):
self.features = features
labels = labels.astype(int)
self.labels = labels - np.min(labels)
def __getitem__(self, index):
return self.features[index], self.labels[index]
def __len__(self):
return len(self.features)