提取中间层特征图

使用VGG处理图像, 查看中间层的特征图。

导入包

1
2
3
4
5
6
7
8
9
10
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as dset

import matplotlib.pyplot as plt
import numpy as np

from PIL import Image
import numpy as np

建立模型和相应函数

建立VGG模型

只使用VGG的特征层, 不需要分类用的全连接层。

1
2
vgg = torchvision.models.vgg19(pretrained=True)
model = vgg.features

预处理图片

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

def image_loader(image_dir):
img = Image.open(image_dir).convert('RGB')

image_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])

img = image_transform(img)
img = img.unsqueeze(0)
return img

显示特征图

  • plt.ion() 进入交互模式, 使用imshow()即可显示图片, 不需要show()
  • plt.figure() 创建画板
  • plt.axis('off') 关闭坐标轴
  • plt.imshow() 输入的图片需转换成numpy类型, 且通道数不超过3。 需要灰度图可以加入参数 cmap='gray'
1
2
3
4
5
6
7
8
9
10
11
12
plt.ion()
def show_feature_map(img):
img = img.squeeze(0)
img = img.cpu().numpy()
img_num = img.shape[0]
row_num = np.ceil(np.sqrt(img_num))
plt.figure()

for index in range(1, img_num+1):
plt.subplot(int(row_num), int(row_num), index)
plt.imshow(img[index-1])
plt.axis('off')

提取特征图

想要提取对应层的特征图, 可以用全面层去处理图像。

1
2
3
4
5
6
def get_feature_map(model, k, x):
with torch.no_grad():
for index, layer in enumerate(model):
x = layer(x)
if(index == k):
return x

显示特征图

1
2
3
4
5
image_dir = './8.jpg'
img = image_loader(image_dir)

feature_map = get_feature_map(model, 1, img)
show_feature_map(feature_map)