神经风格迁移

  2019-9-21 


风格迁移

内容图像(C),生成的图像(G ),样式图像(S)

合成图像为唯一需要学习(更新)的参数(视为模型参数),也是生成的结果,而不是VGG

输入:C

初始化:可以选择让G=C,也可以随机初始化一个G

这个算法里面的参数(也就是是合成图片里面的每个像素点,我们可以将内容图片直接 copy 成合成图片,然后训练使得他的风格和我们的风格图片相似,同时也可以随机化一张图片作为合成图片(两种初始化),然后训练他使得他与内容图片以及风格图片具有相似性。特征的提取

model

这里不训练模型,直接使用预训练的vgg来提取特征

def get_model():
    vgg = models.vgg19(pretrained=True).features
    for param in vgg.parameters():
        param.requires_grad = False
    return vgg

extract loss

这里使用以前用过的register_forward_hook()函数来hook

class LayerActivations():
    features = []
    def __init__(self,model,layer_nums):
        #这里要hook多层,要保存多层钩子
        self.hooks = []
        for layer_num in layer_nums:
            self.hooks.append(model[layer_num].register_forward_hook(self.hook_fn))
    def hook_fn(self,module,input,output):
        self.features.append(output)
    #捕捉完输出后不能忘掉remove方法,否则所有输入都累加在一起会内存溢出 
    def remove(self):
        for hook in self.hooks:
            hook.remove()

#该函数用于将图片输入进模型,经过钩子获取指定多层的特征输出      
def extract_layers(layers,img,model):
    la = LayerActivations(model=model,layer_nums=layers)
    #清空缓存
    la.features=[]
    #运行模型,开钩
    out = model(img)       
    #已经获取到特征,这是我们关注的东西,然后注销钩子
    la.remove()
    #注意这里返回的是列表,一次性钩了多层
    return la.features

定义损失

# content loss
class ContentLoss(nn.Module):
    def __init__(self,weight):
        super().__init__()
        self.weight = weight
        self.mseloss = nn.MSELoss()
    def forward(self,inputs,targets):
        out = self.mseloss(inputs,targets)
        return out * self.weight

衡量风格损失使用的是Gram Matrix,对于$k$个向量$\alpha_1,\alpha_2,\cdots,\alpha_k$

把特征提取输出变换为$k$行$h*w$列的矩阵$X$,那么$X=<\alpha_1,\alpha_2,\cdots,\alpha_k>$,其中向量$x_i$代表了通道$i$上的样式特征,于是Gram矩阵实际上计算出了各个通道的两两相关性

# style loss
class GramMatrix(nn.Module):
    def forward(self,inputs):
        b,c,h,w = inputs.size()
        #hw维度扁平化
        features = inputs.view(b,c,h*w)
        #bmm(A) =  AxA^T
        gram_matrix = torch.bmm(features,features.transpose(1,2))
        #为防止值过大,做归一化
        gram_matrix.div_(h*w)
        return gram_matrix
class StyleLoss(nn.Module):
    def __init__(self,weight):
        super().__init__()
        self.gramfunc = GramMatrix()
        self.mseloss = nn.MSELoss()
        self.weight = weight
    def forward(self,inputs,targets):
        gram_inputs = self.gramfunc(inputs)
        gram_targets = self.gramfunc(targets)
        out = self.mseloss(gram_inputs,gram_targets)
        return out * self.weight
# loss function
def loss_fn(content_layers,style_layers,content_weight,style_weight,content_img,style_img,model,input_param,times):
    loss = 0 
    #提取风格图片的特征并计算,注意这里是提取input(合成)和风格/内容的,都要提取
    style_features = extract_layers(style_layers,style_img,model)  
    input_features = extract_layers(style_layers,input_param,model) 
    for style_layer_index in range(len(style_layers)):
        style_loss = StyleLoss(style_weight)
        style_loss_this = style_loss(input_features[style_layer_index],style_features[style_layer_index])      
        
    #提取内容图片的特征并计算   
    content_features = extract_layers(content_layers,content_img,model)
    input_features = extract_layers(content_layers,input_param,model)     
    for content_layer_index in range(len(content_layers)):
        content_loss = ContentLoss(content_weight)
        content_loss_this = content_loss(input_features[content_layer_index],content_features[content_layer_index])    
    
    loss = style_loss_this+content_loss_this
    
    if times % 5 == 0:
        print('style loss:{:.4f} , content loss:{:.4f}'.format(style_loss_this,content_loss_this))
    return loss

TRAIN

使用LBGFS作为优化函数,LBFGS使用的时候就需要用到闭包

def train(Epoch,input_img,
          content_weight,style_weight,content_img,style_img,
          content_layers=[21],style_layers=[1,6,11,20,25]):
    #默认vgg模型
    model = get_model().cuda()
    
    #初始化合成图片,将input_img作为初始的合成图片并参数化
    input_param = nn.Parameter(input_img.data)    
    
    #初始化优化器
    optimizer = torch.optim.LBFGS([input_param])
    for epoch in range(Epoch):
        print('epoch:{}'.format(epoch))
        #开始训练
        global times
        times=0
        def closure():
            optimizer.zero_grad() #勿忘
            global times
            times+=1
            loss = loss_fn(content_layers,style_layers,content_weight,style_weight,content_img,style_img,model,input_param,times)
            loss.backward() 
            return loss
        optimizer.step(closure) 
    return input_param.data 

def showImg(res):
    img = res[0]
    img = img.cpu()
    img = transforms.ToPILImage()(img) 
    plt.imshow(img)

且听风吟