Mouse2Vec/CRT.py
2024-05-07 17:01:12 +02:00

438 lines
No EOL
19 KiB
Python

import pdb, copy
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from base_models import MLP, resnet1d18, SSLDataSet
from sklearn.metrics import f1_score
class cnn_extractor(nn.Module):
def __init__(self, dim, input_plane):
super(cnn_extractor, self).__init__()
self.cnn = resnet1d18(input_channels=dim, inplanes=input_plane)
def forward(self, x):
x = self.cnn(x)
return x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class TFR(nn.Module):
def __init__(self, seq_len, patch_len, dim, depth, heads, mlp_dim, channels=12,
dim_head=64, dropout=0., emb_dropout=0.):
'''
The encoder of CRT
'''
super().__init__()
self.patch_len = patch_len
self.seq_len = seq_len
assert seq_len % (4 * patch_len) == 0, \
'The seq_len should be 4 * n * patch_len, or there must be patch with both magnitude and phase data.'
num_patches = seq_len // patch_len
patch_dim = channels * patch_len
self.to_patch = nn.Sequential(Rearrange('b c (n p1) -> b n c p1', p1=patch_len),
Rearrange('b n c p1 -> (b n) c p1'))
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 3, dim))
self.modal_embedding = nn.Parameter(torch.randn(3, 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 3, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.cnn1 = cnn_extractor(dim=channels, input_plane=dim // 8) # For temporal data
self.cnn2 = cnn_extractor(dim=channels, input_plane=dim // 8) # For magnitude data
self.cnn3 = cnn_extractor(dim=channels, input_plane=dim // 8) # For phase data
def forward(self, x, opt):
batch, _, time_steps = x.shape
t, m, p = x[:, :, :time_steps // 2], x[:, :, time_steps // 2: time_steps * 3 // 4], x[:, :, -time_steps // 4:]
assert t.shape[-1] % opt.patch_len == 0 == m.shape[-1] % opt.patch_len == p.shape[-1] % opt.patch_len
t, m, p = self.to_patch(t), self.to_patch(m), self.to_patch(p)
patch2seq = nn.Sequential(nn.AdaptiveAvgPool1d(1),
Rearrange('(b n) c 1 -> b n c', b=batch))
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=batch)
x = torch.cat((cls_tokens[:, 0:1, :], patch2seq(self.cnn1(t)),
cls_tokens[:, 1:2, :], patch2seq(self.cnn2(m)),
cls_tokens[:, 2:3, :], patch2seq(self.cnn3(p))), dim=1)
b, t, c = x.shape
time_steps = t - 3
t_token_idx, m_token_idx, p_token_idx = 0, time_steps // 2 + 1, time_steps * 3 // 4 + 2
x[:m_token_idx] += self.modal_embedding[:1]
x[m_token_idx: p_token_idx] += self.modal_embedding[1:2]
x[p_token_idx: ] += self.modal_embedding[2:]
x += self.pos_embedding[:, : t]
x = self.dropout(x)
x = self.transformer(x)
t_token, m_token, p_token = x[:, t_token_idx], x[:, m_token_idx], x[:, p_token_idx]
avg = (t_token + m_token + p_token) / 3
return avg
def TFR_Encoder(seq_len, patch_len, dim, in_dim, depth):
vit = TFR(seq_len=seq_len,
patch_len=patch_len,
#num_classes=num_class,
dim=dim,
depth=depth,
heads=8,
mlp_dim=dim,
dropout=0.2,
emb_dropout=0.1,
channels=in_dim)
return vit
class CRT(nn.Module):
def __init__(
self,
encoder,
decoder_dim,
decoder_depth=2,
decoder_heads=8,
decoder_dim_head=64,
patch_len = 20,
in_dim=12
):
super().__init__()
self.encoder = encoder
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
self.to_patch = encoder.to_patch
pixel_values_per_patch = in_dim * patch_len
# decoder parameters
self.modal_embedding = self.encoder.modal_embedding
self.mask_token = nn.Parameter(torch.randn(3, decoder_dim))
self.decoder = Transformer(dim=decoder_dim,
depth=decoder_depth,
heads=decoder_heads,
dim_head=decoder_dim_head,
mlp_dim=decoder_dim)
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
self.to_pixels = nn.ModuleList([nn.Linear(decoder_dim, pixel_values_per_patch) for i in range(3)])
self.projs = nn.ModuleList([nn.Linear(decoder_dim, decoder_dim) for i in range(2)])
self.to_clicks = nn.Linear(decoder_dim, 2 * patch_len)
def IDC_loss(self, tokens, encoded_tokens):
B, T, D = tokens.shape
tokens, encoded_tokens = F.normalize(tokens, dim=-1), F.normalize(encoded_tokens, dim=-1)
encoded_tokens = encoded_tokens.transpose(2, 1)
cross_mul = torch.exp(torch.matmul(tokens, encoded_tokens))
mask = (1 - torch.eye(T)).unsqueeze(0).to(tokens.device)
cross_mul = cross_mul * mask
return torch.log(cross_mul.sum(-1).sum(-1)).mean(-1)
def forward(self, x, clicks, mask_ratio=0.75, beta = 1e-4, beta2 = 1e-4):
device = x.device
patches = self.to_patch[0](x)
batch, num_patches, c, length = patches.shape
clickpatches = self.to_patch[0](clicks.unsqueeze(1).float())
num_masked = int(mask_ratio * num_patches)
rand_indices1 = torch.randperm(num_patches // 2, device=device)
masked_indices1 = rand_indices1[: num_masked // 2].sort()[0]
unmasked_indices1 = rand_indices1[num_masked // 2:].sort()[0]
rand_indices2 = torch.randperm(num_patches // 4, device=device)
masked_indices2, unmasked_indices2 = rand_indices2[: num_masked // 4].sort()[0], rand_indices2[num_masked // 4:].sort()[0]
rand_indices = torch.cat((masked_indices1, unmasked_indices1,
masked_indices2 + num_patches // 2, unmasked_indices2 + num_patches // 2,
masked_indices2 + num_patches // 4 * 3, unmasked_indices2 + num_patches // 4 * 3))
masked_num_t, masked_num_f = masked_indices1.shape[0], 2 * masked_indices2.shape[0]
unmasked_num_t, unmasked_num_f = unmasked_indices1.shape[0], 2 * unmasked_indices2.shape[0]
tpatches = patches[:, : num_patches // 2, :, :]
mpatches, ppatches = patches[:, num_patches // 2: num_patches * 3 // 4, :, :], patches[:, -num_patches // 4:, :, :]
unmasked_tpatches = tpatches[:, unmasked_indices1, :, :]
unmasked_mpatches, unmasked_ppatches = mpatches[:, unmasked_indices2, :, :], ppatches[:, unmasked_indices2, :, :]
t_tokens, m_tokens, p_tokens = self.to_patch[1](unmasked_tpatches), self.to_patch[1](unmasked_mpatches), self.to_patch[1](unmasked_ppatches)
t_tokens, m_tokens, p_tokens = self.encoder.cnn1(t_tokens), self.encoder.cnn2(m_tokens), self.encoder.cnn3(p_tokens)
Flat = nn.Sequential(nn.AdaptiveAvgPool1d(1),
Rearrange('(b n) c 1 -> b n c', b=batch))
t_tokens, m_tokens, p_tokens = Flat(t_tokens), Flat(m_tokens), Flat(p_tokens)
ori_tokens = torch.cat((t_tokens, m_tokens, p_tokens), 1).clone()
cls_tokens = repeat(self.encoder.cls_token, '() n d -> b n d', b=batch)
tokens = torch.cat((cls_tokens[:, 0:1, :], t_tokens,
cls_tokens[:, 1:2, :], m_tokens,
cls_tokens[:, 2:3, :], p_tokens), dim=1)
t_idx, m_idx, p_idx = num_patches // 2 - 1, num_patches * 3 // 4 - 1, num_patches - 1
pos_embedding = torch.cat((self.encoder.pos_embedding[:, 0:1, :], self.encoder.pos_embedding[:, unmasked_indices1 + 1, :],
self.encoder.pos_embedding[:, t_idx + 2: t_idx + 3],
self.encoder.pos_embedding[:, unmasked_indices2 + t_idx + 3, :],
self.encoder.pos_embedding[:, m_idx + 3: m_idx + 4],
self.encoder.pos_embedding[:, unmasked_indices2 + m_idx + 4, :]), dim=1)
modal_embedding = torch.cat((repeat(self.modal_embedding[0], '1 d -> 1 n d', n=unmasked_num_t + 1),
repeat(self.modal_embedding[1], '1 d -> 1 n d', n=unmasked_num_f // 2 + 1),
repeat(self.modal_embedding[2], '1 d -> 1 n d', n=unmasked_num_f // 2 + 1)), dim=1)
tokens = tokens + pos_embedding + modal_embedding
encoded_tokens = self.encoder.transformer(tokens) # tokens: unmasked tokens + CLS tokens
t_idx, m_idx, p_idx = unmasked_num_t, unmasked_num_f // 2 + unmasked_num_t + 1, -1
idc_loss = self.IDC_loss(self.projs[0](ori_tokens), self.projs[1](torch.cat(([encoded_tokens[:, 1: t_idx+1], encoded_tokens[:, t_idx+2: m_idx+1], encoded_tokens[:, m_idx+2: ]]), dim=1)))
decoder_tokens = encoded_tokens
mask_tokens1 = repeat(self.mask_token[0], 'd -> b n d', b=batch, n=masked_num_t)
mask_tokens2 = repeat(self.mask_token[1], 'd -> b n d', b=batch, n=masked_num_f // 2)
mask_tokens3 = repeat(self.mask_token[2], 'd -> b n d', b=batch, n=masked_num_f // 2)
mask_tokens = torch.cat((mask_tokens1, mask_tokens2, mask_tokens3), dim=1)
decoder_pos_emb = self.decoder_pos_emb(torch.cat(
(masked_indices1, masked_indices2 + num_patches // 2, masked_indices2 + num_patches * 3 // 4)))
mask_tokens = mask_tokens + decoder_pos_emb
decoder_tokens = torch.cat((decoder_tokens, mask_tokens), dim=1)
decoded_tokens = self.decoder(decoder_tokens)
mask_tokens = decoded_tokens[:, -mask_tokens.shape[1]:]
pred_pixel_values_t = self.to_pixels[0](torch.cat((decoder_tokens[:, 1: t_idx + 1], mask_tokens[:, : masked_num_t]), 1))
pred_pixel_values_m = self.to_pixels[1](torch.cat((decoder_tokens[:, t_idx+2: m_idx+1], mask_tokens[:, masked_num_t: masked_num_f // 2 + masked_num_t]), 1))
pred_pixel_values_p = self.to_pixels[2](torch.cat((decoder_tokens[:, m_idx+2: -mask_tokens.shape[1]], mask_tokens[:, -masked_num_f // 2:]), 1))
pred_pixel_values = torch.cat((pred_pixel_values_t, pred_pixel_values_m, pred_pixel_values_p), dim=1)
recon_loss = F.mse_loss(pred_pixel_values, rearrange(patches[:,rand_indices], 'b n c p -> b n (c p)'))
rmvcls_embedding = torch.cat((decoder_tokens[:, 1: t_idx + 1], mask_tokens[:, : masked_num_t]), 1)
click_pred = self.to_clicks(rmvcls_embedding)
click_pred = rearrange(click_pred, 'b n (c p) -> b n c p', p=2)
click_pred = rearrange(click_pred, 'b n c p -> (b n c) p')
clicksGT = clickpatches[:, rand_indices1].squeeze()
clicksGT = rearrange(clicksGT, 'b n c -> (b n c)')
clicksOH = F.one_hot(clicksGT.to(torch.int64), num_classes=2)
pos_weight = (clicks==0).sum()/(clicks==1).sum()
click_pred_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(click_pred, clicksOH.float())
return recon_loss + beta * idc_loss + beta2 * click_pred_loss
class Model(nn.Module):
def __init__(self, opt):
super(Model, self).__init__()
self.opt = opt
self.encoder = TFR_Encoder(seq_len=opt.seq_len,
patch_len=opt.patch_len,
dim=opt.dim,
in_dim=opt.in_dim,
depth=opt.depth)
self.crt = CRT(encoder=self.encoder,
decoder_dim=opt.dim,
in_dim=opt.in_dim,
patch_len=opt.patch_len)
def forward(self, x, clicks, ratio = 0.5):
return self.crt(x, clicks, mask_ratio=ratio, beta2=self.opt.beta2)
def self_supervised_learning(model, X, clicks, opt, modelfile, min_ratio=0.3, max_ratio=0.8):
optimizer = optim.Adam(model.parameters(), opt.AElearningRate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20)
model.to(opt.device)
model.train()
dataset = SSLDataSet(X, clicks)
dataloader = DataLoader(dataset, batch_size=opt.AEbatchSize, shuffle=True)
losses = []
for _ in range(opt.AEepochs):
for idx, batch in enumerate(dataloader):
x, clicks = tuple(t.to(opt.device) for t in batch)
loss = model.to(opt.device)(x, clicks, ratio=max(min_ratio, min(max_ratio, _ / opt.AEepochs)))
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(float(loss))
if idx % 20 == 0:
print(idx,'/',len(dataloader), sum(losses) / len(losses))
scheduler.step(_)
torch.save(model.to('cpu'), modelfile)
class encoder_classifier(nn.Module):
def __init__(self, AE, opt):
super().__init__()
self.opt = opt
hidden_channels = 64
self.inputdim = opt.dim
self.encoder = copy.deepcopy(AE.encoder)
self.classifier = nn.Sequential(
nn.Linear(self.inputdim, hidden_channels),
nn.ReLU(),
nn.Dropout(p = opt.Headdropout),
nn.Linear(hidden_channels, hidden_channels),
nn.ReLU(),
nn.Dropout(p = opt.Headdropout),
nn.Linear(hidden_channels, opt.num_class)
)
# initialize the classifier
for m in self.classifier.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def forward(self, x):
features = self.encoder(x, self.opt)
out = self.classifier(features)
return out
def weight_init_layer(m):
if hasattr(m, 'reset_parameters'):
m.reset_parameters()
def weight_init_whole(m):
children = list(m.children())
if len(children)==0: # No more children layers
m.apply(weight_init_layer)
else: # Get children layers
for c in children:
weight_init_whole(c)
def finetuning(AE, train_set, valid_set, opt, modelfile, multi_label=True):
stage_info = {0: 'Train the classifier only', 1: 'Finetune the whole model'}
print('Stage %d: '%(opt.stage), stage_info[opt.stage])
train_loader = DataLoader(train_set, batch_size=opt.HeadbatchSize, shuffle=True)
model = encoder_classifier(AE, opt)
model.train()
best_res = 0
step = 0
if opt.stage == 0:
optimizer = optim.Adam(model.classifier.parameters(), lr=opt.HeadlearningRate)
else:
lr = opt.HeadlearningRate/2
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20)
for epoch in range(opt.Headepochs):
for batch_idx, batch in enumerate(train_loader):
step += 1
x, y, clicks = tuple(t.to(opt.device) for t in batch)
pred = model.to(opt.device)(x)
if not multi_label:
pos_weight = (y==0.).sum()/y.sum()
y = F.one_hot(y, num_classes=2)
#loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(pred, y.squeeze().float())
loss = nn.BCEWithLogitsLoss()(pred, y.squeeze().float())
else:
if pred.shape[0]==1:
loss = nn.CrossEntropyLoss()(pred, y.long())
else:
loss = nn.CrossEntropyLoss()(pred, y.squeeze().long())
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 10 == 0:
[trainacc, trainf1] = test(model, train_set, opt.HeadbatchSize, multi_label)
if trainf1 > best_res:
best_res = trainf1
scheduler.step(best_res)
def test(model, dataset, batch_size, multi_label):
def topAcc(GTlabel, pred_proba, top):
count = 0
for i in range(GTlabel.shape[0]):
if (GTlabel[i]) in np.argsort(-pred_proba[i])[:top]:
count+=1.0
return count/GTlabel.shape[0]
model.eval()
model.to(model.opt.device)
testloader = DataLoader(dataset, batch_size=batch_size)
pred_prob = None
with torch.no_grad():
for batch in testloader:
x, y, clicks = tuple(t.to(model.opt.device) for t in batch)
pred = model(x)
if pred_prob is None:
pred_prob = pred.cpu().detach().numpy()
else:
pred_prob = np.concatenate([pred_prob, pred.cpu().detach().numpy()], axis=0)
acc = topAcc(dataset.label.squeeze(), pred_prob, 1)
pred = np.argmax(pred_prob, axis=1)
f1 = f1_score(dataset.label.squeeze(), pred, average='macro')
model.train()
return [acc, f1]