from model import common import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models class VGG(nn.Module): def __init__(self, conv_index, rgb_range=1): super(VGG, self).__init__() vgg_features = models.vgg19(pretrained=True).features modules = [m for m in vgg_features] if conv_index.find('22') >= 0: self.vgg = nn.Sequential(*modules[:8]) elif conv_index.find('54') >= 0: self.vgg = nn.Sequential(*modules[:35]) vgg_mean = (0.485, 0.456, 0.406) vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) for p in self.parameters(): p.requires_grad = False def forward(self, sr, hr): def _forward(x): x = self.sub_mean(x) x = self.vgg(x) return x vgg_sr = _forward(sr) with torch.no_grad(): vgg_hr = _forward(hr.detach()) loss = F.mse_loss(vgg_sr, vgg_hr) return loss