垃圾分类器

使用pytorch1.0版本

垃圾分类数据集下载

整理数据集

数据集的文件结构为:

- RubbishClassification
    - 0
    - 1
    - ...
    - 15
    - train.json
    - val.json
    - ...

需要将其整理为:

- RubbishDatasets
  - train
    - 0
    - 1
    - ...
    - 15
  - valid
    - 0
    - 1
    - ...
    - 15

训练集和数据集已经以json的形式划分。

创建路径

1
2
3
4
5
6
7
8
9
10
11
12
13
import json
import os
from glob import glob

path = './RubbishDatasets'
if(not os.path.exists(os.path.join(path, 'train'))):
os.makedirs(os.path.join(path, 'train'))
if(not os.path.exists(os.path.join(path, 'valid'))):
os.makedirs(os.path.join(path, 'valid'))

for t in ['train', 'valid']:
for i in range(16):
os.mkdir(os.path.join(path, t, str(i)+'/'))

划分数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
with open('./RubbishClassification/train.json', 'r') as f:
trainjson = json.load(f)

for item in trainjson:
file = item['path'].replace('\\', '/')
folder = file.split('/')[-2]
image = file.split('/')[-1]
if os.path.exists(file):
os.rename('./'+item['path'], os.path.join(path, 'train', folder, image))


with open('./RubbishClassification/val.json', 'r') as f:
trainjson = json.load(f)

for item in trainjson:
file = item['path'].replace('\\', '/')
folder = file.split('/')[-2]
image = file.split('/')[-1]
if os.path.exists(file):
os.rename('./'+item['path'], os.path.join(path, 'valid', folder, image))

加载数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as dset

transform = transforms.Compose(
[transforms.Resize([224, 224]),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]
)

train = dset.ImageFolder('./RubbishDatasets/train/', transform)
valid = dset.ImageFolder('./RubbishDatasets/valid/', transform)

trainloader = torch.utils.data.DataLoader(train, batch_size=8, shuffle=True, num_workers=3)
validloader = torch.utils.data.DataLoader(valid, batch_size=8, shuffle=True, num_workers=3)

device = device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

定义网络模型

直接使用 ResNet18 模型。

1
2
3
4
5
6
import torch.nn as nn
net = torchvision.models.resnet18(pretrained = True)
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 16)

net = net.to(device)

定义损失函数和优化器

1
2
3
4
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

训练模型

每一轮训练后进行验证, 保存最优模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from imghdr import tests
from msilib import datasizemask


dataloaders = {
'train': trainloader,
'valid': validloader

}
dataset_sizes = {
'train' : len(train.imgs),
'valid': len(valid.imgs)
}

best_model_wts = net.state_dict()
best_acc = 0.0

for epoch in range(10):
print('Epoch {}/{}'.format(epoch, 9))
print('-'*10)

for phase in ['train', 'valid']:
if phase == 'train':
net.train(True)
else:
net.train(False)

running_loss = 0.0
running_corrects = 0

for data in dataloaders[phase]:
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()

outputs = net(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs.data, 1)

if phase == 'train':
loss.backward()
optimizer.step()

running_loss += loss.item()
running_corrects += torch.sum(preds == labels.data)

epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects / dataset_sizes[phase]

print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

if(phase == 'valid' and epoch_acc > best_acc):
best_acc = epoch_acc
best_model_wts = net.state_dict()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

Epoch 0/9
----------
train Loss: 0.1502 Acc: 0.6335
valid Loss: 0.0661 Acc: 0.8271
Epoch 1/9
----------
train Loss: 0.0703 Acc: 0.8264
valid Loss: 0.0667 Acc: 0.8380
Epoch 2/9
----------
train Loss: 0.0434 Acc: 0.8927
valid Loss: 0.0639 Acc: 0.8489
Epoch 3/9
----------
train Loss: 0.0269 Acc: 0.9390
valid Loss: 0.0627 Acc: 0.8538
Epoch 4/9
----------
train Loss: 0.0203 Acc: 0.9537
valid Loss: 0.0581 Acc: 0.8680
Epoch 5/9
----------
train Loss: 0.0138 Acc: 0.9748
valid Loss: 0.0680 Acc: 0.8501
Epoch 6/9
...
Epoch 9/9
----------
train Loss: 0.0073 Acc: 0.9873
valid Loss: 0.0656 Acc: 0.8653

保存和加载最优模型

1
torch.save(best_model_wts, 'D:\JupyterNotes\RubbishClassifyModel.pth')
1
2
3
4
5
model = torchvision.models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 16)
model.load_state_dict(torch.load('./RubbishClassifyModel.pth'))
model.eval()