325 lines
11 KiB
Python
325 lines
11 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import einops
|
||
|
from einops.layers.torch import Rearrange
|
||
|
import math
|
||
|
|
||
|
from .helpers import (
|
||
|
SinusoidalPosEmb,
|
||
|
Downsample1d,
|
||
|
Upsample1d,
|
||
|
Conv1dBlock,
|
||
|
)
|
||
|
|
||
|
class GroupNorm32(nn.GroupNorm):
|
||
|
def forward(self, x):
|
||
|
return super().forward(x.float()).type(x.dtype)
|
||
|
|
||
|
def zero_module(module):
|
||
|
"""
|
||
|
Zero out the parameters of a module and return it.
|
||
|
"""
|
||
|
for p in module.parameters():
|
||
|
p.detach().zero_()
|
||
|
return module
|
||
|
|
||
|
def normalization(channels):
|
||
|
"""
|
||
|
Make a standard normalization layer.
|
||
|
|
||
|
:param channels: number of input channels.
|
||
|
:return: an nn.Module for normalization.
|
||
|
"""
|
||
|
return GroupNorm32(32, channels)
|
||
|
|
||
|
class AttentionBlock(nn.Module):
|
||
|
"""
|
||
|
An attention block that allows spatial positions to attend to each other.
|
||
|
|
||
|
Originally ported from here, but adapted to the N-d case.
|
||
|
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, channels, num_heads=1, use_checkpoint=False):
|
||
|
super().__init__()
|
||
|
self.channels = channels
|
||
|
self.num_heads = num_heads
|
||
|
self.use_checkpoint = use_checkpoint
|
||
|
|
||
|
self.norm = normalization(channels)
|
||
|
#self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||
|
self.qkv = Conv1dBlock( channels, channels * 3, 2)
|
||
|
|
||
|
self.attention = QKVAttention()
|
||
|
#self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||
|
self.proj_out = zero_module(Conv1dBlock(channels, channels, 4))
|
||
|
|
||
|
def forward(self, x):
|
||
|
#print(x.shape)
|
||
|
b, c, *spatial = x.shape
|
||
|
x = x.reshape(b, c, -1)
|
||
|
qkv = self.qkv(self.norm(x))
|
||
|
qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
|
||
|
h = self.attention(qkv)
|
||
|
#print(h.shape, qkv.shape)
|
||
|
h = h.reshape(b, -1, h.shape[-1])
|
||
|
h = self.proj_out(h)
|
||
|
#print(x.shape, h.shape)
|
||
|
return (x + h).reshape(b, c, *spatial)
|
||
|
|
||
|
|
||
|
class QKVAttention(nn.Module):
|
||
|
"""
|
||
|
A module which performs QKV attention.
|
||
|
"""
|
||
|
|
||
|
def forward(self, qkv):
|
||
|
"""
|
||
|
Apply QKV attention.
|
||
|
|
||
|
:param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
|
||
|
:return: an [N x C x T] tensor after attention.
|
||
|
"""
|
||
|
ch = qkv.shape[1] // 3
|
||
|
q, k, v = torch.split(qkv, ch, dim=1)
|
||
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
||
|
weight = torch.einsum(
|
||
|
"bct,bcs->bts", q * scale, k * scale
|
||
|
) # More stable with f16 than dividing afterwards
|
||
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||
|
return torch.einsum("bts,bcs->bct", weight, v)
|
||
|
|
||
|
|
||
|
class ResidualTemporalBlock(nn.Module):
|
||
|
|
||
|
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=3):
|
||
|
super().__init__()
|
||
|
|
||
|
self.blocks = nn.ModuleList([
|
||
|
Conv1dBlock(inp_channels, out_channels, kernel_size),
|
||
|
Conv1dBlock(out_channels, out_channels, kernel_size, if_zero=True)
|
||
|
])
|
||
|
self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines
|
||
|
nn.Mish(),
|
||
|
nn.Linear(embed_dim, out_channels),
|
||
|
Rearrange('batch t -> batch t 1'),
|
||
|
)
|
||
|
self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \
|
||
|
if inp_channels != out_channels else nn.Identity()
|
||
|
|
||
|
self.dropout = nn.Dropout(0.5)
|
||
|
|
||
|
def forward(self, x, t):
|
||
|
out = self.blocks[0](x) + self.time_mlp(t) # for diffusion
|
||
|
# out = self.blocks[0](x) # for Noise and Deterministic Baselines
|
||
|
out = self.blocks[1](out)
|
||
|
return out + self.residual_conv(self.dropout(x))
|
||
|
|
||
|
|
||
|
class TemporalUnet(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
transition_dim,
|
||
|
num_class,
|
||
|
dim=32,
|
||
|
dim_mults=(1, 2, 4, 8),
|
||
|
):
|
||
|
super().__init__()
|
||
|
|
||
|
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
|
||
|
in_out = list(zip(dims[:-1], dims[1:]))
|
||
|
|
||
|
time_dim = dim
|
||
|
self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines
|
||
|
SinusoidalPosEmb(dim),
|
||
|
nn.Linear(dim, dim * 4),
|
||
|
nn.Mish(),
|
||
|
nn.Linear(dim * 4, dim),
|
||
|
)
|
||
|
|
||
|
self.downs = nn.ModuleList([])
|
||
|
self.ups = nn.ModuleList([])
|
||
|
num_resolutions = len(in_out)
|
||
|
|
||
|
self.label_embed = nn.Embedding(num_class, time_dim)
|
||
|
|
||
|
# print(in_out)
|
||
|
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||
|
is_last = ind >= (num_resolutions - 1)
|
||
|
|
||
|
self.downs.append(nn.ModuleList([
|
||
|
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim),
|
||
|
AttentionBlock(dim_out, use_checkpoint=False, num_heads=4),
|
||
|
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim),
|
||
|
Downsample1d(dim_out) if not is_last else nn.Identity()
|
||
|
]))
|
||
|
|
||
|
'''self.downs.append(nn.ModuleList([
|
||
|
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim*2),
|
||
|
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim*2),
|
||
|
Downsample1d(dim_out) if not is_last else nn.Identity()
|
||
|
]))'''
|
||
|
|
||
|
mid_dim = dims[-1]
|
||
|
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim)
|
||
|
self.attention = AttentionBlock(mid_dim, use_checkpoint=False, num_heads=16)
|
||
|
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim)
|
||
|
'''self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim*2)
|
||
|
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim*2)'''
|
||
|
|
||
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||
|
is_last = ind >= (num_resolutions - 1)
|
||
|
|
||
|
self.ups.append(nn.ModuleList([
|
||
|
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim),
|
||
|
AttentionBlock(dim_in, use_checkpoint=False, num_heads=4),
|
||
|
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim),
|
||
|
Upsample1d(dim_in) if not is_last else nn.Identity()
|
||
|
]))
|
||
|
'''self.ups.append(nn.ModuleList([
|
||
|
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim*2),
|
||
|
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim*2),
|
||
|
Upsample1d(dim_in) if not is_last else nn.Identity()
|
||
|
]))'''
|
||
|
|
||
|
self.final_conv = nn.Sequential(
|
||
|
Conv1dBlock(dim, dim, kernel_size=3, if_zero=True),
|
||
|
nn.Conv1d(dim, transition_dim, 1),
|
||
|
)
|
||
|
|
||
|
def forward(self, x, time, class_label):
|
||
|
x = einops.rearrange(x, 'b h t -> b t h')
|
||
|
|
||
|
# t = None # for Noise and Deterministic Baselines
|
||
|
t = self.time_mlp(time) # for diffusion
|
||
|
#print(x.shape, time.shape, t.shape, class_label.shape)
|
||
|
#y_emb = self.label_embed(class_label)
|
||
|
#print(t.shape, y_emb.shape)
|
||
|
#t = t + y_emb
|
||
|
#t = torch.cat((t, y_emb), 1)
|
||
|
h = []
|
||
|
|
||
|
for resnet, attn, resnet2, downsample in self.downs:
|
||
|
x = resnet(x, t)
|
||
|
x = attn(x)
|
||
|
x = resnet2(x, t)
|
||
|
h.append(x)
|
||
|
x = downsample(x)
|
||
|
|
||
|
x = self.mid_block1(x, t)
|
||
|
x = self.attention(x)
|
||
|
x = self.mid_block2(x, t)
|
||
|
|
||
|
for resnet,attn, resnet2, upsample in self.ups:
|
||
|
x = torch.cat((x, h.pop()), dim=1)
|
||
|
x = resnet(x, t)
|
||
|
x = attn(x)
|
||
|
x = resnet2(x, t)
|
||
|
x = upsample(x)
|
||
|
|
||
|
x = self.final_conv(x)
|
||
|
x = einops.rearrange(x, 'b t h -> b h t')
|
||
|
return x
|
||
|
|
||
|
class TemporalUnetNoAttn(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
transition_dim,
|
||
|
num_class,
|
||
|
dim=32,
|
||
|
dim_mults=(1, 2, 4, 8),
|
||
|
):
|
||
|
super().__init__()
|
||
|
|
||
|
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
|
||
|
in_out = list(zip(dims[:-1], dims[1:]))
|
||
|
|
||
|
time_dim = dim
|
||
|
self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines
|
||
|
SinusoidalPosEmb(dim),
|
||
|
nn.Linear(dim, dim * 4),
|
||
|
nn.Mish(),
|
||
|
nn.Linear(dim * 4, dim),
|
||
|
)
|
||
|
|
||
|
self.downs = nn.ModuleList([])
|
||
|
self.ups = nn.ModuleList([])
|
||
|
num_resolutions = len(in_out)
|
||
|
|
||
|
self.label_embed = nn.Embedding(num_class, time_dim)
|
||
|
|
||
|
# print(in_out)
|
||
|
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||
|
is_last = ind >= (num_resolutions - 1)
|
||
|
|
||
|
self.downs.append(nn.ModuleList([
|
||
|
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim),
|
||
|
#AttentionBlock(dim_out, use_checkpoint=False, num_heads=4),
|
||
|
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim),
|
||
|
Downsample1d(dim_out) if not is_last else nn.Identity()
|
||
|
]))
|
||
|
|
||
|
'''self.downs.append(nn.ModuleList([
|
||
|
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim*2),
|
||
|
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim*2),
|
||
|
Downsample1d(dim_out) if not is_last else nn.Identity()
|
||
|
]))'''
|
||
|
|
||
|
mid_dim = dims[-1]
|
||
|
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim)
|
||
|
#self.attention = AttentionBlock(mid_dim, use_checkpoint=False, num_heads=16)
|
||
|
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim)
|
||
|
'''self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim*2)
|
||
|
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim*2)'''
|
||
|
|
||
|
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||
|
is_last = ind >= (num_resolutions - 1)
|
||
|
|
||
|
self.ups.append(nn.ModuleList([
|
||
|
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim),
|
||
|
#AttentionBlock(dim_in, use_checkpoint=False, num_heads=4),
|
||
|
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim),
|
||
|
Upsample1d(dim_in) if not is_last else nn.Identity()
|
||
|
]))
|
||
|
'''self.ups.append(nn.ModuleList([
|
||
|
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim*2),
|
||
|
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim*2),
|
||
|
Upsample1d(dim_in) if not is_last else nn.Identity()
|
||
|
]))'''
|
||
|
|
||
|
self.final_conv = nn.Sequential(
|
||
|
Conv1dBlock(dim, dim, kernel_size=3, if_zero=True),
|
||
|
nn.Conv1d(dim, transition_dim, 1),
|
||
|
)
|
||
|
|
||
|
def forward(self, x, time, class_label):
|
||
|
x = einops.rearrange(x, 'b h t -> b t h')
|
||
|
|
||
|
# t = None # for Noise and Deterministic Baselines
|
||
|
t = self.time_mlp(time) # for diffusion
|
||
|
#print(x.shape, time.shape, t.shape, class_label.shape)
|
||
|
#y_emb = self.label_embed(class_label)
|
||
|
#print(t.shape, y_emb.shape)
|
||
|
#t = t + y_emb
|
||
|
#t = torch.cat((t, y_emb), 1)
|
||
|
h = []
|
||
|
|
||
|
for resnet, resnet2, downsample in self.downs:
|
||
|
x = resnet(x, t)
|
||
|
x = resnet2(x, t)
|
||
|
h.append(x)
|
||
|
x = downsample(x)
|
||
|
|
||
|
x = self.mid_block1(x, t)
|
||
|
x = self.mid_block2(x, t)
|
||
|
|
||
|
for resnet, resnet2, upsample in self.ups:
|
||
|
x = torch.cat((x, h.pop()), dim=1)
|
||
|
x = resnet(x, t)
|
||
|
x = resnet2(x, t)
|
||
|
x = upsample(x)
|
||
|
|
||
|
x = self.final_conv(x)
|
||
|
x = einops.rearrange(x, 'b t h -> b h t')
|
||
|
return x
|