import pdb, copy import numpy as np import torch from torch import nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from einops import rearrange, repeat from einops.layers.torch import Rearrange from base_models import MLP, resnet1d18, SSLDataSet from sklearn.metrics import f1_score class cnn_extractor(nn.Module): def __init__(self, dim, input_plane): super(cnn_extractor, self).__init__() self.cnn = resnet1d18(input_channels=dim, inplanes=input_plane) def forward(self, x): x = self.cnn(x) return x class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout=0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim=-1) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward(self, x): qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) out = torch.matmul(attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)) ])) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return x class TFR(nn.Module): def __init__(self, seq_len, patch_len, dim, depth, heads, mlp_dim, channels=12, dim_head=64, dropout=0., emb_dropout=0.): ''' The encoder of CRT ''' super().__init__() self.patch_len = patch_len self.seq_len = seq_len assert seq_len % (4 * patch_len) == 0, \ 'The seq_len should be 4 * n * patch_len, or there must be patch with both magnitude and phase data.' num_patches = seq_len // patch_len patch_dim = channels * patch_len self.to_patch = nn.Sequential(Rearrange('b c (n p1) -> b n c p1', p1=patch_len), Rearrange('b n c p1 -> (b n) c p1')) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 3, dim)) self.modal_embedding = nn.Parameter(torch.randn(3, 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 3, dim)) self.dropout = nn.Dropout(emb_dropout) self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) self.cnn1 = cnn_extractor(dim=channels, input_plane=dim // 8) # For temporal data self.cnn2 = cnn_extractor(dim=channels, input_plane=dim // 8) # For magnitude data self.cnn3 = cnn_extractor(dim=channels, input_plane=dim // 8) # For phase data def forward(self, x, opt): batch, _, time_steps = x.shape t, m, p = x[:, :, :time_steps // 2], x[:, :, time_steps // 2: time_steps * 3 // 4], x[:, :, -time_steps // 4:] assert t.shape[-1] % opt.patch_len == 0 == m.shape[-1] % opt.patch_len == p.shape[-1] % opt.patch_len t, m, p = self.to_patch(t), self.to_patch(m), self.to_patch(p) patch2seq = nn.Sequential(nn.AdaptiveAvgPool1d(1), Rearrange('(b n) c 1 -> b n c', b=batch)) cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=batch) x = torch.cat((cls_tokens[:, 0:1, :], patch2seq(self.cnn1(t)), cls_tokens[:, 1:2, :], patch2seq(self.cnn2(m)), cls_tokens[:, 2:3, :], patch2seq(self.cnn3(p))), dim=1) b, t, c = x.shape time_steps = t - 3 t_token_idx, m_token_idx, p_token_idx = 0, time_steps // 2 + 1, time_steps * 3 // 4 + 2 x[:m_token_idx] += self.modal_embedding[:1] x[m_token_idx: p_token_idx] += self.modal_embedding[1:2] x[p_token_idx: ] += self.modal_embedding[2:] x += self.pos_embedding[:, : t] x = self.dropout(x) x = self.transformer(x) t_token, m_token, p_token = x[:, t_token_idx], x[:, m_token_idx], x[:, p_token_idx] avg = (t_token + m_token + p_token) / 3 return avg def TFR_Encoder(seq_len, patch_len, dim, in_dim, depth): vit = TFR(seq_len=seq_len, patch_len=patch_len, #num_classes=num_class, dim=dim, depth=depth, heads=8, mlp_dim=dim, dropout=0.2, emb_dropout=0.1, channels=in_dim) return vit class CRT(nn.Module): def __init__( self, encoder, decoder_dim, decoder_depth=2, decoder_heads=8, decoder_dim_head=64, patch_len = 20, in_dim=12 ): super().__init__() self.encoder = encoder num_patches, encoder_dim = encoder.pos_embedding.shape[-2:] self.to_patch = encoder.to_patch pixel_values_per_patch = in_dim * patch_len # decoder parameters self.modal_embedding = self.encoder.modal_embedding self.mask_token = nn.Parameter(torch.randn(3, decoder_dim)) self.decoder = Transformer(dim=decoder_dim, depth=decoder_depth, heads=decoder_heads, dim_head=decoder_dim_head, mlp_dim=decoder_dim) self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim) self.to_pixels = nn.ModuleList([nn.Linear(decoder_dim, pixel_values_per_patch) for i in range(3)]) self.projs = nn.ModuleList([nn.Linear(decoder_dim, decoder_dim) for i in range(2)]) self.to_clicks = nn.Linear(decoder_dim, 2 * patch_len) def IDC_loss(self, tokens, encoded_tokens): B, T, D = tokens.shape tokens, encoded_tokens = F.normalize(tokens, dim=-1), F.normalize(encoded_tokens, dim=-1) encoded_tokens = encoded_tokens.transpose(2, 1) cross_mul = torch.exp(torch.matmul(tokens, encoded_tokens)) mask = (1 - torch.eye(T)).unsqueeze(0).to(tokens.device) cross_mul = cross_mul * mask return torch.log(cross_mul.sum(-1).sum(-1)).mean(-1) def forward(self, x, clicks, mask_ratio=0.75, beta = 1e-4, beta2 = 1e-4): device = x.device patches = self.to_patch[0](x) batch, num_patches, c, length = patches.shape clickpatches = self.to_patch[0](clicks.unsqueeze(1).float()) num_masked = int(mask_ratio * num_patches) rand_indices1 = torch.randperm(num_patches // 2, device=device) masked_indices1 = rand_indices1[: num_masked // 2].sort()[0] unmasked_indices1 = rand_indices1[num_masked // 2:].sort()[0] rand_indices2 = torch.randperm(num_patches // 4, device=device) masked_indices2, unmasked_indices2 = rand_indices2[: num_masked // 4].sort()[0], rand_indices2[num_masked // 4:].sort()[0] rand_indices = torch.cat((masked_indices1, unmasked_indices1, masked_indices2 + num_patches // 2, unmasked_indices2 + num_patches // 2, masked_indices2 + num_patches // 4 * 3, unmasked_indices2 + num_patches // 4 * 3)) masked_num_t, masked_num_f = masked_indices1.shape[0], 2 * masked_indices2.shape[0] unmasked_num_t, unmasked_num_f = unmasked_indices1.shape[0], 2 * unmasked_indices2.shape[0] tpatches = patches[:, : num_patches // 2, :, :] mpatches, ppatches = patches[:, num_patches // 2: num_patches * 3 // 4, :, :], patches[:, -num_patches // 4:, :, :] unmasked_tpatches = tpatches[:, unmasked_indices1, :, :] unmasked_mpatches, unmasked_ppatches = mpatches[:, unmasked_indices2, :, :], ppatches[:, unmasked_indices2, :, :] t_tokens, m_tokens, p_tokens = self.to_patch[1](unmasked_tpatches), self.to_patch[1](unmasked_mpatches), self.to_patch[1](unmasked_ppatches) t_tokens, m_tokens, p_tokens = self.encoder.cnn1(t_tokens), self.encoder.cnn2(m_tokens), self.encoder.cnn3(p_tokens) Flat = nn.Sequential(nn.AdaptiveAvgPool1d(1), Rearrange('(b n) c 1 -> b n c', b=batch)) t_tokens, m_tokens, p_tokens = Flat(t_tokens), Flat(m_tokens), Flat(p_tokens) ori_tokens = torch.cat((t_tokens, m_tokens, p_tokens), 1).clone() cls_tokens = repeat(self.encoder.cls_token, '() n d -> b n d', b=batch) tokens = torch.cat((cls_tokens[:, 0:1, :], t_tokens, cls_tokens[:, 1:2, :], m_tokens, cls_tokens[:, 2:3, :], p_tokens), dim=1) t_idx, m_idx, p_idx = num_patches // 2 - 1, num_patches * 3 // 4 - 1, num_patches - 1 pos_embedding = torch.cat((self.encoder.pos_embedding[:, 0:1, :], self.encoder.pos_embedding[:, unmasked_indices1 + 1, :], self.encoder.pos_embedding[:, t_idx + 2: t_idx + 3], self.encoder.pos_embedding[:, unmasked_indices2 + t_idx + 3, :], self.encoder.pos_embedding[:, m_idx + 3: m_idx + 4], self.encoder.pos_embedding[:, unmasked_indices2 + m_idx + 4, :]), dim=1) modal_embedding = torch.cat((repeat(self.modal_embedding[0], '1 d -> 1 n d', n=unmasked_num_t + 1), repeat(self.modal_embedding[1], '1 d -> 1 n d', n=unmasked_num_f // 2 + 1), repeat(self.modal_embedding[2], '1 d -> 1 n d', n=unmasked_num_f // 2 + 1)), dim=1) tokens = tokens + pos_embedding + modal_embedding encoded_tokens = self.encoder.transformer(tokens) # tokens: unmasked tokens + CLS tokens t_idx, m_idx, p_idx = unmasked_num_t, unmasked_num_f // 2 + unmasked_num_t + 1, -1 idc_loss = self.IDC_loss(self.projs[0](ori_tokens), self.projs[1](torch.cat(([encoded_tokens[:, 1: t_idx+1], encoded_tokens[:, t_idx+2: m_idx+1], encoded_tokens[:, m_idx+2: ]]), dim=1))) decoder_tokens = encoded_tokens mask_tokens1 = repeat(self.mask_token[0], 'd -> b n d', b=batch, n=masked_num_t) mask_tokens2 = repeat(self.mask_token[1], 'd -> b n d', b=batch, n=masked_num_f // 2) mask_tokens3 = repeat(self.mask_token[2], 'd -> b n d', b=batch, n=masked_num_f // 2) mask_tokens = torch.cat((mask_tokens1, mask_tokens2, mask_tokens3), dim=1) decoder_pos_emb = self.decoder_pos_emb(torch.cat( (masked_indices1, masked_indices2 + num_patches // 2, masked_indices2 + num_patches * 3 // 4))) mask_tokens = mask_tokens + decoder_pos_emb decoder_tokens = torch.cat((decoder_tokens, mask_tokens), dim=1) decoded_tokens = self.decoder(decoder_tokens) mask_tokens = decoded_tokens[:, -mask_tokens.shape[1]:] pred_pixel_values_t = self.to_pixels[0](torch.cat((decoder_tokens[:, 1: t_idx + 1], mask_tokens[:, : masked_num_t]), 1)) pred_pixel_values_m = self.to_pixels[1](torch.cat((decoder_tokens[:, t_idx+2: m_idx+1], mask_tokens[:, masked_num_t: masked_num_f // 2 + masked_num_t]), 1)) pred_pixel_values_p = self.to_pixels[2](torch.cat((decoder_tokens[:, m_idx+2: -mask_tokens.shape[1]], mask_tokens[:, -masked_num_f // 2:]), 1)) pred_pixel_values = torch.cat((pred_pixel_values_t, pred_pixel_values_m, pred_pixel_values_p), dim=1) recon_loss = F.mse_loss(pred_pixel_values, rearrange(patches[:,rand_indices], 'b n c p -> b n (c p)')) rmvcls_embedding = torch.cat((decoder_tokens[:, 1: t_idx + 1], mask_tokens[:, : masked_num_t]), 1) click_pred = self.to_clicks(rmvcls_embedding) click_pred = rearrange(click_pred, 'b n (c p) -> b n c p', p=2) click_pred = rearrange(click_pred, 'b n c p -> (b n c) p') clicksGT = clickpatches[:, rand_indices1].squeeze() clicksGT = rearrange(clicksGT, 'b n c -> (b n c)') clicksOH = F.one_hot(clicksGT.to(torch.int64), num_classes=2) pos_weight = (clicks==0).sum()/(clicks==1).sum() click_pred_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(click_pred, clicksOH.float()) return recon_loss + beta * idc_loss + beta2 * click_pred_loss class Model(nn.Module): def __init__(self, opt): super(Model, self).__init__() self.opt = opt self.encoder = TFR_Encoder(seq_len=opt.seq_len, patch_len=opt.patch_len, dim=opt.dim, in_dim=opt.in_dim, depth=opt.depth) self.crt = CRT(encoder=self.encoder, decoder_dim=opt.dim, in_dim=opt.in_dim, patch_len=opt.patch_len) def forward(self, x, clicks, ratio = 0.5): return self.crt(x, clicks, mask_ratio=ratio, beta2=self.opt.beta2) def self_supervised_learning(model, X, clicks, opt, modelfile, min_ratio=0.3, max_ratio=0.8): optimizer = optim.Adam(model.parameters(), opt.AElearningRate) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20) model.to(opt.device) model.train() dataset = SSLDataSet(X, clicks) dataloader = DataLoader(dataset, batch_size=opt.AEbatchSize, shuffle=True) losses = [] for _ in range(opt.AEepochs): for idx, batch in enumerate(dataloader): x, clicks = tuple(t.to(opt.device) for t in batch) loss = model.to(opt.device)(x, clicks, ratio=max(min_ratio, min(max_ratio, _ / opt.AEepochs))) optimizer.zero_grad() loss.backward() optimizer.step() losses.append(float(loss)) if idx % 20 == 0: print(idx,'/',len(dataloader), sum(losses) / len(losses)) scheduler.step(_) torch.save(model.to('cpu'), modelfile) class encoder_classifier(nn.Module): def __init__(self, AE, opt): super().__init__() self.opt = opt hidden_channels = 64 self.inputdim = opt.dim self.encoder = copy.deepcopy(AE.encoder) self.classifier = nn.Sequential( nn.Linear(self.inputdim, hidden_channels), nn.ReLU(), nn.Dropout(p = opt.Headdropout), nn.Linear(hidden_channels, hidden_channels), nn.ReLU(), nn.Dropout(p = opt.Headdropout), nn.Linear(hidden_channels, opt.num_class) ) # initialize the classifier for m in self.classifier.modules(): if isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0) def forward(self, x): features = self.encoder(x, self.opt) out = self.classifier(features) return out def weight_init_layer(m): if hasattr(m, 'reset_parameters'): m.reset_parameters() def weight_init_whole(m): children = list(m.children()) if len(children)==0: # No more children layers m.apply(weight_init_layer) else: # Get children layers for c in children: weight_init_whole(c) def finetuning(AE, train_set, valid_set, opt, modelfile, multi_label=True): stage_info = {0: 'Train the classifier only', 1: 'Finetune the whole model'} print('Stage %d: '%(opt.stage), stage_info[opt.stage]) train_loader = DataLoader(train_set, batch_size=opt.HeadbatchSize, shuffle=True) model = encoder_classifier(AE, opt) model.train() best_res = 0 step = 0 if opt.stage == 0: optimizer = optim.Adam(model.classifier.parameters(), lr=opt.HeadlearningRate) else: lr = opt.HeadlearningRate/2 optimizer = optim.Adam(model.parameters(), lr=lr) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20) for epoch in range(opt.Headepochs): for batch_idx, batch in enumerate(train_loader): step += 1 x, y, clicks = tuple(t.to(opt.device) for t in batch) pred = model.to(opt.device)(x) if not multi_label: pos_weight = (y==0.).sum()/y.sum() y = F.one_hot(y, num_classes=2) #loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(pred, y.squeeze().float()) loss = nn.BCEWithLogitsLoss()(pred, y.squeeze().float()) else: if pred.shape[0]==1: loss = nn.CrossEntropyLoss()(pred, y.long()) else: loss = nn.CrossEntropyLoss()(pred, y.squeeze().long()) optimizer.zero_grad() loss.backward() optimizer.step() if step % 10 == 0: [trainacc, trainf1] = test(model, train_set, opt.HeadbatchSize, multi_label) if trainf1 > best_res: best_res = trainf1 scheduler.step(best_res) def test(model, dataset, batch_size, multi_label): def topAcc(GTlabel, pred_proba, top): count = 0 for i in range(GTlabel.shape[0]): if (GTlabel[i]) in np.argsort(-pred_proba[i])[:top]: count+=1.0 return count/GTlabel.shape[0] model.eval() model.to(model.opt.device) testloader = DataLoader(dataset, batch_size=batch_size) pred_prob = None with torch.no_grad(): for batch in testloader: x, y, clicks = tuple(t.to(model.opt.device) for t in batch) pred = model(x) if pred_prob is None: pred_prob = pred.cpu().detach().numpy() else: pred_prob = np.concatenate([pred_prob, pred.cpu().detach().numpy()], axis=0) acc = topAcc(dataset.label.squeeze(), pred_prob, 1) pred = np.argmax(pred_prob, axis=1) f1 = f1_score(dataset.label.squeeze(), pred, average='macro') model.train() return [acc, f1]