使用VGG
处理图像, 查看中间层的特征图。
导入包
1 2 3 4 5 6 7 8 9 10
| import torch import torchvision import torchvision.transforms as transforms import torchvision.datasets as dset
import matplotlib.pyplot as plt import numpy as np
from PIL import Image import numpy as np
|
建立模型和相应函数
建立VGG模型
只使用VGG
的特征层, 不需要分类用的全连接层。
1 2
| vgg = torchvision.models.vgg19(pretrained=True) model = vgg.features
|
预处理图片
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225]
def image_loader(image_dir): img = Image.open(image_dir).convert('RGB')
image_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ])
img = image_transform(img) img = img.unsqueeze(0) return img
|
显示特征图
plt.ion()
进入交互模式, 使用imshow()
即可显示图片, 不需要show()
plt.figure()
创建画板
plt.axis('off')
关闭坐标轴
plt.imshow()
输入的图片需转换成numpy
类型, 且通道数不超过3。 需要灰度图可以加入参数 cmap='gray'
1 2 3 4 5 6 7 8 9 10 11 12
| plt.ion() def show_feature_map(img): img = img.squeeze(0) img = img.cpu().numpy() img_num = img.shape[0] row_num = np.ceil(np.sqrt(img_num)) plt.figure()
for index in range(1, img_num+1): plt.subplot(int(row_num), int(row_num), index) plt.imshow(img[index-1]) plt.axis('off')
|
提取特征图
想要提取对应层的特征图, 可以用全面层去处理图像。
1 2 3 4 5 6
| def get_feature_map(model, k, x): with torch.no_grad(): for index, layer in enumerate(model): x = layer(x) if(index == k): return x
|
显示特征图
1 2 3 4 5
| image_dir = './8.jpg' img = image_loader(image_dir)
feature_map = get_feature_map(model, 1, img) show_feature_map(feature_map)
|