使用pytorch1.0版本
垃圾分类数据集下载
整理数据集
数据集的文件结构为:
- RubbishClassification
- 0
- 1
- ...
- 15
- train.json
- val.json
- ...
需要将其整理为:
- RubbishDatasets
- train
- 0
- 1
- ...
- 15
- valid
- 0
- 1
- ...
- 15
训练集和数据集已经以json
的形式划分。
创建路径
1 2 3 4 5 6 7 8 9 10 11 12 13
| import json import os from glob import glob
path = './RubbishDatasets' if(not os.path.exists(os.path.join(path, 'train'))): os.makedirs(os.path.join(path, 'train')) if(not os.path.exists(os.path.join(path, 'valid'))): os.makedirs(os.path.join(path, 'valid'))
for t in ['train', 'valid']: for i in range(16): os.mkdir(os.path.join(path, t, str(i)+'/'))
|
划分数据集
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| with open('./RubbishClassification/train.json', 'r') as f: trainjson = json.load(f)
for item in trainjson: file = item['path'].replace('\\', '/') folder = file.split('/')[-2] image = file.split('/')[-1] if os.path.exists(file): os.rename('./'+item['path'], os.path.join(path, 'train', folder, image))
with open('./RubbishClassification/val.json', 'r') as f: trainjson = json.load(f)
for item in trainjson: file = item['path'].replace('\\', '/') folder = file.split('/')[-2] image = file.split('/')[-1] if os.path.exists(file): os.rename('./'+item['path'], os.path.join(path, 'valid', folder, image))
|
加载数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| import torch import torchvision import torchvision.transforms as transforms import torchvision.datasets as dset
transform = transforms.Compose( [transforms.Resize([224, 224]), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))] )
train = dset.ImageFolder('./RubbishDatasets/train/', transform) valid = dset.ImageFolder('./RubbishDatasets/valid/', transform)
trainloader = torch.utils.data.DataLoader(train, batch_size=8, shuffle=True, num_workers=3) validloader = torch.utils.data.DataLoader(valid, batch_size=8, shuffle=True, num_workers=3)
device = device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
定义网络模型
直接使用 ResNet18
模型。
1 2 3 4 5 6
| import torch.nn as nn net = torchvision.models.resnet18(pretrained = True) num_ftrs = net.fc.in_features net.fc = nn.Linear(num_ftrs, 16)
net = net.to(device)
|
定义损失函数和优化器
1 2 3 4
| import torch.optim as optim
criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
|
训练模型
每一轮训练后进行验证, 保存最优模型。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
| from imghdr import tests from msilib import datasizemask
dataloaders = { 'train': trainloader, 'valid': validloader
} dataset_sizes = { 'train' : len(train.imgs), 'valid': len(valid.imgs) }
best_model_wts = net.state_dict() best_acc = 0.0
for epoch in range(10): print('Epoch {}/{}'.format(epoch, 9)) print('-'*10)
for phase in ['train', 'valid']: if phase == 'train': net.train(True) else: net.train(False) running_loss = 0.0 running_corrects = 0
for data in dataloaders[phase]: inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad()
outputs = net(inputs) loss = criterion(outputs, labels) _, preds = torch.max(outputs.data, 1)
if phase == 'train': loss.backward() optimizer.step() running_loss += loss.item() running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_corrects / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
if(phase == 'valid' and epoch_acc > best_acc): best_acc = epoch_acc best_model_wts = net.state_dict()
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
| Epoch 0/9 ---------- train Loss: 0.1502 Acc: 0.6335 valid Loss: 0.0661 Acc: 0.8271 Epoch 1/9 ---------- train Loss: 0.0703 Acc: 0.8264 valid Loss: 0.0667 Acc: 0.8380 Epoch 2/9 ---------- train Loss: 0.0434 Acc: 0.8927 valid Loss: 0.0639 Acc: 0.8489 Epoch 3/9 ---------- train Loss: 0.0269 Acc: 0.9390 valid Loss: 0.0627 Acc: 0.8538 Epoch 4/9 ---------- train Loss: 0.0203 Acc: 0.9537 valid Loss: 0.0581 Acc: 0.8680 Epoch 5/9 ---------- train Loss: 0.0138 Acc: 0.9748 valid Loss: 0.0680 Acc: 0.8501 Epoch 6/9 ... Epoch 9/9 ---------- train Loss: 0.0073 Acc: 0.9873 valid Loss: 0.0656 Acc: 0.8653
|
保存和加载最优模型
1
| torch.save(best_model_wts, 'D:\JupyterNotes\RubbishClassifyModel.pth')
|
1 2 3 4 5
| model = torchvision.models.resnet18(pretrained=False) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 16) model.load_state_dict(torch.load('./RubbishClassifyModel.pth')) model.eval()
|