init
This commit is contained in:
commit
b429f9807d
7 changed files with 1186 additions and 0 deletions
438
CRT.py
Normal file
438
CRT.py
Normal file
|
@ -0,0 +1,438 @@
|
|||
import pdb, copy
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from base_models import MLP, resnet1d18, SSLDataSet
|
||||
from sklearn.metrics import f1_score
|
||||
|
||||
class cnn_extractor(nn.Module):
|
||||
def __init__(self, dim, input_plane):
|
||||
super(cnn_extractor, self).__init__()
|
||||
self.cnn = resnet1d18(input_channels=dim, inplanes=input_plane)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.cnn(x)
|
||||
return x
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
|
||||
class TFR(nn.Module):
|
||||
def __init__(self, seq_len, patch_len, dim, depth, heads, mlp_dim, channels=12,
|
||||
dim_head=64, dropout=0., emb_dropout=0.):
|
||||
'''
|
||||
The encoder of CRT
|
||||
'''
|
||||
super().__init__()
|
||||
self.patch_len = patch_len
|
||||
self.seq_len = seq_len
|
||||
|
||||
assert seq_len % (4 * patch_len) == 0, \
|
||||
'The seq_len should be 4 * n * patch_len, or there must be patch with both magnitude and phase data.'
|
||||
num_patches = seq_len // patch_len
|
||||
patch_dim = channels * patch_len
|
||||
|
||||
self.to_patch = nn.Sequential(Rearrange('b c (n p1) -> b n c p1', p1=patch_len),
|
||||
Rearrange('b n c p1 -> (b n) c p1'))
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 3, dim))
|
||||
self.modal_embedding = nn.Parameter(torch.randn(3, 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 3, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.cnn1 = cnn_extractor(dim=channels, input_plane=dim // 8) # For temporal data
|
||||
self.cnn2 = cnn_extractor(dim=channels, input_plane=dim // 8) # For magnitude data
|
||||
self.cnn3 = cnn_extractor(dim=channels, input_plane=dim // 8) # For phase data
|
||||
|
||||
def forward(self, x, opt):
|
||||
batch, _, time_steps = x.shape
|
||||
t, m, p = x[:, :, :time_steps // 2], x[:, :, time_steps // 2: time_steps * 3 // 4], x[:, :, -time_steps // 4:]
|
||||
assert t.shape[-1] % opt.patch_len == 0 == m.shape[-1] % opt.patch_len == p.shape[-1] % opt.patch_len
|
||||
|
||||
t, m, p = self.to_patch(t), self.to_patch(m), self.to_patch(p)
|
||||
patch2seq = nn.Sequential(nn.AdaptiveAvgPool1d(1),
|
||||
Rearrange('(b n) c 1 -> b n c', b=batch))
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=batch)
|
||||
x = torch.cat((cls_tokens[:, 0:1, :], patch2seq(self.cnn1(t)),
|
||||
cls_tokens[:, 1:2, :], patch2seq(self.cnn2(m)),
|
||||
cls_tokens[:, 2:3, :], patch2seq(self.cnn3(p))), dim=1)
|
||||
|
||||
b, t, c = x.shape
|
||||
time_steps = t - 3
|
||||
t_token_idx, m_token_idx, p_token_idx = 0, time_steps // 2 + 1, time_steps * 3 // 4 + 2
|
||||
x[:m_token_idx] += self.modal_embedding[:1]
|
||||
x[m_token_idx: p_token_idx] += self.modal_embedding[1:2]
|
||||
x[p_token_idx: ] += self.modal_embedding[2:]
|
||||
x += self.pos_embedding[:, : t]
|
||||
x = self.dropout(x)
|
||||
x = self.transformer(x)
|
||||
t_token, m_token, p_token = x[:, t_token_idx], x[:, m_token_idx], x[:, p_token_idx]
|
||||
avg = (t_token + m_token + p_token) / 3
|
||||
return avg
|
||||
|
||||
|
||||
def TFR_Encoder(seq_len, patch_len, dim, in_dim, depth):
|
||||
vit = TFR(seq_len=seq_len,
|
||||
patch_len=patch_len,
|
||||
#num_classes=num_class,
|
||||
dim=dim,
|
||||
depth=depth,
|
||||
heads=8,
|
||||
mlp_dim=dim,
|
||||
dropout=0.2,
|
||||
emb_dropout=0.1,
|
||||
channels=in_dim)
|
||||
return vit
|
||||
|
||||
class CRT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder,
|
||||
decoder_dim,
|
||||
decoder_depth=2,
|
||||
decoder_heads=8,
|
||||
decoder_dim_head=64,
|
||||
patch_len = 20,
|
||||
in_dim=12
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
|
||||
self.to_patch = encoder.to_patch
|
||||
pixel_values_per_patch = in_dim * patch_len
|
||||
|
||||
# decoder parameters
|
||||
self.modal_embedding = self.encoder.modal_embedding
|
||||
self.mask_token = nn.Parameter(torch.randn(3, decoder_dim))
|
||||
self.decoder = Transformer(dim=decoder_dim,
|
||||
depth=decoder_depth,
|
||||
heads=decoder_heads,
|
||||
dim_head=decoder_dim_head,
|
||||
mlp_dim=decoder_dim)
|
||||
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
|
||||
self.to_pixels = nn.ModuleList([nn.Linear(decoder_dim, pixel_values_per_patch) for i in range(3)])
|
||||
self.projs = nn.ModuleList([nn.Linear(decoder_dim, decoder_dim) for i in range(2)])
|
||||
|
||||
self.to_clicks = nn.Linear(decoder_dim, 2 * patch_len)
|
||||
|
||||
def IDC_loss(self, tokens, encoded_tokens):
|
||||
B, T, D = tokens.shape
|
||||
tokens, encoded_tokens = F.normalize(tokens, dim=-1), F.normalize(encoded_tokens, dim=-1)
|
||||
encoded_tokens = encoded_tokens.transpose(2, 1)
|
||||
cross_mul = torch.exp(torch.matmul(tokens, encoded_tokens))
|
||||
mask = (1 - torch.eye(T)).unsqueeze(0).to(tokens.device)
|
||||
cross_mul = cross_mul * mask
|
||||
return torch.log(cross_mul.sum(-1).sum(-1)).mean(-1)
|
||||
|
||||
def forward(self, x, clicks, mask_ratio=0.75, beta = 1e-4, beta2 = 1e-4):
|
||||
device = x.device
|
||||
patches = self.to_patch[0](x)
|
||||
batch, num_patches, c, length = patches.shape
|
||||
clickpatches = self.to_patch[0](clicks.unsqueeze(1).float())
|
||||
|
||||
num_masked = int(mask_ratio * num_patches)
|
||||
rand_indices1 = torch.randperm(num_patches // 2, device=device)
|
||||
masked_indices1 = rand_indices1[: num_masked // 2].sort()[0]
|
||||
unmasked_indices1 = rand_indices1[num_masked // 2:].sort()[0]
|
||||
rand_indices2 = torch.randperm(num_patches // 4, device=device)
|
||||
masked_indices2, unmasked_indices2 = rand_indices2[: num_masked // 4].sort()[0], rand_indices2[num_masked // 4:].sort()[0]
|
||||
rand_indices = torch.cat((masked_indices1, unmasked_indices1,
|
||||
masked_indices2 + num_patches // 2, unmasked_indices2 + num_patches // 2,
|
||||
masked_indices2 + num_patches // 4 * 3, unmasked_indices2 + num_patches // 4 * 3))
|
||||
|
||||
masked_num_t, masked_num_f = masked_indices1.shape[0], 2 * masked_indices2.shape[0]
|
||||
unmasked_num_t, unmasked_num_f = unmasked_indices1.shape[0], 2 * unmasked_indices2.shape[0]
|
||||
|
||||
tpatches = patches[:, : num_patches // 2, :, :]
|
||||
mpatches, ppatches = patches[:, num_patches // 2: num_patches * 3 // 4, :, :], patches[:, -num_patches // 4:, :, :]
|
||||
|
||||
unmasked_tpatches = tpatches[:, unmasked_indices1, :, :]
|
||||
unmasked_mpatches, unmasked_ppatches = mpatches[:, unmasked_indices2, :, :], ppatches[:, unmasked_indices2, :, :]
|
||||
t_tokens, m_tokens, p_tokens = self.to_patch[1](unmasked_tpatches), self.to_patch[1](unmasked_mpatches), self.to_patch[1](unmasked_ppatches)
|
||||
t_tokens, m_tokens, p_tokens = self.encoder.cnn1(t_tokens), self.encoder.cnn2(m_tokens), self.encoder.cnn3(p_tokens)
|
||||
Flat = nn.Sequential(nn.AdaptiveAvgPool1d(1),
|
||||
Rearrange('(b n) c 1 -> b n c', b=batch))
|
||||
t_tokens, m_tokens, p_tokens = Flat(t_tokens), Flat(m_tokens), Flat(p_tokens)
|
||||
ori_tokens = torch.cat((t_tokens, m_tokens, p_tokens), 1).clone()
|
||||
|
||||
cls_tokens = repeat(self.encoder.cls_token, '() n d -> b n d', b=batch)
|
||||
tokens = torch.cat((cls_tokens[:, 0:1, :], t_tokens,
|
||||
cls_tokens[:, 1:2, :], m_tokens,
|
||||
cls_tokens[:, 2:3, :], p_tokens), dim=1)
|
||||
|
||||
t_idx, m_idx, p_idx = num_patches // 2 - 1, num_patches * 3 // 4 - 1, num_patches - 1
|
||||
pos_embedding = torch.cat((self.encoder.pos_embedding[:, 0:1, :], self.encoder.pos_embedding[:, unmasked_indices1 + 1, :],
|
||||
self.encoder.pos_embedding[:, t_idx + 2: t_idx + 3],
|
||||
self.encoder.pos_embedding[:, unmasked_indices2 + t_idx + 3, :],
|
||||
self.encoder.pos_embedding[:, m_idx + 3: m_idx + 4],
|
||||
self.encoder.pos_embedding[:, unmasked_indices2 + m_idx + 4, :]), dim=1)
|
||||
|
||||
modal_embedding = torch.cat((repeat(self.modal_embedding[0], '1 d -> 1 n d', n=unmasked_num_t + 1),
|
||||
repeat(self.modal_embedding[1], '1 d -> 1 n d', n=unmasked_num_f // 2 + 1),
|
||||
repeat(self.modal_embedding[2], '1 d -> 1 n d', n=unmasked_num_f // 2 + 1)), dim=1)
|
||||
|
||||
tokens = tokens + pos_embedding + modal_embedding
|
||||
|
||||
encoded_tokens = self.encoder.transformer(tokens) # tokens: unmasked tokens + CLS tokens
|
||||
|
||||
t_idx, m_idx, p_idx = unmasked_num_t, unmasked_num_f // 2 + unmasked_num_t + 1, -1
|
||||
|
||||
idc_loss = self.IDC_loss(self.projs[0](ori_tokens), self.projs[1](torch.cat(([encoded_tokens[:, 1: t_idx+1], encoded_tokens[:, t_idx+2: m_idx+1], encoded_tokens[:, m_idx+2: ]]), dim=1)))
|
||||
|
||||
decoder_tokens = encoded_tokens
|
||||
|
||||
mask_tokens1 = repeat(self.mask_token[0], 'd -> b n d', b=batch, n=masked_num_t)
|
||||
mask_tokens2 = repeat(self.mask_token[1], 'd -> b n d', b=batch, n=masked_num_f // 2)
|
||||
mask_tokens3 = repeat(self.mask_token[2], 'd -> b n d', b=batch, n=masked_num_f // 2)
|
||||
mask_tokens = torch.cat((mask_tokens1, mask_tokens2, mask_tokens3), dim=1)
|
||||
|
||||
decoder_pos_emb = self.decoder_pos_emb(torch.cat(
|
||||
(masked_indices1, masked_indices2 + num_patches // 2, masked_indices2 + num_patches * 3 // 4)))
|
||||
|
||||
mask_tokens = mask_tokens + decoder_pos_emb
|
||||
decoder_tokens = torch.cat((decoder_tokens, mask_tokens), dim=1)
|
||||
decoded_tokens = self.decoder(decoder_tokens)
|
||||
|
||||
mask_tokens = decoded_tokens[:, -mask_tokens.shape[1]:]
|
||||
|
||||
pred_pixel_values_t = self.to_pixels[0](torch.cat((decoder_tokens[:, 1: t_idx + 1], mask_tokens[:, : masked_num_t]), 1))
|
||||
pred_pixel_values_m = self.to_pixels[1](torch.cat((decoder_tokens[:, t_idx+2: m_idx+1], mask_tokens[:, masked_num_t: masked_num_f // 2 + masked_num_t]), 1))
|
||||
pred_pixel_values_p = self.to_pixels[2](torch.cat((decoder_tokens[:, m_idx+2: -mask_tokens.shape[1]], mask_tokens[:, -masked_num_f // 2:]), 1))
|
||||
pred_pixel_values = torch.cat((pred_pixel_values_t, pred_pixel_values_m, pred_pixel_values_p), dim=1)
|
||||
|
||||
recon_loss = F.mse_loss(pred_pixel_values, rearrange(patches[:,rand_indices], 'b n c p -> b n (c p)'))
|
||||
|
||||
rmvcls_embedding = torch.cat((decoder_tokens[:, 1: t_idx + 1], mask_tokens[:, : masked_num_t]), 1)
|
||||
click_pred = self.to_clicks(rmvcls_embedding)
|
||||
click_pred = rearrange(click_pred, 'b n (c p) -> b n c p', p=2)
|
||||
click_pred = rearrange(click_pred, 'b n c p -> (b n c) p')
|
||||
clicksGT = clickpatches[:, rand_indices1].squeeze()
|
||||
clicksGT = rearrange(clicksGT, 'b n c -> (b n c)')
|
||||
clicksOH = F.one_hot(clicksGT.to(torch.int64), num_classes=2)
|
||||
pos_weight = (clicks==0).sum()/(clicks==1).sum()
|
||||
click_pred_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(click_pred, clicksOH.float())
|
||||
|
||||
return recon_loss + beta * idc_loss + beta2 * click_pred_loss
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, opt):
|
||||
super(Model, self).__init__()
|
||||
self.opt = opt
|
||||
self.encoder = TFR_Encoder(seq_len=opt.seq_len,
|
||||
patch_len=opt.patch_len,
|
||||
dim=opt.dim,
|
||||
in_dim=opt.in_dim,
|
||||
depth=opt.depth)
|
||||
self.crt = CRT(encoder=self.encoder,
|
||||
decoder_dim=opt.dim,
|
||||
in_dim=opt.in_dim,
|
||||
patch_len=opt.patch_len)
|
||||
|
||||
def forward(self, x, clicks, ratio = 0.5):
|
||||
return self.crt(x, clicks, mask_ratio=ratio, beta2=self.opt.beta2)
|
||||
|
||||
def self_supervised_learning(model, X, clicks, opt, modelfile, min_ratio=0.3, max_ratio=0.8):
|
||||
optimizer = optim.Adam(model.parameters(), opt.AElearningRate)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20)
|
||||
|
||||
model.to(opt.device)
|
||||
model.train()
|
||||
|
||||
dataset = SSLDataSet(X, clicks)
|
||||
dataloader = DataLoader(dataset, batch_size=opt.AEbatchSize, shuffle=True)
|
||||
|
||||
losses = []
|
||||
for _ in range(opt.AEepochs):
|
||||
for idx, batch in enumerate(dataloader):
|
||||
x, clicks = tuple(t.to(opt.device) for t in batch)
|
||||
loss = model.to(opt.device)(x, clicks, ratio=max(min_ratio, min(max_ratio, _ / opt.AEepochs)))
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
losses.append(float(loss))
|
||||
if idx % 20 == 0:
|
||||
print(idx,'/',len(dataloader), sum(losses) / len(losses))
|
||||
scheduler.step(_)
|
||||
torch.save(model.to('cpu'), modelfile)
|
||||
|
||||
class encoder_classifier(nn.Module):
|
||||
def __init__(self, AE, opt):
|
||||
super().__init__()
|
||||
self.opt = opt
|
||||
hidden_channels = 64
|
||||
self.inputdim = opt.dim
|
||||
self.encoder = copy.deepcopy(AE.encoder)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(self.inputdim, hidden_channels),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p = opt.Headdropout),
|
||||
nn.Linear(hidden_channels, hidden_channels),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p = opt.Headdropout),
|
||||
nn.Linear(hidden_channels, opt.num_class)
|
||||
)
|
||||
# initialize the classifier
|
||||
for m in self.classifier.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_normal_(m.weight)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
def forward(self, x):
|
||||
features = self.encoder(x, self.opt)
|
||||
out = self.classifier(features)
|
||||
return out
|
||||
|
||||
def weight_init_layer(m):
|
||||
if hasattr(m, 'reset_parameters'):
|
||||
m.reset_parameters()
|
||||
def weight_init_whole(m):
|
||||
children = list(m.children())
|
||||
if len(children)==0: # No more children layers
|
||||
m.apply(weight_init_layer)
|
||||
else: # Get children layers
|
||||
for c in children:
|
||||
weight_init_whole(c)
|
||||
|
||||
def finetuning(AE, train_set, valid_set, opt, modelfile, multi_label=True):
|
||||
stage_info = {0: 'Train the classifier only', 1: 'Finetune the whole model'}
|
||||
print('Stage %d: '%(opt.stage), stage_info[opt.stage])
|
||||
|
||||
train_loader = DataLoader(train_set, batch_size=opt.HeadbatchSize, shuffle=True)
|
||||
model = encoder_classifier(AE, opt)
|
||||
model.train()
|
||||
|
||||
best_res = 0
|
||||
step = 0
|
||||
if opt.stage == 0:
|
||||
optimizer = optim.Adam(model.classifier.parameters(), lr=opt.HeadlearningRate)
|
||||
else:
|
||||
lr = opt.HeadlearningRate/2
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20)
|
||||
for epoch in range(opt.Headepochs):
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
step += 1
|
||||
x, y, clicks = tuple(t.to(opt.device) for t in batch)
|
||||
pred = model.to(opt.device)(x)
|
||||
if not multi_label:
|
||||
pos_weight = (y==0.).sum()/y.sum()
|
||||
y = F.one_hot(y, num_classes=2)
|
||||
#loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(pred, y.squeeze().float())
|
||||
loss = nn.BCEWithLogitsLoss()(pred, y.squeeze().float())
|
||||
else:
|
||||
if pred.shape[0]==1:
|
||||
loss = nn.CrossEntropyLoss()(pred, y.long())
|
||||
else:
|
||||
loss = nn.CrossEntropyLoss()(pred, y.squeeze().long())
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if step % 10 == 0:
|
||||
[trainacc, trainf1] = test(model, train_set, opt.HeadbatchSize, multi_label)
|
||||
if trainf1 > best_res:
|
||||
best_res = trainf1
|
||||
scheduler.step(best_res)
|
||||
|
||||
def test(model, dataset, batch_size, multi_label):
|
||||
def topAcc(GTlabel, pred_proba, top):
|
||||
count = 0
|
||||
for i in range(GTlabel.shape[0]):
|
||||
if (GTlabel[i]) in np.argsort(-pred_proba[i])[:top]:
|
||||
count+=1.0
|
||||
return count/GTlabel.shape[0]
|
||||
|
||||
model.eval()
|
||||
model.to(model.opt.device)
|
||||
testloader = DataLoader(dataset, batch_size=batch_size)
|
||||
pred_prob = None
|
||||
with torch.no_grad():
|
||||
for batch in testloader:
|
||||
x, y, clicks = tuple(t.to(model.opt.device) for t in batch)
|
||||
pred = model(x)
|
||||
if pred_prob is None:
|
||||
pred_prob = pred.cpu().detach().numpy()
|
||||
else:
|
||||
pred_prob = np.concatenate([pred_prob, pred.cpu().detach().numpy()], axis=0)
|
||||
acc = topAcc(dataset.label.squeeze(), pred_prob, 1)
|
||||
pred = np.argmax(pred_prob, axis=1)
|
||||
f1 = f1_score(dataset.label.squeeze(), pred, average='macro')
|
||||
|
||||
model.train()
|
||||
return [acc, f1]
|
69
README.md
Normal file
69
README.md
Normal file
|
@ -0,0 +1,69 @@
|
|||
<div align="center">
|
||||
<h1> Mouse2Vec: Learning Reusable Semantic Representations of Mouse Behaviour </h1>
|
||||
|
||||
**[Guanhua Zhang][4], [Zhiming Hu][5], [Mihai Bâce][3], [Andreas Bulling][6]** <br>
|
||||
**ACM CHI 2024**, Honolulu, Hawaii <br>
|
||||
**[[Project][2]]** **[[Paper][7]]** </div>
|
||||
----------------
|
||||
|
||||
# Setup
|
||||
We recommend to setup a virtual environment using Anaconda. <br>
|
||||
1. Create a conda environment and install dependencies
|
||||
```shell
|
||||
conda env create --name mouse2vec --file=env.yaml
|
||||
conda activate mouse2vec
|
||||
```
|
||||
2. Clone our repository to download our code and a pretrained model
|
||||
```shell
|
||||
git clone this_repo.git
|
||||
```
|
||||
|
||||
# Run the code
|
||||
Our code supports training using GPUs or CPUs. It will prioritise GPUs if available (line 45 in main.py). You can also assign a particular card via CUDA_VISIBLE_DEVICES (e.g., the following commands use GPU card no.3).
|
||||
## Train Mouse2Vec Autoencoder
|
||||
<br>
|
||||
|
||||
Execute
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=3 python main.py --ssl True
|
||||
```
|
||||
## Use Mouse2Vec for Downstream Tasks
|
||||
|
||||
We enable two ways to use Mouse2Vec on your datasets for downstream tasks.
|
||||
1. Use the (frozen) pretrained model and only train a MLP-based classifier <br>
|
||||
|
||||
Execute
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=3 python main.py --ssl True --load True --stage 0 --testDataset [Your Dataset]
|
||||
```
|
||||
|
||||
2. Finetune both Mouse2Vec and the classifier <br>
|
||||
|
||||
Execute
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=3 python main.py --ssl True --load True --stage 1 --testDataset [Your Dataset]
|
||||
```
|
||||
|
||||
# Citation
|
||||
If you find our code useful or use it in your own projects, please cite our paper:
|
||||
```
|
||||
@inproceedings{zhang24_chi,
|
||||
title = {Mouse2Vec: Learning Reusable Semantic Representations of Mouse Behaviour},
|
||||
author = {Zhang, Guanhua and Hu, Zhiming and B{\^a}ce, Mihai and Bulling, Andreas},
|
||||
year = {2024},
|
||||
pages = {1--17},
|
||||
booktitle = {Proc. ACM SIGCHI Conference on Human Factors in Computing Systems (CHI)},
|
||||
doi = {10.1145/3613904.3642141}
|
||||
}
|
||||
```
|
||||
|
||||
# Acknowledgements
|
||||
Our work relied on the codebase of [Cross Reconstruction Transformer][1]. Thanks to the authors for sharing their code.
|
||||
|
||||
[1]: https://github.com/BobZwr/Cross-Reconstruction-Transformer
|
||||
[2]: https://perceptualui.org/publications/zhang24_chi/
|
||||
[3]: https://scholar.google.com/citations?user=ku-t0MMAAAAJ&hl=en&oi=ao
|
||||
[4]: https://scholar.google.com/citations?user=NqkK0GwAAAAJ&hl=en
|
||||
[5]: https://scholar.google.com/citations?hl=en&user=OLB_xBEAAAAJ
|
||||
[6]: https://www.perceptualui.org/people/bulling/
|
||||
[7]: https://perceptualui.org/publications/zhang24_chi.pdf
|
274
base_models.py
Normal file
274
base_models.py
Normal file
|
@ -0,0 +1,274 @@
|
|||
import numpy as np, pdb
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class SSLDataSet(Dataset):
|
||||
def __init__(self, data, clicks):
|
||||
super(SSLDataSet, self).__init__()
|
||||
self.data = data
|
||||
self.clicks = clicks
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return torch.tensor(self.data[idx], dtype=torch.float), torch.tensor(self.clicks[idx], dtype=torch.float)
|
||||
|
||||
def __len__(self):
|
||||
return self.data.shape[0]
|
||||
|
||||
class FTDataSet(Dataset):
|
||||
def __init__(self, data, label, clicks, actions=None, multi_label=False):
|
||||
super(FTDataSet, self).__init__()
|
||||
self.data = data
|
||||
self.label = label
|
||||
self.clicks = clicks
|
||||
self.multi_label = multi_label
|
||||
self.actions = actions
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.multi_label:
|
||||
if self.actions is None:
|
||||
return (torch.tensor(self.data[index], dtype=torch.float), torch.tensor(self.label[index], dtype=torch.float), torch.tensor(self.clicks[index], dtype=torch.float))
|
||||
else:
|
||||
return (torch.tensor(self.data[index], dtype=torch.float), torch.tensor(self.label[index], dtype=torch.float), torch.tensor(self.clicks[index], dtype=torch.float), torch.tensor(self.actions[index], dtype=torch.float))
|
||||
else:
|
||||
return (torch.tensor(self.data[index], dtype=torch.float), torch.tensor(self.label[index], dtype=torch.long), torch.tensor(self.clicks[index], dtype=torch.float))
|
||||
|
||||
def __len__(self):
|
||||
return self.data.shape[0]
|
||||
|
||||
# Resnet 1d
|
||||
|
||||
def conv(in_planes, out_planes, stride=1, kernel_size=3):
|
||||
"convolution with padding 自动使用zeros进行padding"
|
||||
return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=(kernel_size - 1) // 2, bias=False)
|
||||
|
||||
class ZeroPad1d(nn.Module):
|
||||
def __init__(self, pad_left, pad_right):
|
||||
super().__init__()
|
||||
self.pad_left = pad_left
|
||||
self.pad_right = pad_right
|
||||
|
||||
def forward(self, x):
|
||||
return F.pad(x, (self.pad_left, self.pad_right))
|
||||
|
||||
|
||||
class BasicBlock1d(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = conv(inplanes, planes, stride=stride, kernel_size=3)
|
||||
self.bn1 = nn.BatchNorm1d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv2 = conv(planes, planes, kernel_size=3)
|
||||
self.bn2 = nn.BatchNorm1d(planes)
|
||||
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck1d(nn.Module):
|
||||
"""Bottleneck for ResNet52 ..."""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super().__init__()
|
||||
kernel_size = 3
|
||||
self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm1d(planes)
|
||||
self.conv2 = nn.Conv1d(planes, planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=(kernel_size - 1) // 2, bias=False)
|
||||
self.bn2 = nn.BatchNorm1d(planes)
|
||||
self.conv3 = nn.Conv1d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm1d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class ResNet1d(nn.Module):
|
||||
'''1d adaptation of the torchvision resnet'''
|
||||
|
||||
def __init__(self, block, layers, kernel_size=3, input_channels=12, inplanes=64,
|
||||
fix_feature_dim=False, kernel_size_stem=None, stride_stem=2, pooling_stem=True,
|
||||
stride=2):
|
||||
super(ResNet1d, self).__init__()
|
||||
|
||||
self.inplanes = inplanes
|
||||
layers_tmp = []
|
||||
if kernel_size_stem is None:
|
||||
kernel_size_stem = kernel_size[0] if isinstance(kernel_size, list) else kernel_size
|
||||
|
||||
# conv-bn-relu (basic feature extraction)
|
||||
layers_tmp.append(nn.Conv1d(input_channels, inplanes,
|
||||
kernel_size=kernel_size_stem,
|
||||
stride=stride_stem,
|
||||
padding=(kernel_size_stem - 1) // 2, bias=False))
|
||||
layers_tmp.append(nn.BatchNorm1d(inplanes))
|
||||
layers_tmp.append(nn.ReLU(inplace=True))
|
||||
|
||||
if pooling_stem is True:
|
||||
layers_tmp.append(nn.MaxPool1d(kernel_size=3, stride=2, padding=1))
|
||||
|
||||
for i, l in enumerate(layers):
|
||||
if i == 0:
|
||||
layers_tmp.append(self._make_block(block, inplanes, layers[0]))
|
||||
else:
|
||||
layers_tmp.append(
|
||||
self._make_block(block, inplanes if fix_feature_dim else (2 ** i) * inplanes, layers[i],
|
||||
stride=stride))
|
||||
|
||||
self.feature_extractor = nn.Sequential(*layers_tmp)
|
||||
|
||||
def _make_block(self, block, planes, blocks, stride=1, kernel_size=3):
|
||||
down_sample = None
|
||||
|
||||
# 注定会进行下采样
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
down_sample = nn.Sequential(
|
||||
nn.Conv1d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm1d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, down_sample))
|
||||
self.inplanes = planes * block.expansion
|
||||
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.feature_extractor(x)
|
||||
|
||||
def resnet1d14(inplanes, input_channels):
|
||||
return ResNet1d(BasicBlock1d, [2,2,2], inplanes=inplanes, input_channels=input_channels)
|
||||
|
||||
def resnet1d18(**kwargs):
|
||||
return ResNet1d(BasicBlock1d, [2, 2, 2, 2], **kwargs)
|
||||
|
||||
# MLP
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, n_classes, bn = True):
|
||||
super(MLP, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.n_classes = n_classes
|
||||
self.hidden_channels = hidden_channels
|
||||
self.fc1 = nn.Linear(self.in_channels, self.hidden_channels)
|
||||
self.fc2 = nn.Linear(self.hidden_channels, self.n_classes)
|
||||
self.ac = nn.ReLU()
|
||||
self.bn = nn.BatchNorm1d(hidden_channels)
|
||||
self.ln = nn.LayerNorm(hidden_channels)
|
||||
self.fc3 = nn.Linear(self.hidden_channels, self.hidden_channels)
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.fc1(x)
|
||||
hidden = self.bn(hidden)
|
||||
hidden = self.ac(hidden)
|
||||
|
||||
'''hidden = self.fc3(hidden)
|
||||
hidden = self.ln2(hidden)
|
||||
hidden = self.ac(hidden)'''
|
||||
|
||||
out = self.fc2(hidden)
|
||||
|
||||
return out
|
||||
|
||||
# Time-steps features -> aggregated features
|
||||
class Flatten(nn.Module):
|
||||
def __init__(self):
|
||||
super(Flatten, self).__init__()
|
||||
|
||||
def forward(self, tensor):
|
||||
b = tensor.size(0)
|
||||
return tensor.reshape(b, -1)
|
||||
|
||||
class AdaptiveConcatPool1d(nn.Module):
|
||||
"Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`."
|
||||
|
||||
def __init__(self, sz=None):
|
||||
"Output will be 2*sz or 2 if sz is None"
|
||||
super().__init__()
|
||||
sz = sz or 1
|
||||
self.ap, self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz)
|
||||
|
||||
def forward(self, x):
|
||||
"""x is shaped of B, C, T"""
|
||||
return torch.cat([self.mp(x), self.ap(x), x[..., -1:]], 1)
|
||||
|
||||
def bn_drop_lin(n_in, n_out, bn, p, actn):
|
||||
"`n_in`->bn->dropout->linear(`n_in`,`n_out`)->`actn`"
|
||||
layers = list()
|
||||
|
||||
if bn:
|
||||
layers.append(nn.BatchNorm1d(n_in))
|
||||
|
||||
if p > 0.:
|
||||
layers.append(nn.Dropout(p=p))
|
||||
|
||||
layers.append(nn.Linear(n_in, n_out))
|
||||
|
||||
if actn is not None:
|
||||
layers.append(actn)
|
||||
|
||||
return layers
|
||||
|
||||
def create_head1d(nf: int, nc: int, lin_ftrs=[512, ], dropout=0.5, bn: bool = True, act="relu"):
|
||||
lin_ftrs = [3 * nf] + lin_ftrs + [nc]
|
||||
|
||||
activations = [nn.ReLU(inplace=True) if act == "relu" else nn.ELU(inplace=True)] * (len(lin_ftrs) - 2) + [None]
|
||||
layers = [AdaptiveConcatPool1d(), Flatten()]
|
||||
|
||||
for ni, no, actn in zip(lin_ftrs[:-1], lin_ftrs[1:], activations):
|
||||
layers += bn_drop_lin(ni, no, bn, dropout, actn)
|
||||
|
||||
layers += [nn.Sigmoid()]
|
||||
|
||||
return nn.Sequential(*layers)
|
276
env.yaml
Normal file
276
env.yaml
Normal file
|
@ -0,0 +1,276 @@
|
|||
channels:
|
||||
- pytorch
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=main
|
||||
- _openmp_mutex=4.5=1_gnu
|
||||
- alembic=1.8.1=py39h06a4308_0
|
||||
- argon2-cffi=21.3.0=pyhd3eb1b0_0
|
||||
- argon2-cffi-bindings=21.2.0=py39h7f8727e_0
|
||||
- asttokens=2.0.5=pyhd3eb1b0_0
|
||||
- attrs=22.1.0=py39h06a4308_0
|
||||
- autopage=0.5.0=pyhd8ed1ab_0
|
||||
- backcall=0.2.0=pyhd3eb1b0_0
|
||||
- blas=1.0=mkl
|
||||
- bleach=4.1.0=pyhd3eb1b0_0
|
||||
- bzip2=1.0.8=h7b6447c_0
|
||||
- ca-certificates=2023.01.10=h06a4308_0
|
||||
- certifi=2022.12.7=py39h06a4308_0
|
||||
- cffi=1.14.6=py39h400218f_0
|
||||
- cliff=3.10.0=pyhd8ed1ab_0
|
||||
- cmaes=0.8.2=pyh44b312d_0
|
||||
- cmd2=2.3.3=py39hf3d152e_1
|
||||
- colorlog=5.0.1=py39h06a4308_1
|
||||
- comm=0.1.3=pyhd8ed1ab_0
|
||||
- cudatoolkit=10.2.89=hfd86e86_1
|
||||
- decorator=5.1.1=pyhd3eb1b0_0
|
||||
- defusedxml=0.7.1=pyhd3eb1b0_0
|
||||
- entrypoints=0.4=py39h06a4308_0
|
||||
- executing=0.8.3=pyhd3eb1b0_0
|
||||
- ffmpeg=4.3=hf484d3e_0
|
||||
- freetype=2.11.0=h70c0345_0
|
||||
- giflib=5.2.1=h7b6447c_0
|
||||
- gmp=6.2.1=h2531618_2
|
||||
- gnutls=3.6.15=he1e5248_0
|
||||
- intel-openmp=2021.4.0=h06a4308_3561
|
||||
- ipython=8.10.0=py39h06a4308_0
|
||||
- ipython_genutils=0.2.0=pyhd3eb1b0_1
|
||||
- jedi=0.18.1=py39h06a4308_1
|
||||
- jinja2=3.1.2=py39h06a4308_0
|
||||
- jpeg=9d=h7f8727e_0
|
||||
- jsonschema=4.17.3=py39h06a4308_0
|
||||
- jupyter_client=7.1.2=pyhd3eb1b0_0
|
||||
- jupyter_core=5.3.0=py39h06a4308_0
|
||||
- lame=3.100=h7b6447c_0
|
||||
- lcms2=2.12=h3be6417_0
|
||||
- ld_impl_linux-64=2.35.1=h7274673_9
|
||||
- libblas=3.9.0=12_linux64_mkl
|
||||
- libcblas=3.9.0=12_linux64_mkl
|
||||
- libffi=3.3=he6710b0_2
|
||||
- libgcc-ng=9.3.0=h5101ec6_17
|
||||
- libgfortran-ng=7.5.0=ha8ba4b0_17
|
||||
- libgfortran4=7.5.0=ha8ba4b0_17
|
||||
- libgomp=9.3.0=h5101ec6_17
|
||||
- libiconv=1.15=h63c8f33_5
|
||||
- libidn2=2.3.2=h7f8727e_0
|
||||
- liblapack=3.9.0=12_linux64_mkl
|
||||
- libllvm10=10.0.1=hbcb73fb_5
|
||||
- libpng=1.6.37=hbc83047_0
|
||||
- libsodium=1.0.18=h7b6447c_0
|
||||
- libstdcxx-ng=9.3.0=hd4cf53a_17
|
||||
- libtasn1=4.16.0=h27cfd23_0
|
||||
- libtiff=4.2.0=h85742a9_0
|
||||
- libunistring=0.9.10=h27cfd23_0
|
||||
- libuv=1.40.0=h7b6447c_0
|
||||
- libwebp=1.2.0=h89dd481_0
|
||||
- libwebp-base=1.2.0=h27cfd23_0
|
||||
- llvmlite=0.36.0=py39h612dafd_4
|
||||
- lz4-c=1.9.3=h295c915_1
|
||||
- mako=1.2.3=py39h06a4308_0
|
||||
- markupsafe=2.1.1=py39h7f8727e_0
|
||||
- matplotlib-inline=0.1.6=py39h06a4308_0
|
||||
- mistune=0.8.4=py39h27cfd23_1000
|
||||
- mkl=2021.4.0=h06a4308_640
|
||||
- mkl-service=2.4.0=py39h7f8727e_0
|
||||
- mkl_fft=1.3.1=py39hd3c417c_0
|
||||
- mkl_random=1.2.2=py39h51133e4_0
|
||||
- nbconvert=5.5.0=py_0
|
||||
- nbformat=5.7.0=py39h06a4308_0
|
||||
- ncurses=6.3=h7f8727e_2
|
||||
- nest-asyncio=1.5.6=py39h06a4308_0
|
||||
- nettle=3.7.3=hbbd107a_1
|
||||
- notebook=6.4.6=pyha770c72_0
|
||||
- numba=0.53.1=py39ha9443f7_0
|
||||
- numpy=1.21.2=py39h20f2e39_0
|
||||
- numpy-base=1.21.2=py39h79a1101_0
|
||||
- olefile=0.46=pyhd3eb1b0_0
|
||||
- openh264=2.1.0=hd408876_0
|
||||
- openssl=1.1.1t=h7f8727e_0
|
||||
- optuna=2.10.0=pyhd8ed1ab_0
|
||||
- packaging=23.0=py39h06a4308_0
|
||||
- pandoc=2.12=h06a4308_3
|
||||
- pandocfilters=1.5.0=pyhd3eb1b0_0
|
||||
- parso=0.8.3=pyhd3eb1b0_0
|
||||
- pbr=5.6.0=pyhd3eb1b0_0
|
||||
- pexpect=4.8.0=pyhd3eb1b0_3
|
||||
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
||||
- pillow=8.4.0=py39h5aabda8_0
|
||||
- pip=21.2.4=py39h06a4308_0
|
||||
- platformdirs=2.5.2=py39h06a4308_0
|
||||
- plotly=4.14.3=pyhd3eb1b0_0
|
||||
- prometheus_client=0.14.1=py39h06a4308_0
|
||||
- prompt-toolkit=3.0.36=py39h06a4308_0
|
||||
- psutil=5.8.0=py39h3811e60_1
|
||||
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
||||
- pure_eval=0.2.2=pyhd3eb1b0_0
|
||||
- pycparser=2.21=pyhd3eb1b0_0
|
||||
- pygments=2.11.2=pyhd3eb1b0_0
|
||||
- pyparsing=3.0.9=py39h06a4308_0
|
||||
- pyperclip=1.8.2=pyhd8ed1ab_2
|
||||
- pyrsistent=0.18.0=py39heee7806_0
|
||||
- python=3.9.7=h12debd9_1
|
||||
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
||||
- python-fastjsonschema=2.16.2=py39h06a4308_0
|
||||
- python_abi=3.9=2_cp39
|
||||
- pytorch=1.10.1=py3.9_cuda10.2_cudnn7.6.5_0
|
||||
- pytorch-mutex=1.0=cuda
|
||||
- pyts=0.12.0=pyh6c4a22f_0
|
||||
- readline=8.1=h27cfd23_0
|
||||
- retrying=1.3.3=pyhd3eb1b0_2
|
||||
- send2trash=1.8.0=pyhd3eb1b0_1
|
||||
- setuptools=58.0.4=py39h06a4308_0
|
||||
- six=1.16.0=pyhd3eb1b0_0
|
||||
- sqlalchemy=1.3.23=py39h3811e60_0
|
||||
- sqlite=3.36.0=hc218d9a_0
|
||||
- stack_data=0.2.0=pyhd3eb1b0_0
|
||||
- stevedore=3.5.0=py39hf3d152e_2
|
||||
- tbb=2020.3=hfd86e86_0
|
||||
- terminado=0.17.1=py39h06a4308_0
|
||||
- testpath=0.6.0=py39h06a4308_0
|
||||
- tk=8.6.11=h1ccaba5_0
|
||||
- torchaudio=0.10.1=py39_cu102
|
||||
- torchvision=0.11.2=py39_cu102
|
||||
- tornado=6.1=py39h27cfd23_0
|
||||
- tqdm=4.65.0=py39hb070fc8_0
|
||||
- traitlets=5.7.1=py39h06a4308_0
|
||||
- typing-extensions=3.10.0.2=hd3eb1b0_0
|
||||
- typing_extensions=3.10.0.2=pyh06a4308_0
|
||||
- wcwidth=0.2.5=pyhd3eb1b0_0
|
||||
- webencodings=0.5.1=py39h06a4308_1
|
||||
- wheel=0.37.0=pyhd3eb1b0_1
|
||||
- xz=5.2.5=h7b6447c_0
|
||||
- yaml=0.2.5=h7b6447c_0
|
||||
- zeromq=4.3.4=h2531618_0
|
||||
- zlib=1.2.11=h7f8727e_4
|
||||
- zstd=1.4.9=haebb681_0
|
||||
- pip:
|
||||
- absl-py==1.0.0
|
||||
- aiohttp==3.8.1
|
||||
- aiosignal==1.2.0
|
||||
- altair==4.1.0
|
||||
- appdirs==1.4.4
|
||||
- astor==0.8.1
|
||||
- async-timeout==4.0.2
|
||||
- base58==2.1.1
|
||||
- blinker==1.4
|
||||
- cachetools==4.2.4
|
||||
- charset-normalizer==2.0.9
|
||||
- click==7.1.2
|
||||
- configparser==5.2.0
|
||||
- cycler==0.11.0
|
||||
- cython==0.29.35
|
||||
- datasets==1.17.0
|
||||
- debugpy==1.5.1
|
||||
- deprecated==1.2.13
|
||||
- dictset==0.3.1.2
|
||||
- dill==0.3.4
|
||||
- docker-pycreds==0.4.0
|
||||
- dtaidistance==2.3.10
|
||||
- einops==0.6.0
|
||||
- fastdtw==0.3.4
|
||||
- filelock==3.4.0
|
||||
- fonttools==4.28.5
|
||||
- frozenlist==1.2.0
|
||||
- fsspec==2021.11.1
|
||||
- gitdb==4.0.9
|
||||
- gitpython==3.1.24
|
||||
- google-auth==2.3.3
|
||||
- google-auth-oauthlib==0.4.6
|
||||
- grpcio==1.43.0
|
||||
- h5py==3.6.0
|
||||
- hdbscan==0.8.29
|
||||
- hmmlearn==0.2.8
|
||||
- huggingface-hub==0.2.1
|
||||
- idna==3.3
|
||||
- imbalanced-learn==0.9.0
|
||||
- import-ipynb==0.1.3
|
||||
- importlib-metadata==6.6.0
|
||||
- install-jdk==0.3.0
|
||||
- ipykernel==6.6.0
|
||||
- ipynb==0.5.1
|
||||
- ipywidgets==7.6.5
|
||||
- javabridge==1.0.16
|
||||
- joblib==1.1.0
|
||||
- jupyterlab-widgets==1.0.2
|
||||
- kiwisolver==1.3.2
|
||||
- markdown==3.3.6
|
||||
- matplotlib==3.5.1
|
||||
- msgpack==1.0.3
|
||||
- multidict==5.2.0
|
||||
- multiprocess==0.70.12.2
|
||||
- nltk==3.7
|
||||
- numexpr==2.8.1
|
||||
- oauthlib==3.1.1
|
||||
- opencv-python==4.5.5.62
|
||||
- pandas==1.5.3
|
||||
- pathtools==0.1.2
|
||||
- patsy==0.5.2
|
||||
- promise==2.3
|
||||
- protobuf==3.19.1
|
||||
- pyarrow==6.0.1
|
||||
- pyasn1==0.4.8
|
||||
- pyasn1-modules==0.2.8
|
||||
- pydeck==0.7.1
|
||||
- pympler==1.0
|
||||
- pynndescent==0.5.10
|
||||
- pystaggrelite3==0.1.3
|
||||
- python-crfsuite==0.9.9
|
||||
- python-weka-wrapper==0.3.18
|
||||
- pytorch-ignite==0.4.8
|
||||
- pytorch-warmup==0.1.0
|
||||
- pytz==2021.3
|
||||
- pytz-deprecation-shim==0.1.0.post0
|
||||
- pyyaml==6.0
|
||||
- pyzmq==22.3.0
|
||||
- ray==1.9.2
|
||||
- redis==4.1.1
|
||||
- regex==2021.11.10
|
||||
- requests==2.26.0
|
||||
- requests-oauthlib==1.3.0
|
||||
- rsa==4.8
|
||||
- sacremoses==0.0.46
|
||||
- scikit-learn==1.0.1
|
||||
- scipy==1.10.1
|
||||
- seaborn==0.11.2
|
||||
- sentencepiece==0.1.96
|
||||
- sentry-sdk==1.5.1
|
||||
- seqeval==1.2.2
|
||||
- setproctitle==1.3.2
|
||||
- shortuuid==1.0.8
|
||||
- simpletransformers==0.63.3
|
||||
- sklearn-crfsuite==0.3.6
|
||||
- smmap==5.0.0
|
||||
- statsmodels==0.13.2
|
||||
- streamlit==1.3.0
|
||||
- subprocess32==3.5.4
|
||||
- tables==3.7.0
|
||||
- tabulate==0.8.9
|
||||
- tensorboard==2.11.2
|
||||
- tensorboard-data-server==0.6.1
|
||||
- tensorboard-plugin-wit==1.8.1
|
||||
- termcolor==1.1.0
|
||||
- textdistance==4.2.2
|
||||
- threadpoolctl==3.0.0
|
||||
- tokenizers==0.10.3
|
||||
- toml==0.10.2
|
||||
- toolz==0.11.2
|
||||
- torch-summary==1.4.5
|
||||
- torch-tb-profiler==0.4.1
|
||||
- torchinfo==1.7.0
|
||||
- torchtext==0.11.1
|
||||
- transformers==4.14.1
|
||||
- tzdata==2021.5
|
||||
- tzlocal==4.1
|
||||
- umap-learn==0.5.3
|
||||
- urllib3==1.26.7
|
||||
- validators==0.18.2
|
||||
- watchdog==2.1.6
|
||||
- weka==2.0.2
|
||||
- werkzeug==2.0.2
|
||||
- widgetsnbextension==3.5.2
|
||||
- wrapt==1.13.3
|
||||
- xxhash==2.0.2
|
||||
- yarl==1.7.2
|
||||
- yaspin==2.1.0
|
||||
- zipp==3.15.0
|
77
main.py
Normal file
77
main.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
import numpy as np
|
||||
import random, pdb, os, copy
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import pickle as pkl
|
||||
import torch
|
||||
|
||||
from utils import MakeDir, set_seed, str2bool
|
||||
from CRT import Model, self_supervised_learning, finetuning
|
||||
from base_models import FTDataSet
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ssl", type=str2bool, default=False, help='Self-supervised learning (pretrain the autoencoder)')
|
||||
parser.add_argument("--sl", type=str2bool, default=True, help='Supervised learning (downstream tasks)')
|
||||
parser.add_argument("--load", type=str2bool, default=True)
|
||||
|
||||
parser.add_argument("--patch_len", type=int, default=5)
|
||||
parser.add_argument("--dim", type=int, default=128)
|
||||
parser.add_argument("--depth", type=int, default=6)
|
||||
parser.add_argument("--dropout", type=float, default=0.1)
|
||||
parser.add_argument("--Headdropout", type=float, default=0.1)
|
||||
parser.add_argument("--beta2", type=float, default=1)
|
||||
|
||||
parser.add_argument("--AElearningRate", type=float, default=1e-4)
|
||||
parser.add_argument("--HeadlearningRate", type=float, default=1e-3)
|
||||
parser.add_argument("--AEbatchSize", type=int, default=512)
|
||||
parser.add_argument("--HeadbatchSize", type=int, default=32)
|
||||
parser.add_argument("--AEepochs", type=int, default=100)
|
||||
parser.add_argument("--Headepochs", type=int, default=50)
|
||||
|
||||
parser.add_argument("--AEwin", type=int, default=5)
|
||||
parser.add_argument("--Headwin", type=int, default=5)
|
||||
parser.add_argument("--slid", type=int, default=1)
|
||||
parser.add_argument("--timeWinFreq", type=int, default=20)
|
||||
|
||||
parser.add_argument("--pretrainDataset", type=str, default='Buffalo_EMAKI')
|
||||
parser.add_argument("--testDataset", type=str)
|
||||
|
||||
parser.add_argument("--stage", type=int, choices=[0,1], default=0, help='0: Train the classifier only; 1: Finetune the whole model')
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
opt = parser.parse_args()
|
||||
opt.in_dim = 2 # x,y
|
||||
opt.seq_len = opt.AEwin * opt.timeWinFreq *2
|
||||
opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
set_seed(opt.seed)
|
||||
aemodelfile = 'pretrained_model.pkl'
|
||||
|
||||
if opt.ssl:
|
||||
'''
|
||||
Load dataset for pretraining the Autoencoder
|
||||
taskdata: on-screen coordinates (x,y), their magnitude and phase
|
||||
'''
|
||||
taskdata, taskclicks = loadPretrainDataset(opt)
|
||||
model = Model(opt).to(opt.device)
|
||||
self_supervised_learning(model, taskdata, taskclicks, opt, modelfile=aemodelfile)
|
||||
|
||||
if opt.load:
|
||||
model = torch.load(aemodelfile, map_location=opt.device)
|
||||
else:
|
||||
model = Model(opt).to(opt.device)
|
||||
pdb.set_trace()
|
||||
|
||||
if opt.sl:
|
||||
'''
|
||||
Load dataset for downstream tasks
|
||||
_X: on-screen coordinates (x,y), their magnitude and phase
|
||||
_y: labels
|
||||
_clicks: clicks
|
||||
'''
|
||||
train_X, train_y, train_clicks, valid_X, valid_y, valid_clicks = loadDownstreamDataset(opt)
|
||||
|
||||
TrainSet = FTDataSet(train_X, train_y, train_clicks, multi_label=True) # Binary or multi-class
|
||||
ValidSet = FTDataSet(valid_X, valid_y, valid_clicks, multi_label=True)
|
||||
|
||||
finetuning(model, TrainSet, ValidSet, opt, aemodelfile.split('.pkl')[0]+'downstream.pkl', multi_label=True)
|
BIN
pretrained_model.pkl
Normal file
BIN
pretrained_model.pkl
Normal file
Binary file not shown.
52
utils.py
Normal file
52
utils.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import pdb, os, random, copy
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pickle as pkl
|
||||
import torch
|
||||
from torch.utils.data import DataLoader,Dataset
|
||||
|
||||
def str2bool(v):
|
||||
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
||||
return True
|
||||
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError('Unsupported value encountered.')
|
||||
|
||||
def set_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed) # Numpy module.
|
||||
random.seed(seed) # Python random module.
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
def DFT(data, clicks=None):
|
||||
for idx, channel in enumerate(range(data.shape[-1])):
|
||||
freq = np.fft.fft(data[:,:,channel])
|
||||
freq = freq[:, :data.shape[1]//2] # symmetric
|
||||
mag = np.abs(freq)
|
||||
phase = np.angle(freq)
|
||||
mag = (mag - np.min(mag)) / (np.max(mag) - np.min(mag))
|
||||
phase = (phase - np.min(phase)) / (np.max(phase) - np.min(phase))
|
||||
if idx==0:
|
||||
mags, phases = mag, phase
|
||||
else:
|
||||
mags, phases = np.stack([mags, mag], axis=-1), np.stack([phases, phase], axis=-1)
|
||||
data = np.concatenate([data, mags, phases], axis=1)
|
||||
return data, clicks
|
||||
|
||||
def MakeDir(dirName):
|
||||
if not os.path.exists(dirName):
|
||||
os.makedirs(dirName)
|
||||
|
||||
class simpleSet(Dataset):
|
||||
def __init__(self, features, labels):
|
||||
self.features = features
|
||||
labels = labels.astype(int)
|
||||
self.labels = labels - np.min(labels)
|
||||
def __getitem__(self, index):
|
||||
return self.features[index], self.labels[index]
|
||||
def __len__(self):
|
||||
return len(self.features)
|
Loading…
Reference in a new issue