图像风格转换(Natural-Style论文复现)

复现论文:A Neural Algorithm of Artistic Style

代码参考: 使用PyTorch进行图像风格转换 (基于原论文对模型进行了微调)

导入包并配置设备

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from PIL import Image
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
import torchvision.models as models

import copy

device = torch.device("cpu")

导入图片并进行预处理

这里img_size为处理和生成的图片大小。 越大效果越好, 但训练速度会变慢, 对显存要求也更高。

根据自己需要和设备进行适当调整。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
img_size = 720

transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor()
])

def image_loader(image_dir):
image = Image.open(image_dir)
image = transform(image).unsqueeze(0)
return image.to(device, torch.float)

style_image = image_loader('./1.jpg')
content_image = image_loader('./2.jpg')

assert style_image.size() == content_image.size()

绘图方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
unloader = transforms.ToPILImage()

plt.ion()

def imshow(tensor, title=None):
image = tensor.cpu().clone()
image = image.squeeze(0)
image = unloader(image)
plt.figure(figsize=(12, 6))
plt.axis('off')
plt.imshow(image)
if title is not None:
plt.title(title)
plt.pause(0.001)

imshow(style_image, title='Style Image')
imshow(content_image, title='Content Image')

损失函数

需要定义内容损失和风格损失。

这里将损失函数以层的方式定义, 后面构建模型会将其加入对应位置。

构造时需要传入对应图片(内容图片、风格图片)经过VGG处理后的特征图。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class ContentLoss(nn.Module):
def __init__(self, target) -> None:
super(ContentLoss, self).__init__()
self.target = target.detach()

def forward(self, input):
self.loss = F.mse_loss(input, self.target)
return input

def gram_matrix(input):
a, b, c, d = input.size()
features = input.view(a*b, c*d)
G = torch.mm(features, features.t())
return G.div(a*b*c*d)

class StyleLoss(nn.Module):
def __init__(self, target_feature) -> None:
super(StyleLoss, self).__init__()
self.target = gram_matrix(target_feature).detach()

def forward(self, input):
G = gram_matrix(input)
self.loss = F.mse_loss(G, self.target)
return input

数据预处理层

将图片数据正则化以层的方式定义, 放在模型第1层。

1
2
3
4
5
6
7
8
9
10
11
normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)

class Normalization(nn.Module):
def __init__(self, mean, std) -> None:
super(Normalization, self).__init__()
self.mean = mean.clone().detach().requires_grad_(True).view(-1, 1, 1)
self.std = std.clone().detach().requires_grad_(True).view(-1, 1, 1)

def forward(self, img):
return (img-self.mean) / self.std

构建模型

首先查看VGG的特征层。

1
2
3
vgg_features = models.vgg19(pretrained=True).features.to(device).eval()

vgg_features
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(17): ReLU(inplace=True)
(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(24): ReLU(inplace=True)
(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(26): ReLU(inplace=True)
(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): ReLU(inplace=True)
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(33): ReLU(inplace=True)
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): ReLU(inplace=True)
(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

根据论文的描述结合论文的Fig1, 内容层应该为第10个卷积层(conv4_2),风格层应该为第1、3、5、9、13个卷积层(conv1-5_1)。

接下来构建模型, 在对应的内容层和风格层后插入损失函数层。

最后一层损失函数层后面的部分可以截取。 因为生成图片根据模型的损失函数来趋近风格图片。

需要注意的是VGG模型的ReLUinplace属性默认为True,可能会需要设置为False

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
content_layers_default = ['conv_10']
style_layers_default = ['conv_1', 'conv_3', 'conv_5', 'conv_9', 'conv_13']

def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
style_image, content_image,
content_layers=content_layers_default,
style_layers = style_layers_default):
cnn = copy.deepcopy(cnn)

normalization = Normalization(normalization_mean, normalization_std).to(device)

content_losses = []
style_losses = []
model = nn.Sequential(normalization)

i = 0
for layer in cnn.children():
if isinstance(layer, nn.Conv2d):
i += 1
name = 'conv_{}'.format(i)
elif isinstance(layer, nn.ReLU):
layer.inplace = False
name = 'relu_{}'.format(i)
elif isinstance(layer, nn.MaxPool2d):
name = 'pool_{}'.format(i)
elif isinstance(layer, nn.BatchNorm2d):
name = 'bn_{}'.format(i)
else:
raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

model.add_module(name, layer)

if name in content_layers:
target = model(content_image).detach()
content_loss = ContentLoss(target)
model.add_module("content_loss_{}".format(i), content_loss)
content_losses.append(content_loss)

if name in style_layers:
target_feature = model(style_image).detach()
style_loss = StyleLoss(target_feature)
model.add_module("style_loss_{}".format(i), style_loss)
style_losses.append(style_loss)

for i in range(len(model)-1, -1, -1):
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
break

model = model[:(i+1)]

return model, style_losses, content_losses

配置优化器

注意L-BFGS算法需要多次计算, 一次step可能会跑多轮迭代。 所以最大迭代次数会比预设的num_steps多一些, 并不是发生未知错误导致死循环。

实测设置num_steps=4时大概能测80到100轮。

1
2
3
def get_input_optimizer(input_image):
optimizer = optim.LBFGS([input_image.requires_grad_()], max_iter=20, max_eval=25)
return optimizer

训练函数

按照论文设置$\alpha/\beta=10^{-3}$。 设置为10000和10是为了让损失下降更快一些(仅为无根据的猜想, 没有进行实验确认是否有用)。

训练轮次建议先设小点测试一下。 style_weightcontent_weight可根据需要自行调整。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def run_style_transfer(cnn, normalization_mean, normalization_std,
content_image, style_image, input_image,
num_steps=300, style_weight=10000, content_weight=10):
print('Building the style transfer model..')
model, style_losses, content_losses = get_style_model_and_losses(cnn,
normalization_mean, normalization_std, style_image, content_image)
optimizer = get_input_optimizer(input_image)

print('Optimizing..')
run = [0]

torch.autograd.set_detect_anomaly = True

while run[0] <= num_steps:

def closure():
input_image.data.clamp_(0,1)

optimizer.zero_grad()
model(input_image)
style_score = 0
content_score = 0

for sl in style_losses:
style_score = style_score + sl.loss
for cl in content_losses:
content_score = content_score + cl.loss

style_score *= style_weight
content_score *= content_weight

loss = style_score + content_score
loss.backward()

run[0] += 1

print("run {}".format(run))
print('Style Loss : {:4f} Content Loss {:4f}'.format(
style_score.item(), content_score.item()
))
print()

return style_score+content_score

optimizer.step(closure)

input_image.data.clamp_(0, 1)

return input_image

进行训练

训练过程中风格损失是逐渐减少的, 内容损失初始为零, 并逐渐增大。

1
2
3
4
5
6
input_image = content_image.clone()

output = run_style_transfer(vgg_features, normalization_mean, normalization_std,
content_image, style_image, input_image)

imshow(output, title='Output Image')

图片保存

关于保存图片问题, 我是保存在编辑区中显示的图片。 通过修改imshow函数可以修改显示图片的尺寸。

imsave保存的图片完全不能看。 解决方法并没有找到。