init
This commit is contained in:
commit
b429f9807d
7 changed files with 1186 additions and 0 deletions
438
CRT.py
Normal file
438
CRT.py
Normal file
|
@ -0,0 +1,438 @@
|
|||
import pdb, copy
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from base_models import MLP, resnet1d18, SSLDataSet
|
||||
from sklearn.metrics import f1_score
|
||||
|
||||
class cnn_extractor(nn.Module):
|
||||
def __init__(self, dim, input_plane):
|
||||
super(cnn_extractor, self).__init__()
|
||||
self.cnn = resnet1d18(input_channels=dim, inplanes=input_plane)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.cnn(x)
|
||||
return x
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim=-1)
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
|
||||
class TFR(nn.Module):
|
||||
def __init__(self, seq_len, patch_len, dim, depth, heads, mlp_dim, channels=12,
|
||||
dim_head=64, dropout=0., emb_dropout=0.):
|
||||
'''
|
||||
The encoder of CRT
|
||||
'''
|
||||
super().__init__()
|
||||
self.patch_len = patch_len
|
||||
self.seq_len = seq_len
|
||||
|
||||
assert seq_len % (4 * patch_len) == 0, \
|
||||
'The seq_len should be 4 * n * patch_len, or there must be patch with both magnitude and phase data.'
|
||||
num_patches = seq_len // patch_len
|
||||
patch_dim = channels * patch_len
|
||||
|
||||
self.to_patch = nn.Sequential(Rearrange('b c (n p1) -> b n c p1', p1=patch_len),
|
||||
Rearrange('b n c p1 -> (b n) c p1'))
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 3, dim))
|
||||
self.modal_embedding = nn.Parameter(torch.randn(3, 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 3, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.cnn1 = cnn_extractor(dim=channels, input_plane=dim // 8) # For temporal data
|
||||
self.cnn2 = cnn_extractor(dim=channels, input_plane=dim // 8) # For magnitude data
|
||||
self.cnn3 = cnn_extractor(dim=channels, input_plane=dim // 8) # For phase data
|
||||
|
||||
def forward(self, x, opt):
|
||||
batch, _, time_steps = x.shape
|
||||
t, m, p = x[:, :, :time_steps // 2], x[:, :, time_steps // 2: time_steps * 3 // 4], x[:, :, -time_steps // 4:]
|
||||
assert t.shape[-1] % opt.patch_len == 0 == m.shape[-1] % opt.patch_len == p.shape[-1] % opt.patch_len
|
||||
|
||||
t, m, p = self.to_patch(t), self.to_patch(m), self.to_patch(p)
|
||||
patch2seq = nn.Sequential(nn.AdaptiveAvgPool1d(1),
|
||||
Rearrange('(b n) c 1 -> b n c', b=batch))
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=batch)
|
||||
x = torch.cat((cls_tokens[:, 0:1, :], patch2seq(self.cnn1(t)),
|
||||
cls_tokens[:, 1:2, :], patch2seq(self.cnn2(m)),
|
||||
cls_tokens[:, 2:3, :], patch2seq(self.cnn3(p))), dim=1)
|
||||
|
||||
b, t, c = x.shape
|
||||
time_steps = t - 3
|
||||
t_token_idx, m_token_idx, p_token_idx = 0, time_steps // 2 + 1, time_steps * 3 // 4 + 2
|
||||
x[:m_token_idx] += self.modal_embedding[:1]
|
||||
x[m_token_idx: p_token_idx] += self.modal_embedding[1:2]
|
||||
x[p_token_idx: ] += self.modal_embedding[2:]
|
||||
x += self.pos_embedding[:, : t]
|
||||
x = self.dropout(x)
|
||||
x = self.transformer(x)
|
||||
t_token, m_token, p_token = x[:, t_token_idx], x[:, m_token_idx], x[:, p_token_idx]
|
||||
avg = (t_token + m_token + p_token) / 3
|
||||
return avg
|
||||
|
||||
|
||||
def TFR_Encoder(seq_len, patch_len, dim, in_dim, depth):
|
||||
vit = TFR(seq_len=seq_len,
|
||||
patch_len=patch_len,
|
||||
#num_classes=num_class,
|
||||
dim=dim,
|
||||
depth=depth,
|
||||
heads=8,
|
||||
mlp_dim=dim,
|
||||
dropout=0.2,
|
||||
emb_dropout=0.1,
|
||||
channels=in_dim)
|
||||
return vit
|
||||
|
||||
class CRT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder,
|
||||
decoder_dim,
|
||||
decoder_depth=2,
|
||||
decoder_heads=8,
|
||||
decoder_dim_head=64,
|
||||
patch_len = 20,
|
||||
in_dim=12
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
|
||||
self.to_patch = encoder.to_patch
|
||||
pixel_values_per_patch = in_dim * patch_len
|
||||
|
||||
# decoder parameters
|
||||
self.modal_embedding = self.encoder.modal_embedding
|
||||
self.mask_token = nn.Parameter(torch.randn(3, decoder_dim))
|
||||
self.decoder = Transformer(dim=decoder_dim,
|
||||
depth=decoder_depth,
|
||||
heads=decoder_heads,
|
||||
dim_head=decoder_dim_head,
|
||||
mlp_dim=decoder_dim)
|
||||
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
|
||||
self.to_pixels = nn.ModuleList([nn.Linear(decoder_dim, pixel_values_per_patch) for i in range(3)])
|
||||
self.projs = nn.ModuleList([nn.Linear(decoder_dim, decoder_dim) for i in range(2)])
|
||||
|
||||
self.to_clicks = nn.Linear(decoder_dim, 2 * patch_len)
|
||||
|
||||
def IDC_loss(self, tokens, encoded_tokens):
|
||||
B, T, D = tokens.shape
|
||||
tokens, encoded_tokens = F.normalize(tokens, dim=-1), F.normalize(encoded_tokens, dim=-1)
|
||||
encoded_tokens = encoded_tokens.transpose(2, 1)
|
||||
cross_mul = torch.exp(torch.matmul(tokens, encoded_tokens))
|
||||
mask = (1 - torch.eye(T)).unsqueeze(0).to(tokens.device)
|
||||
cross_mul = cross_mul * mask
|
||||
return torch.log(cross_mul.sum(-1).sum(-1)).mean(-1)
|
||||
|
||||
def forward(self, x, clicks, mask_ratio=0.75, beta = 1e-4, beta2 = 1e-4):
|
||||
device = x.device
|
||||
patches = self.to_patch[0](x)
|
||||
batch, num_patches, c, length = patches.shape
|
||||
clickpatches = self.to_patch[0](clicks.unsqueeze(1).float())
|
||||
|
||||
num_masked = int(mask_ratio * num_patches)
|
||||
rand_indices1 = torch.randperm(num_patches // 2, device=device)
|
||||
masked_indices1 = rand_indices1[: num_masked // 2].sort()[0]
|
||||
unmasked_indices1 = rand_indices1[num_masked // 2:].sort()[0]
|
||||
rand_indices2 = torch.randperm(num_patches // 4, device=device)
|
||||
masked_indices2, unmasked_indices2 = rand_indices2[: num_masked // 4].sort()[0], rand_indices2[num_masked // 4:].sort()[0]
|
||||
rand_indices = torch.cat((masked_indices1, unmasked_indices1,
|
||||
masked_indices2 + num_patches // 2, unmasked_indices2 + num_patches // 2,
|
||||
masked_indices2 + num_patches // 4 * 3, unmasked_indices2 + num_patches // 4 * 3))
|
||||
|
||||
masked_num_t, masked_num_f = masked_indices1.shape[0], 2 * masked_indices2.shape[0]
|
||||
unmasked_num_t, unmasked_num_f = unmasked_indices1.shape[0], 2 * unmasked_indices2.shape[0]
|
||||
|
||||
tpatches = patches[:, : num_patches // 2, :, :]
|
||||
mpatches, ppatches = patches[:, num_patches // 2: num_patches * 3 // 4, :, :], patches[:, -num_patches // 4:, :, :]
|
||||
|
||||
unmasked_tpatches = tpatches[:, unmasked_indices1, :, :]
|
||||
unmasked_mpatches, unmasked_ppatches = mpatches[:, unmasked_indices2, :, :], ppatches[:, unmasked_indices2, :, :]
|
||||
t_tokens, m_tokens, p_tokens = self.to_patch[1](unmasked_tpatches), self.to_patch[1](unmasked_mpatches), self.to_patch[1](unmasked_ppatches)
|
||||
t_tokens, m_tokens, p_tokens = self.encoder.cnn1(t_tokens), self.encoder.cnn2(m_tokens), self.encoder.cnn3(p_tokens)
|
||||
Flat = nn.Sequential(nn.AdaptiveAvgPool1d(1),
|
||||
Rearrange('(b n) c 1 -> b n c', b=batch))
|
||||
t_tokens, m_tokens, p_tokens = Flat(t_tokens), Flat(m_tokens), Flat(p_tokens)
|
||||
ori_tokens = torch.cat((t_tokens, m_tokens, p_tokens), 1).clone()
|
||||
|
||||
cls_tokens = repeat(self.encoder.cls_token, '() n d -> b n d', b=batch)
|
||||
tokens = torch.cat((cls_tokens[:, 0:1, :], t_tokens,
|
||||
cls_tokens[:, 1:2, :], m_tokens,
|
||||
cls_tokens[:, 2:3, :], p_tokens), dim=1)
|
||||
|
||||
t_idx, m_idx, p_idx = num_patches // 2 - 1, num_patches * 3 // 4 - 1, num_patches - 1
|
||||
pos_embedding = torch.cat((self.encoder.pos_embedding[:, 0:1, :], self.encoder.pos_embedding[:, unmasked_indices1 + 1, :],
|
||||
self.encoder.pos_embedding[:, t_idx + 2: t_idx + 3],
|
||||
self.encoder.pos_embedding[:, unmasked_indices2 + t_idx + 3, :],
|
||||
self.encoder.pos_embedding[:, m_idx + 3: m_idx + 4],
|
||||
self.encoder.pos_embedding[:, unmasked_indices2 + m_idx + 4, :]), dim=1)
|
||||
|
||||
modal_embedding = torch.cat((repeat(self.modal_embedding[0], '1 d -> 1 n d', n=unmasked_num_t + 1),
|
||||
repeat(self.modal_embedding[1], '1 d -> 1 n d', n=unmasked_num_f // 2 + 1),
|
||||
repeat(self.modal_embedding[2], '1 d -> 1 n d', n=unmasked_num_f // 2 + 1)), dim=1)
|
||||
|
||||
tokens = tokens + pos_embedding + modal_embedding
|
||||
|
||||
encoded_tokens = self.encoder.transformer(tokens) # tokens: unmasked tokens + CLS tokens
|
||||
|
||||
t_idx, m_idx, p_idx = unmasked_num_t, unmasked_num_f // 2 + unmasked_num_t + 1, -1
|
||||
|
||||
idc_loss = self.IDC_loss(self.projs[0](ori_tokens), self.projs[1](torch.cat(([encoded_tokens[:, 1: t_idx+1], encoded_tokens[:, t_idx+2: m_idx+1], encoded_tokens[:, m_idx+2: ]]), dim=1)))
|
||||
|
||||
decoder_tokens = encoded_tokens
|
||||
|
||||
mask_tokens1 = repeat(self.mask_token[0], 'd -> b n d', b=batch, n=masked_num_t)
|
||||
mask_tokens2 = repeat(self.mask_token[1], 'd -> b n d', b=batch, n=masked_num_f // 2)
|
||||
mask_tokens3 = repeat(self.mask_token[2], 'd -> b n d', b=batch, n=masked_num_f // 2)
|
||||
mask_tokens = torch.cat((mask_tokens1, mask_tokens2, mask_tokens3), dim=1)
|
||||
|
||||
decoder_pos_emb = self.decoder_pos_emb(torch.cat(
|
||||
(masked_indices1, masked_indices2 + num_patches // 2, masked_indices2 + num_patches * 3 // 4)))
|
||||
|
||||
mask_tokens = mask_tokens + decoder_pos_emb
|
||||
decoder_tokens = torch.cat((decoder_tokens, mask_tokens), dim=1)
|
||||
decoded_tokens = self.decoder(decoder_tokens)
|
||||
|
||||
mask_tokens = decoded_tokens[:, -mask_tokens.shape[1]:]
|
||||
|
||||
pred_pixel_values_t = self.to_pixels[0](torch.cat((decoder_tokens[:, 1: t_idx + 1], mask_tokens[:, : masked_num_t]), 1))
|
||||
pred_pixel_values_m = self.to_pixels[1](torch.cat((decoder_tokens[:, t_idx+2: m_idx+1], mask_tokens[:, masked_num_t: masked_num_f // 2 + masked_num_t]), 1))
|
||||
pred_pixel_values_p = self.to_pixels[2](torch.cat((decoder_tokens[:, m_idx+2: -mask_tokens.shape[1]], mask_tokens[:, -masked_num_f // 2:]), 1))
|
||||
pred_pixel_values = torch.cat((pred_pixel_values_t, pred_pixel_values_m, pred_pixel_values_p), dim=1)
|
||||
|
||||
recon_loss = F.mse_loss(pred_pixel_values, rearrange(patches[:,rand_indices], 'b n c p -> b n (c p)'))
|
||||
|
||||
rmvcls_embedding = torch.cat((decoder_tokens[:, 1: t_idx + 1], mask_tokens[:, : masked_num_t]), 1)
|
||||
click_pred = self.to_clicks(rmvcls_embedding)
|
||||
click_pred = rearrange(click_pred, 'b n (c p) -> b n c p', p=2)
|
||||
click_pred = rearrange(click_pred, 'b n c p -> (b n c) p')
|
||||
clicksGT = clickpatches[:, rand_indices1].squeeze()
|
||||
clicksGT = rearrange(clicksGT, 'b n c -> (b n c)')
|
||||
clicksOH = F.one_hot(clicksGT.to(torch.int64), num_classes=2)
|
||||
pos_weight = (clicks==0).sum()/(clicks==1).sum()
|
||||
click_pred_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(click_pred, clicksOH.float())
|
||||
|
||||
return recon_loss + beta * idc_loss + beta2 * click_pred_loss
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, opt):
|
||||
super(Model, self).__init__()
|
||||
self.opt = opt
|
||||
self.encoder = TFR_Encoder(seq_len=opt.seq_len,
|
||||
patch_len=opt.patch_len,
|
||||
dim=opt.dim,
|
||||
in_dim=opt.in_dim,
|
||||
depth=opt.depth)
|
||||
self.crt = CRT(encoder=self.encoder,
|
||||
decoder_dim=opt.dim,
|
||||
in_dim=opt.in_dim,
|
||||
patch_len=opt.patch_len)
|
||||
|
||||
def forward(self, x, clicks, ratio = 0.5):
|
||||
return self.crt(x, clicks, mask_ratio=ratio, beta2=self.opt.beta2)
|
||||
|
||||
def self_supervised_learning(model, X, clicks, opt, modelfile, min_ratio=0.3, max_ratio=0.8):
|
||||
optimizer = optim.Adam(model.parameters(), opt.AElearningRate)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20)
|
||||
|
||||
model.to(opt.device)
|
||||
model.train()
|
||||
|
||||
dataset = SSLDataSet(X, clicks)
|
||||
dataloader = DataLoader(dataset, batch_size=opt.AEbatchSize, shuffle=True)
|
||||
|
||||
losses = []
|
||||
for _ in range(opt.AEepochs):
|
||||
for idx, batch in enumerate(dataloader):
|
||||
x, clicks = tuple(t.to(opt.device) for t in batch)
|
||||
loss = model.to(opt.device)(x, clicks, ratio=max(min_ratio, min(max_ratio, _ / opt.AEepochs)))
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
losses.append(float(loss))
|
||||
if idx % 20 == 0:
|
||||
print(idx,'/',len(dataloader), sum(losses) / len(losses))
|
||||
scheduler.step(_)
|
||||
torch.save(model.to('cpu'), modelfile)
|
||||
|
||||
class encoder_classifier(nn.Module):
|
||||
def __init__(self, AE, opt):
|
||||
super().__init__()
|
||||
self.opt = opt
|
||||
hidden_channels = 64
|
||||
self.inputdim = opt.dim
|
||||
self.encoder = copy.deepcopy(AE.encoder)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Linear(self.inputdim, hidden_channels),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p = opt.Headdropout),
|
||||
nn.Linear(hidden_channels, hidden_channels),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p = opt.Headdropout),
|
||||
nn.Linear(hidden_channels, opt.num_class)
|
||||
)
|
||||
# initialize the classifier
|
||||
for m in self.classifier.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_normal_(m.weight)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
def forward(self, x):
|
||||
features = self.encoder(x, self.opt)
|
||||
out = self.classifier(features)
|
||||
return out
|
||||
|
||||
def weight_init_layer(m):
|
||||
if hasattr(m, 'reset_parameters'):
|
||||
m.reset_parameters()
|
||||
def weight_init_whole(m):
|
||||
children = list(m.children())
|
||||
if len(children)==0: # No more children layers
|
||||
m.apply(weight_init_layer)
|
||||
else: # Get children layers
|
||||
for c in children:
|
||||
weight_init_whole(c)
|
||||
|
||||
def finetuning(AE, train_set, valid_set, opt, modelfile, multi_label=True):
|
||||
stage_info = {0: 'Train the classifier only', 1: 'Finetune the whole model'}
|
||||
print('Stage %d: '%(opt.stage), stage_info[opt.stage])
|
||||
|
||||
train_loader = DataLoader(train_set, batch_size=opt.HeadbatchSize, shuffle=True)
|
||||
model = encoder_classifier(AE, opt)
|
||||
model.train()
|
||||
|
||||
best_res = 0
|
||||
step = 0
|
||||
if opt.stage == 0:
|
||||
optimizer = optim.Adam(model.classifier.parameters(), lr=opt.HeadlearningRate)
|
||||
else:
|
||||
lr = opt.HeadlearningRate/2
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20)
|
||||
for epoch in range(opt.Headepochs):
|
||||
for batch_idx, batch in enumerate(train_loader):
|
||||
step += 1
|
||||
x, y, clicks = tuple(t.to(opt.device) for t in batch)
|
||||
pred = model.to(opt.device)(x)
|
||||
if not multi_label:
|
||||
pos_weight = (y==0.).sum()/y.sum()
|
||||
y = F.one_hot(y, num_classes=2)
|
||||
#loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(pred, y.squeeze().float())
|
||||
loss = nn.BCEWithLogitsLoss()(pred, y.squeeze().float())
|
||||
else:
|
||||
if pred.shape[0]==1:
|
||||
loss = nn.CrossEntropyLoss()(pred, y.long())
|
||||
else:
|
||||
loss = nn.CrossEntropyLoss()(pred, y.squeeze().long())
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if step % 10 == 0:
|
||||
[trainacc, trainf1] = test(model, train_set, opt.HeadbatchSize, multi_label)
|
||||
if trainf1 > best_res:
|
||||
best_res = trainf1
|
||||
scheduler.step(best_res)
|
||||
|
||||
def test(model, dataset, batch_size, multi_label):
|
||||
def topAcc(GTlabel, pred_proba, top):
|
||||
count = 0
|
||||
for i in range(GTlabel.shape[0]):
|
||||
if (GTlabel[i]) in np.argsort(-pred_proba[i])[:top]:
|
||||
count+=1.0
|
||||
return count/GTlabel.shape[0]
|
||||
|
||||
model.eval()
|
||||
model.to(model.opt.device)
|
||||
testloader = DataLoader(dataset, batch_size=batch_size)
|
||||
pred_prob = None
|
||||
with torch.no_grad():
|
||||
for batch in testloader:
|
||||
x, y, clicks = tuple(t.to(model.opt.device) for t in batch)
|
||||
pred = model(x)
|
||||
if pred_prob is None:
|
||||
pred_prob = pred.cpu().detach().numpy()
|
||||
else:
|
||||
pred_prob = np.concatenate([pred_prob, pred.cpu().detach().numpy()], axis=0)
|
||||
acc = topAcc(dataset.label.squeeze(), pred_prob, 1)
|
||||
pred = np.argmax(pred_prob, axis=1)
|
||||
f1 = f1_score(dataset.label.squeeze(), pred, average='macro')
|
||||
|
||||
model.train()
|
||||
return [acc, f1]
|
Loading…
Add table
Add a link
Reference in a new issue