# Deep Back-Projection Networks For Super-Resolution # https://arxiv.org/abs/1803.02735 from model import common import torch import torch.nn as nn def make_model(args, parent=False): return DDBPN(args) def projection_conv(in_channels, out_channels, scale, up=True): kernel_size, stride, padding = { 2: (6, 2, 2), 4: (8, 4, 2), 8: (12, 8, 2) }[scale] if up: conv_f = nn.ConvTranspose2d else: conv_f = nn.Conv2d return conv_f( in_channels, out_channels, kernel_size, stride=stride, padding=padding ) class DenseProjection(nn.Module): def __init__(self, in_channels, nr, scale, up=True, bottleneck=True): super(DenseProjection, self).__init__() if bottleneck: self.bottleneck = nn.Sequential(*[ nn.Conv2d(in_channels, nr, 1), nn.PReLU(nr) ]) inter_channels = nr else: self.bottleneck = None inter_channels = in_channels self.conv_1 = nn.Sequential(*[ projection_conv(inter_channels, nr, scale, up), nn.PReLU(nr) ]) self.conv_2 = nn.Sequential(*[ projection_conv(nr, inter_channels, scale, not up), nn.PReLU(inter_channels) ]) self.conv_3 = nn.Sequential(*[ projection_conv(inter_channels, nr, scale, up), nn.PReLU(nr) ]) def forward(self, x): if self.bottleneck is not None: x = self.bottleneck(x) a_0 = self.conv_1(x) b_0 = self.conv_2(a_0) e = b_0.sub(x) a_1 = self.conv_3(e) out = a_0.add(a_1) return out class DDBPN(nn.Module): def __init__(self, args): super(DDBPN, self).__init__() scale = args.scale[0] n0 = 128 nr = 32 self.depth = 6 rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) initial = [ nn.Conv2d(args.n_colors, n0, 3, padding=1), nn.PReLU(n0), nn.Conv2d(n0, nr, 1), nn.PReLU(nr) ] self.initial = nn.Sequential(*initial) self.upmodules = nn.ModuleList() self.downmodules = nn.ModuleList() channels = nr for i in range(self.depth): self.upmodules.append( DenseProjection(channels, nr, scale, True, i > 1) ) if i != 0: channels += nr channels = nr for i in range(self.depth - 1): self.downmodules.append( DenseProjection(channels, nr, scale, False, i != 0) ) channels += nr reconstruction = [ nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1) ] self.reconstruction = nn.Sequential(*reconstruction) self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) def forward(self, x): x = self.sub_mean(x) x = self.initial(x) h_list = [] l_list = [] for i in range(self.depth - 1): if i == 0: l = x else: l = torch.cat(l_list, dim=1) h_list.append(self.upmodules[i](l)) l_list.append(self.downmodules[i](torch.cat(h_list, dim=1))) h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1))) out = self.reconstruction(torch.cat(h_list, dim=1)) out = self.add_mean(out) return out