物体分類

keras や pytorch などのパッケージを使用すると、ニューラルネットワークや畳み込み演算などを容易に計算できるようになる。そのため、画像分類を行うための予測器を構築することが容易になってきた。このページではユーザー数が多く、情報量の多い pytorch を使用して物体分類用のモデルの構築方法を示していく。

ニューラルネットワークによる物体分類

この項目では、pytorch を使用して非常に単純なニューラルネットワークを構築して、画像分類を行う方法を示す。この項目で示した例は、pytorch の training a classifier ページを参照して作成した。まず、この項目で使うモジュールなどをインポートする。

In [ ]:
import torch
import torchvision
import torchvision.transforms as transforms

# パソコンに OpenMP ランタイムが複数ある場合、異常終了するので、回避策として次の環境変数を設定する
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

ここでデータセットを用意します。とりあえず、pytorch が用意されている関数を使って、サンプルデータセットをダウンロードしてくる。このデータセットは、学習用とテスト用の両方に別れているので、ここで両方をダウンロードしてくる。データのダウンロードは torchvision.datasets.CIFAR10 関数で行う。ダウンロードした後に、torch.utils.data.DataLoader 関数で、データを pytroch にデータを認識させて、学習に使える状態にする。

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 訓練データ
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=1)

# テストデータ
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=1)
Files already downloaded and verified
Files already downloaded and verified

データの用意が終わったので、簡単な浅いニューラルネットワークを作成する。ここで作成するネットワークは、畳み込み層、プーリング層、畳み込み層、プーリング層、全結合層、全結合層のような構成にする。

In [ ]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        
        # 畳み込み演算用のカーネルのサイズを定義
        # 引数は(入力チャンネル数, 出力チャンネル数, カーネルサイズ)
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        
        # プーリング演算用のカーネルのサイズの定義
        # 引数は(カーネルサイズ、移動ステップ数)
        self.pool = nn.MaxPool2d(2, 2)
                
        # 全結合層の構造を定義
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

        
    def forward(self, x):
        
        # 画像(行列型)のデータに対して畳み込み演算とプーリング演算を 2 回繰り返す
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        
        # 行列型のデータをベクトルに整形
        x = x.view(-1, 16 * 5 * 5)
        
        # ベクトルを全結合層に代入
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        # 出力層
        x = self.fc3(x)
        
        return x


net = Net()

損失関数を定義し、最適化を行うためのアルゴリズムを定義する。多クラス分類問題の場合は、損失関数として交差エントロピーを使用するのが一般的であるから、ここで CrossEntropyLoss を使用する。また、最適化アルゴリズムは、とりあえずもっとも基礎的なアルゴリズムである確率的勾配降下法 SGD を使用する。

In [ ]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()

# 上で定義したニューラルネットワークのパラメーターをすべて微分可能な形で optimizer に保存
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

以上で、データの準備、ニューラルネットワークの構造、損失関数の定義、最適化アルゴリズムの準備および設定を終えた。これで学習を行うための作業をすべて終わり、早速、for 文を使って、構築したニューラルネットワークを学習させてみます。

In [5]:
n_epochs = 2

for epoch in range(n_epochs):
    running_loss = 0.0
    
    # 訓練データローダーからデータを 1 batch ずつ取り出す
    for i, data in enumerate(trainloader):
        
        inputs, labels = data
        
        # 前回の学習時に伝播してきた誤差をゼロにする
        optimizer.zero_grad()

        # ニューラルネットワークに学習データを代入して結果を得る
        outputs = net(inputs)
        
        # その出力と教師ラベルを比べて、両者の損失を計算する
        loss = criterion(outputs, labels)
        
        # その損失を誤差逆伝播法で偏微分可能なパラメーター全体に伝播させ、
        loss.backward()
        
        # 全パラメータを更新する(勾配に学習率をかけて更新)
        optimizer.step()
        
        # 以降、学習の進捗状況を知りたいので、1万 batch ずつ、途中進捗を出力
        running_loss += loss.item()
        if i % 2000 == 1999: 
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

            
print('Finished Training')
[1,  2000] loss: 2.207
[1,  4000] loss: 1.844
[1,  6000] loss: 1.638
[1,  8000] loss: 1.550
[1, 10000] loss: 1.517
[1, 12000] loss: 1.467
[2,  2000] loss: 1.399
[2,  4000] loss: 1.340
[2,  6000] loss: 1.320
[2,  8000] loss: 1.289
[2, 10000] loss: 1.284
[2, 12000] loss: 1.265
Finished Training

画像 2 回繰り返して学習を行なった。では、テストデータを使って検証結果をみていくことにする。pytorch では、net(images) を実行するとパラメーターが一時的にメモリに保存され、誤差逆伝播法によるパラメーター更新を高速にしている。しかし、テスト時は、パラメーター更新を行わないので、無駄にメモリを使わないために、torch.no_grad の制約下で、パラメーターを保持しないように指定して実行する。

In [6]:
n_correct = 0
n_total = 0

# 
with torch.no_grad():
    
    # テストデータを 1 batch ずつロードする
    for data in testloader:
        
        images, labels = data
        
        # 画像をネットワークに代入して、予測値を得る
        outputs = net(images)
        
        # 出力値が確率なので、最大確率のラベルを取得
        _, predicted = torch.max(outputs.data, 1)
        
        # これまでにテストした画像の合計枚数
        n_total += labels.size(0)
        
        # これまでのテストで正解数
        n_correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * n_correct / n_total))
Accuracy of the network on the 10000 test images: 56 %

次に各クラスごとの正解率を個別に調べてみる。

In [7]:
# 全クラスの種類
classes= ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


class_correct = list(0. for i in range(len(classes)))
class_total = list(0. for i in range(len(classes)))


with torch.no_grad():
    
    for data in testloader:
        images, labels = data
        
        outputs = net(images)
        
        _, predicted = torch.max(outputs, 1)
        
        c = (predicted == labels).squeeze()
        
        # テストデータをロードする時バッチサイズを 4 にしてあったので、1 枚ずつチェックしていく
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(len(classes)):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))
Accuracy of plane : 63 %
Accuracy of   car : 64 %
Accuracy of  bird : 41 %
Accuracy of   cat : 47 %
Accuracy of  deer : 45 %
Accuracy of   dog : 34 %
Accuracy of  frog : 66 %
Accuracy of horse : 70 %
Accuracy of  ship : 54 %
Accuracy of truck : 75 %

学習済みのネットワークを保存するとき、torch.save 関数を使用する。この際に、保存先のパスを指定する。この関数で保存されるのはネットワーク中のパラメーターなどであり、ネットワーク構造が保存されない。

In [ ]:
torch.save(net.state_dict(), './cifar_net.pth')

そのため、次に、このネットワークをファイルから読み込んで使用する時に、まずネットワーク構造のインスタンスを一度生成してから、そのインスタンスにパラメーターをセットアップする形で読み込む。

In [9]:
net = Net()
net.load_state_dict(torch.load('./cifar_net.pth'))
Out[9]:
<All keys matched successfully>

GPU を使用して学習と検証を行う場合は、次のようにしてニューラルネットワークとデータを GPU 上に転送する必要がある。

In [10]:
# CUDA (GPU向けの汎用並列コンピューティングプラットフォーム) 見つかれば CUDA を使用
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# データセット
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=1)

# テストデータ
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=1)


# 学習
n_epochs = 2

net = Net()
net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(n_epochs):
    for i, data in enumerate(trainloader):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
            
print('Finished Training')



# テスト
n_correct = 0
n_total = 0

with torch.no_grad():
    for data in testloader:
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        n_total += labels.size(0)
        n_correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * n_correct / n_total))
Files already downloaded and verified
Files already downloaded and verified
Finished Training
Accuracy of the network on the 10000 test images: 54 %

転移学習

ニューラルネットワークを構築して、すべての重みを初期化して、空の状態で学習を始めると、学習の進み具合が遅い。これに対して、他のデータセットである程度学習を済ませたモデルを持ってきて使用すると、学習が早く進む。例えを言うならば、何もわかっていない状態でフランス語を勉強するのと、英語を習得した上でフランス語を勉強するのに似ている。

この項目は、pytorch ウェブサイトの TRANSFER LEARNING FOR COMPUTER VISION TUTORIAL ページを参照して作成した。

In [ ]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

ここではハチとアリのサンプルデータセットを使用して転移学習を行う例を示す。サンプルデータセットは次の URL でダウンロードできる。ダウンロードして、展開後、Jupyter Notebook のファイルと同じ場所に置く。

In [12]:
!wget https://download.pytorch.org/tutorial/hymenoptera_data.zip
!unzip hymenoptera_data.zip
--2020-01-17 00:55:21--  https://download.pytorch.org/tutorial/hymenoptera_data.zip
Resolving download.pytorch.org (download.pytorch.org)... 13.226.42.89, 13.226.42.62, 13.226.42.64, ...
Connecting to download.pytorch.org (download.pytorch.org)|13.226.42.89|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 47286322 (45M) [application/zip]
Saving to: ‘hymenoptera_data.zip.3’

hymenoptera_data.zi 100%[===================>]  45.10M  64.4MB/s    in 0.7s    

2020-01-17 00:55:22 (64.4 MB/s) - ‘hymenoptera_data.zip.3’ saved [47286322/47286322]

Archive:  hymenoptera_data.zip
replace hymenoptera_data/train/ants/0013035.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: A
  inflating: hymenoptera_data/train/ants/0013035.jpg  
  inflating: hymenoptera_data/train/ants/1030023514_aad5c608f9.jpg  
  inflating: hymenoptera_data/train/ants/1095476100_3906d8afde.jpg  
  inflating: hymenoptera_data/train/ants/1099452230_d1949d3250.jpg  
  inflating: hymenoptera_data/train/ants/116570827_e9c126745d.jpg  
  inflating: hymenoptera_data/train/ants/1225872729_6f0856588f.jpg  
  inflating: hymenoptera_data/train/ants/1262877379_64fcada201.jpg  
  inflating: hymenoptera_data/train/ants/1269756697_0bce92cdab.jpg  
  inflating: hymenoptera_data/train/ants/1286984635_5119e80de1.jpg  
  inflating: hymenoptera_data/train/ants/132478121_2a430adea2.jpg  
  inflating: hymenoptera_data/train/ants/1360291657_dc248c5eea.jpg  
  inflating: hymenoptera_data/train/ants/1368913450_e146e2fb6d.jpg  
  inflating: hymenoptera_data/train/ants/1473187633_63ccaacea6.jpg  
  inflating: hymenoptera_data/train/ants/148715752_302c84f5a4.jpg  
  inflating: hymenoptera_data/train/ants/1489674356_09d48dde0a.jpg  
  inflating: hymenoptera_data/train/ants/149244013_c529578289.jpg  
  inflating: hymenoptera_data/train/ants/150801003_3390b73135.jpg  
  inflating: hymenoptera_data/train/ants/150801171_cd86f17ed8.jpg  
  inflating: hymenoptera_data/train/ants/154124431_65460430f2.jpg  
  inflating: hymenoptera_data/train/ants/162603798_40b51f1654.jpg  
  inflating: hymenoptera_data/train/ants/1660097129_384bf54490.jpg  
  inflating: hymenoptera_data/train/ants/167890289_dd5ba923f3.jpg  
  inflating: hymenoptera_data/train/ants/1693954099_46d4c20605.jpg  
  inflating: hymenoptera_data/train/ants/175998972.jpg  
  inflating: hymenoptera_data/train/ants/178538489_bec7649292.jpg  
  inflating: hymenoptera_data/train/ants/1804095607_0341701e1c.jpg  
  inflating: hymenoptera_data/train/ants/1808777855_2a895621d7.jpg  
  inflating: hymenoptera_data/train/ants/188552436_605cc9b36b.jpg  
  inflating: hymenoptera_data/train/ants/1917341202_d00a7f9af5.jpg  
  inflating: hymenoptera_data/train/ants/1924473702_daa9aacdbe.jpg  
  inflating: hymenoptera_data/train/ants/196057951_63bf063b92.jpg  
  inflating: hymenoptera_data/train/ants/196757565_326437f5fe.jpg  
  inflating: hymenoptera_data/train/ants/201558278_fe4caecc76.jpg  
  inflating: hymenoptera_data/train/ants/201790779_527f4c0168.jpg  
  inflating: hymenoptera_data/train/ants/2019439677_2db655d361.jpg  
  inflating: hymenoptera_data/train/ants/207947948_3ab29d7207.jpg  
  inflating: hymenoptera_data/train/ants/20935278_9190345f6b.jpg  
  inflating: hymenoptera_data/train/ants/224655713_3956f7d39a.jpg  
  inflating: hymenoptera_data/train/ants/2265824718_2c96f485da.jpg  
  inflating: hymenoptera_data/train/ants/2265825502_fff99cfd2d.jpg  
  inflating: hymenoptera_data/train/ants/226951206_d6bf946504.jpg  
  inflating: hymenoptera_data/train/ants/2278278459_6b99605e50.jpg  
  inflating: hymenoptera_data/train/ants/2288450226_a6e96e8fdf.jpg  
  inflating: hymenoptera_data/train/ants/2288481644_83ff7e4572.jpg  
  inflating: hymenoptera_data/train/ants/2292213964_ca51ce4bef.jpg  
  inflating: hymenoptera_data/train/ants/24335309_c5ea483bb8.jpg  
  inflating: hymenoptera_data/train/ants/245647475_9523dfd13e.jpg  
  inflating: hymenoptera_data/train/ants/255434217_1b2b3fe0a4.jpg  
  inflating: hymenoptera_data/train/ants/258217966_d9d90d18d3.jpg  
  inflating: hymenoptera_data/train/ants/275429470_b2d7d9290b.jpg  
  inflating: hymenoptera_data/train/ants/28847243_e79fe052cd.jpg  
  inflating: hymenoptera_data/train/ants/318052216_84dff3f98a.jpg  
  inflating: hymenoptera_data/train/ants/334167043_cbd1adaeb9.jpg  
  inflating: hymenoptera_data/train/ants/339670531_94b75ae47a.jpg  
  inflating: hymenoptera_data/train/ants/342438950_a3da61deab.jpg  
  inflating: hymenoptera_data/train/ants/36439863_0bec9f554f.jpg  
  inflating: hymenoptera_data/train/ants/374435068_7eee412ec4.jpg  
  inflating: hymenoptera_data/train/ants/382971067_0bfd33afe0.jpg  
  inflating: hymenoptera_data/train/ants/384191229_5779cf591b.jpg  
  inflating: hymenoptera_data/train/ants/386190770_672743c9a7.jpg  
  inflating: hymenoptera_data/train/ants/392382602_1b7bed32fa.jpg  
  inflating: hymenoptera_data/train/ants/403746349_71384f5b58.jpg  
  inflating: hymenoptera_data/train/ants/408393566_b5b694119b.jpg  
  inflating: hymenoptera_data/train/ants/424119020_6d57481dab.jpg  
  inflating: hymenoptera_data/train/ants/424873399_47658a91fb.jpg  
  inflating: hymenoptera_data/train/ants/450057712_771b3bfc91.jpg  
  inflating: hymenoptera_data/train/ants/45472593_bfd624f8dc.jpg  
  inflating: hymenoptera_data/train/ants/459694881_ac657d3187.jpg  
  inflating: hymenoptera_data/train/ants/460372577_f2f6a8c9fc.jpg  
  inflating: hymenoptera_data/train/ants/460874319_0a45ab4d05.jpg  
  inflating: hymenoptera_data/train/ants/466430434_4000737de9.jpg  
  inflating: hymenoptera_data/train/ants/470127037_513711fd21.jpg  
  inflating: hymenoptera_data/train/ants/474806473_ca6caab245.jpg  
  inflating: hymenoptera_data/train/ants/475961153_b8c13fd405.jpg  
  inflating: hymenoptera_data/train/ants/484293231_e53cfc0c89.jpg  
  inflating: hymenoptera_data/train/ants/49375974_e28ba6f17e.jpg  
  inflating: hymenoptera_data/train/ants/506249802_207cd979b4.jpg  
  inflating: hymenoptera_data/train/ants/506249836_717b73f540.jpg  
  inflating: hymenoptera_data/train/ants/512164029_c0a66b8498.jpg  
  inflating: hymenoptera_data/train/ants/512863248_43c8ce579b.jpg  
  inflating: hymenoptera_data/train/ants/518773929_734dbc5ff4.jpg  
  inflating: hymenoptera_data/train/ants/522163566_fec115ca66.jpg  
  inflating: hymenoptera_data/train/ants/522415432_2218f34bf8.jpg  
  inflating: hymenoptera_data/train/ants/531979952_bde12b3bc0.jpg  
  inflating: hymenoptera_data/train/ants/533848102_70a85ad6dd.jpg  
  inflating: hymenoptera_data/train/ants/535522953_308353a07c.jpg  
  inflating: hymenoptera_data/train/ants/540889389_48bb588b21.jpg  
  inflating: hymenoptera_data/train/ants/541630764_dbd285d63c.jpg  
  inflating: hymenoptera_data/train/ants/543417860_b14237f569.jpg  
  inflating: hymenoptera_data/train/ants/560966032_988f4d7bc4.jpg  
  inflating: hymenoptera_data/train/ants/5650366_e22b7e1065.jpg  
  inflating: hymenoptera_data/train/ants/6240329_72c01e663e.jpg  
  inflating: hymenoptera_data/train/ants/6240338_93729615ec.jpg  
  inflating: hymenoptera_data/train/ants/649026570_e58656104b.jpg  
  inflating: hymenoptera_data/train/ants/662541407_ff8db781e7.jpg  
  inflating: hymenoptera_data/train/ants/67270775_e9fdf77e9d.jpg  
  inflating: hymenoptera_data/train/ants/6743948_2b8c096dda.jpg  
  inflating: hymenoptera_data/train/ants/684133190_35b62c0c1d.jpg  
  inflating: hymenoptera_data/train/ants/69639610_95e0de17aa.jpg  
  inflating: hymenoptera_data/train/ants/707895295_009cf23188.jpg  
  inflating: hymenoptera_data/train/ants/7759525_1363d24e88.jpg  
  inflating: hymenoptera_data/train/ants/795000156_a9900a4a71.jpg  
  inflating: hymenoptera_data/train/ants/822537660_caf4ba5514.jpg  
  inflating: hymenoptera_data/train/ants/82852639_52b7f7f5e3.jpg  
  inflating: hymenoptera_data/train/ants/841049277_b28e58ad05.jpg  
  inflating: hymenoptera_data/train/ants/886401651_f878e888cd.jpg  
  inflating: hymenoptera_data/train/ants/892108839_f1aad4ca46.jpg  
  inflating: hymenoptera_data/train/ants/938946700_ca1c669085.jpg  
  inflating: hymenoptera_data/train/ants/957233405_25c1d1187b.jpg  
  inflating: hymenoptera_data/train/ants/9715481_b3cb4114ff.jpg  
  inflating: hymenoptera_data/train/ants/998118368_6ac1d91f81.jpg  
  inflating: hymenoptera_data/train/ants/ant photos.jpg  
  inflating: hymenoptera_data/train/ants/Ant_1.jpg  
  inflating: hymenoptera_data/train/ants/army-ants-red-picture.jpg  
  inflating: hymenoptera_data/train/ants/formica.jpeg  
  inflating: hymenoptera_data/train/ants/hormiga_co_por.jpg  
  inflating: hymenoptera_data/train/ants/imageNotFound.gif  
  inflating: hymenoptera_data/train/ants/kurokusa.jpg  
  inflating: hymenoptera_data/train/ants/MehdiabadiAnt2_600.jpg  
  inflating: hymenoptera_data/train/ants/Nepenthes_rafflesiana_ant.jpg  
  inflating: hymenoptera_data/train/ants/swiss-army-ant.jpg  
  inflating: hymenoptera_data/train/ants/termite-vs-ant.jpg  
  inflating: hymenoptera_data/train/ants/trap-jaw-ant-insect-bg.jpg  
  inflating: hymenoptera_data/train/ants/VietnameseAntMimicSpider.jpg  
  inflating: hymenoptera_data/train/bees/1092977343_cb42b38d62.jpg  
  inflating: hymenoptera_data/train/bees/1093831624_fb5fbe2308.jpg  
  inflating: hymenoptera_data/train/bees/1097045929_1753d1c765.jpg  
  inflating: hymenoptera_data/train/bees/1232245714_f862fbe385.jpg  
  inflating: hymenoptera_data/train/bees/129236073_0985e91c7d.jpg  
  inflating: hymenoptera_data/train/bees/1295655112_7813f37d21.jpg  
  inflating: hymenoptera_data/train/bees/132511197_0b86ad0fff.jpg  
  inflating: hymenoptera_data/train/bees/132826773_dbbcb117b9.jpg  
  inflating: hymenoptera_data/train/bees/150013791_969d9a968b.jpg  
  inflating: hymenoptera_data/train/bees/1508176360_2972117c9d.jpg  
  inflating: hymenoptera_data/train/bees/154600396_53e1252e52.jpg  
  inflating: hymenoptera_data/train/bees/16838648_415acd9e3f.jpg  
  inflating: hymenoptera_data/train/bees/1691282715_0addfdf5e8.jpg  
  inflating: hymenoptera_data/train/bees/17209602_fe5a5a746f.jpg  
  inflating: hymenoptera_data/train/bees/174142798_e5ad6d76e0.jpg  
  inflating: hymenoptera_data/train/bees/1799726602_8580867f71.jpg  
  inflating: hymenoptera_data/train/bees/1807583459_4fe92b3133.jpg  
  inflating: hymenoptera_data/train/bees/196430254_46bd129ae7.jpg  
  inflating: hymenoptera_data/train/bees/196658222_3fffd79c67.jpg  
  inflating: hymenoptera_data/train/bees/198508668_97d818b6c4.jpg  
  inflating: hymenoptera_data/train/bees/2031225713_50ed499635.jpg  
  inflating: hymenoptera_data/train/bees/2037437624_2d7bce461f.jpg  
  inflating: hymenoptera_data/train/bees/2053200300_8911ef438a.jpg  
  inflating: hymenoptera_data/train/bees/205835650_e6f2614bee.jpg  
  inflating: hymenoptera_data/train/bees/208702903_42fb4d9748.jpg  
  inflating: hymenoptera_data/train/bees/21399619_3e61e5bb6f.jpg  
  inflating: hymenoptera_data/train/bees/2227611847_ec72d40403.jpg  
  inflating: hymenoptera_data/train/bees/2321139806_d73d899e66.jpg  
  inflating: hymenoptera_data/train/bees/2330918208_8074770c20.jpg  
  inflating: hymenoptera_data/train/bees/2345177635_caf07159b3.jpg  
  inflating: hymenoptera_data/train/bees/2358061370_9daabbd9ac.jpg  
  inflating: hymenoptera_data/train/bees/2364597044_3c3e3fc391.jpg  
  inflating: hymenoptera_data/train/bees/2384149906_2cd8b0b699.jpg  
  inflating: hymenoptera_data/train/bees/2397446847_04ef3cd3e1.jpg  
  inflating: hymenoptera_data/train/bees/2405441001_b06c36fa72.jpg  
  inflating: hymenoptera_data/train/bees/2445215254_51698ff797.jpg  
  inflating: hymenoptera_data/train/bees/2452236943_255bfd9e58.jpg  
  inflating: hymenoptera_data/train/bees/2467959963_a7831e9ff0.jpg  
  inflating: hymenoptera_data/train/bees/2470492904_837e97800d.jpg  
  inflating: hymenoptera_data/train/bees/2477324698_3d4b1b1cab.jpg  
  inflating: hymenoptera_data/train/bees/2477349551_e75c97cf4d.jpg  
  inflating: hymenoptera_data/train/bees/2486729079_62df0920be.jpg  
  inflating: hymenoptera_data/train/bees/2486746709_c43cec0e42.jpg  
  inflating: hymenoptera_data/train/bees/2493379287_4100e1dacc.jpg  
  inflating: hymenoptera_data/train/bees/2495722465_879acf9d85.jpg  
  inflating: hymenoptera_data/train/bees/2528444139_fa728b0f5b.jpg  
  inflating: hymenoptera_data/train/bees/2538361678_9da84b77e3.jpg  
  inflating: hymenoptera_data/train/bees/2551813042_8a070aeb2b.jpg  
  inflating: hymenoptera_data/train/bees/2580598377_a4caecdb54.jpg  
  inflating: hymenoptera_data/train/bees/2601176055_8464e6aa71.jpg  
  inflating: hymenoptera_data/train/bees/2610833167_79bf0bcae5.jpg  
  inflating: hymenoptera_data/train/bees/2610838525_fe8e3cae47.jpg  
  inflating: hymenoptera_data/train/bees/2617161745_fa3ebe85b4.jpg  
  inflating: hymenoptera_data/train/bees/2625499656_e3415e374d.jpg  
  inflating: hymenoptera_data/train/bees/2634617358_f32fd16bea.jpg  
  inflating: hymenoptera_data/train/bees/2638074627_6b3ae746a0.jpg  
  inflating: hymenoptera_data/train/bees/2645107662_b73a8595cc.jpg  
  inflating: hymenoptera_data/train/bees/2651621464_a2fa8722eb.jpg  
  inflating: hymenoptera_data/train/bees/2652877533_a564830cbf.jpg  
  inflating: hymenoptera_data/train/bees/266644509_d30bb16a1b.jpg  
  inflating: hymenoptera_data/train/bees/2683605182_9d2a0c66cf.jpg  
  inflating: hymenoptera_data/train/bees/2704348794_eb5d5178c2.jpg  
  inflating: hymenoptera_data/train/bees/2707440199_cd170bd512.jpg  
  inflating: hymenoptera_data/train/bees/2710368626_cb42882dc8.jpg  
  inflating: hymenoptera_data/train/bees/2722592222_258d473e17.jpg  
  inflating: hymenoptera_data/train/bees/2728759455_ce9bb8cd7a.jpg  
  inflating: hymenoptera_data/train/bees/2756397428_1d82a08807.jpg  
  inflating: hymenoptera_data/train/bees/2765347790_da6cf6cb40.jpg  
  inflating: hymenoptera_data/train/bees/2781170484_5d61835d63.jpg  
  inflating: hymenoptera_data/train/bees/279113587_b4843db199.jpg  
  inflating: hymenoptera_data/train/bees/2792000093_e8ae0718cf.jpg  
  inflating: hymenoptera_data/train/bees/2801728106_833798c909.jpg  
  inflating: hymenoptera_data/train/bees/2822388965_f6dca2a275.jpg  
  inflating: hymenoptera_data/train/bees/2861002136_52c7c6f708.jpg  
  inflating: hymenoptera_data/train/bees/2908916142_a7ac8b57a8.jpg  
  inflating: hymenoptera_data/train/bees/29494643_e3410f0d37.jpg  
  inflating: hymenoptera_data/train/bees/2959730355_416a18c63c.jpg  
  inflating: hymenoptera_data/train/bees/2962405283_22718d9617.jpg  
  inflating: hymenoptera_data/train/bees/3006264892_30e9cced70.jpg  
  inflating: hymenoptera_data/train/bees/3030189811_01d095b793.jpg  
  inflating: hymenoptera_data/train/bees/3030772428_8578335616.jpg  
  inflating: hymenoptera_data/train/bees/3044402684_3853071a87.jpg  
  inflating: hymenoptera_data/train/bees/3074585407_9854eb3153.jpg  
  inflating: hymenoptera_data/train/bees/3079610310_ac2d0ae7bc.jpg  
  inflating: hymenoptera_data/train/bees/3090975720_71f12e6de4.jpg  
  inflating: hymenoptera_data/train/bees/3100226504_c0d4f1e3f1.jpg  
  inflating: hymenoptera_data/train/bees/342758693_c56b89b6b6.jpg  
  inflating: hymenoptera_data/train/bees/354167719_22dca13752.jpg  
  inflating: hymenoptera_data/train/bees/359928878_b3b418c728.jpg  
  inflating: hymenoptera_data/train/bees/365759866_b15700c59b.jpg  
  inflating: hymenoptera_data/train/bees/36900412_92b81831ad.jpg  
  inflating: hymenoptera_data/train/bees/39672681_1302d204d1.jpg  
  inflating: hymenoptera_data/train/bees/39747887_42df2855ee.jpg  
  inflating: hymenoptera_data/train/bees/421515404_e87569fd8b.jpg  
  inflating: hymenoptera_data/train/bees/444532809_9e931e2279.jpg  
  inflating: hymenoptera_data/train/bees/446296270_d9e8b93ecf.jpg  
  inflating: hymenoptera_data/train/bees/452462677_7be43af8ff.jpg  
  inflating: hymenoptera_data/train/bees/452462695_40a4e5b559.jpg  
  inflating: hymenoptera_data/train/bees/457457145_5f86eb7e9c.jpg  
  inflating: hymenoptera_data/train/bees/465133211_80e0c27f60.jpg  
  inflating: hymenoptera_data/train/bees/469333327_358ba8fe8a.jpg  
  inflating: hymenoptera_data/train/bees/472288710_2abee16fa0.jpg  
  inflating: hymenoptera_data/train/bees/473618094_8ffdcab215.jpg  
  inflating: hymenoptera_data/train/bees/476347960_52edd72b06.jpg  
  inflating: hymenoptera_data/train/bees/478701318_bbd5e557b8.jpg  
  inflating: hymenoptera_data/train/bees/507288830_f46e8d4cb2.jpg  
  inflating: hymenoptera_data/train/bees/509247772_2db2d01374.jpg  
  inflating: hymenoptera_data/train/bees/513545352_fd3e7c7c5d.jpg  
  inflating: hymenoptera_data/train/bees/522104315_5d3cb2758e.jpg  
  inflating: hymenoptera_data/train/bees/537309131_532bfa59ea.jpg  
  inflating: hymenoptera_data/train/bees/586041248_3032e277a9.jpg  
  inflating: hymenoptera_data/train/bees/760526046_547e8b381f.jpg  
  inflating: hymenoptera_data/train/bees/760568592_45a52c847f.jpg  
  inflating: hymenoptera_data/train/bees/774440991_63a4aa0cbe.jpg  
  inflating: hymenoptera_data/train/bees/85112639_6e860b0469.jpg  
  inflating: hymenoptera_data/train/bees/873076652_eb098dab2d.jpg  
  inflating: hymenoptera_data/train/bees/90179376_abc234e5f4.jpg  
  inflating: hymenoptera_data/train/bees/92663402_37f379e57a.jpg  
  inflating: hymenoptera_data/train/bees/95238259_98470c5b10.jpg  
  inflating: hymenoptera_data/train/bees/969455125_58c797ef17.jpg  
  inflating: hymenoptera_data/train/bees/98391118_bdb1e80cce.jpg  
  inflating: hymenoptera_data/val/ants/10308379_1b6c72e180.jpg  
  inflating: hymenoptera_data/val/ants/1053149811_f62a3410d3.jpg  
  inflating: hymenoptera_data/val/ants/1073564163_225a64f170.jpg  
  inflating: hymenoptera_data/val/ants/1119630822_cd325ea21a.jpg  
  inflating: hymenoptera_data/val/ants/1124525276_816a07c17f.jpg  
  inflating: hymenoptera_data/val/ants/11381045_b352a47d8c.jpg  
  inflating: hymenoptera_data/val/ants/119785936_dd428e40c3.jpg  
  inflating: hymenoptera_data/val/ants/1247887232_edcb61246c.jpg  
  inflating: hymenoptera_data/val/ants/1262751255_c56c042b7b.jpg  
  inflating: hymenoptera_data/val/ants/1337725712_2eb53cd742.jpg  
  inflating: hymenoptera_data/val/ants/1358854066_5ad8015f7f.jpg  
  inflating: hymenoptera_data/val/ants/1440002809_b268d9a66a.jpg  
  inflating: hymenoptera_data/val/ants/147542264_79506478c2.jpg  
  inflating: hymenoptera_data/val/ants/152286280_411648ec27.jpg  
  inflating: hymenoptera_data/val/ants/153320619_2aeb5fa0ee.jpg  
  inflating: hymenoptera_data/val/ants/153783656_85f9c3ac70.jpg  
  inflating: hymenoptera_data/val/ants/157401988_d0564a9d02.jpg  
  inflating: hymenoptera_data/val/ants/159515240_d5981e20d1.jpg  
  inflating: hymenoptera_data/val/ants/161076144_124db762d6.jpg  
  inflating: hymenoptera_data/val/ants/161292361_c16e0bf57a.jpg  
  inflating: hymenoptera_data/val/ants/170652283_ecdaff5d1a.jpg  
  inflating: hymenoptera_data/val/ants/17081114_79b9a27724.jpg  
  inflating: hymenoptera_data/val/ants/172772109_d0a8e15fb0.jpg  
  inflating: hymenoptera_data/val/ants/1743840368_b5ccda82b7.jpg  
  inflating: hymenoptera_data/val/ants/181942028_961261ef48.jpg  
  inflating: hymenoptera_data/val/ants/183260961_64ab754c97.jpg  
  inflating: hymenoptera_data/val/ants/2039585088_c6f47c592e.jpg  
  inflating: hymenoptera_data/val/ants/205398178_c395c5e460.jpg  
  inflating: hymenoptera_data/val/ants/208072188_f293096296.jpg  
  inflating: hymenoptera_data/val/ants/209615353_eeb38ba204.jpg  
  inflating: hymenoptera_data/val/ants/2104709400_8831b4fc6f.jpg  
  inflating: hymenoptera_data/val/ants/212100470_b485e7b7b9.jpg  
  inflating: hymenoptera_data/val/ants/2127908701_d49dc83c97.jpg  
  inflating: hymenoptera_data/val/ants/2191997003_379df31291.jpg  
  inflating: hymenoptera_data/val/ants/2211974567_ee4606b493.jpg  
  inflating: hymenoptera_data/val/ants/2219621907_47bc7cc6b0.jpg  
  inflating: hymenoptera_data/val/ants/2238242353_52c82441df.jpg  
  inflating: hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg  
  inflating: hymenoptera_data/val/ants/239161491_86ac23b0a3.jpg  
  inflating: hymenoptera_data/val/ants/263615709_cfb28f6b8e.jpg  
  inflating: hymenoptera_data/val/ants/308196310_1db5ffa01b.jpg  
  inflating: hymenoptera_data/val/ants/319494379_648fb5a1c6.jpg  
  inflating: hymenoptera_data/val/ants/35558229_1fa4608a7a.jpg  
  inflating: hymenoptera_data/val/ants/412436937_4c2378efc2.jpg  
  inflating: hymenoptera_data/val/ants/436944325_d4925a38c7.jpg  
  inflating: hymenoptera_data/val/ants/445356866_6cb3289067.jpg  
  inflating: hymenoptera_data/val/ants/459442412_412fecf3fe.jpg  
  inflating: hymenoptera_data/val/ants/470127071_8b8ee2bd74.jpg  
  inflating: hymenoptera_data/val/ants/477437164_bc3e6e594a.jpg  
  inflating: hymenoptera_data/val/ants/488272201_c5aa281348.jpg  
  inflating: hymenoptera_data/val/ants/502717153_3e4865621a.jpg  
  inflating: hymenoptera_data/val/ants/518746016_bcc28f8b5b.jpg  
  inflating: hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg  
  inflating: hymenoptera_data/val/ants/562589509_7e55469b97.jpg  
  inflating: hymenoptera_data/val/ants/57264437_a19006872f.jpg  
  inflating: hymenoptera_data/val/ants/573151833_ebbc274b77.jpg  
  inflating: hymenoptera_data/val/ants/649407494_9b6bc4949f.jpg  
  inflating: hymenoptera_data/val/ants/751649788_78dd7d16ce.jpg  
  inflating: hymenoptera_data/val/ants/768870506_8f115d3d37.jpg  
  inflating: hymenoptera_data/val/ants/800px-Meat_eater_ant_qeen_excavating_hole.jpg  
  inflating: hymenoptera_data/val/ants/8124241_36b290d372.jpg  
  inflating: hymenoptera_data/val/ants/8398478_50ef10c47a.jpg  
  inflating: hymenoptera_data/val/ants/854534770_31f6156383.jpg  
  inflating: hymenoptera_data/val/ants/892676922_4ab37dce07.jpg  
  inflating: hymenoptera_data/val/ants/94999827_36895faade.jpg  
  inflating: hymenoptera_data/val/ants/Ant-1818.jpg  
  inflating: hymenoptera_data/val/ants/ants-devouring-remains-of-large-dead-insect-on-red-tile-in-Stellenbosch-South-Africa-closeup-1-DHD.jpg  
  inflating: hymenoptera_data/val/ants/desert_ant.jpg  
  inflating: hymenoptera_data/val/ants/F.pergan.28(f).jpg  
  inflating: hymenoptera_data/val/ants/Hormiga.jpg  
  inflating: hymenoptera_data/val/bees/1032546534_06907fe3b3.jpg  
  inflating: hymenoptera_data/val/bees/10870992_eebeeb3a12.jpg  
  inflating: hymenoptera_data/val/bees/1181173278_23c36fac71.jpg  
  inflating: hymenoptera_data/val/bees/1297972485_33266a18d9.jpg  
  inflating: hymenoptera_data/val/bees/1328423762_f7a88a8451.jpg  
  inflating: hymenoptera_data/val/bees/1355974687_1341c1face.jpg  
  inflating: hymenoptera_data/val/bees/144098310_a4176fd54d.jpg  
  inflating: hymenoptera_data/val/bees/1486120850_490388f84b.jpg  
  inflating: hymenoptera_data/val/bees/149973093_da3c446268.jpg  
  inflating: hymenoptera_data/val/bees/151594775_ee7dc17b60.jpg  
  inflating: hymenoptera_data/val/bees/151603988_2c6f7d14c7.jpg  
  inflating: hymenoptera_data/val/bees/1519368889_4270261ee3.jpg  
  inflating: hymenoptera_data/val/bees/152789693_220b003452.jpg  
  inflating: hymenoptera_data/val/bees/177677657_a38c97e572.jpg  
  inflating: hymenoptera_data/val/bees/1799729694_0c40101071.jpg  
  inflating: hymenoptera_data/val/bees/181171681_c5a1a82ded.jpg  
  inflating: hymenoptera_data/val/bees/187130242_4593a4c610.jpg  
  inflating: hymenoptera_data/val/bees/203868383_0fcbb48278.jpg  
  inflating: hymenoptera_data/val/bees/2060668999_e11edb10d0.jpg  
  inflating: hymenoptera_data/val/bees/2086294791_6f3789d8a6.jpg  
  inflating: hymenoptera_data/val/bees/2103637821_8d26ee6b90.jpg  
  inflating: hymenoptera_data/val/bees/2104135106_a65eede1de.jpg  
  inflating: hymenoptera_data/val/bees/215512424_687e1e0821.jpg  
  inflating: hymenoptera_data/val/bees/2173503984_9c6aaaa7e2.jpg  
  inflating: hymenoptera_data/val/bees/220376539_20567395d8.jpg  
  inflating: hymenoptera_data/val/bees/224841383_d050f5f510.jpg  
  inflating: hymenoptera_data/val/bees/2321144482_f3785ba7b2.jpg  
  inflating: hymenoptera_data/val/bees/238161922_55fa9a76ae.jpg  
  inflating: hymenoptera_data/val/bees/2407809945_fb525ef54d.jpg  
  inflating: hymenoptera_data/val/bees/2415414155_1916f03b42.jpg  
  inflating: hymenoptera_data/val/bees/2438480600_40a1249879.jpg  
  inflating: hymenoptera_data/val/bees/2444778727_4b781ac424.jpg  
  inflating: hymenoptera_data/val/bees/2457841282_7867f16639.jpg  
  inflating: hymenoptera_data/val/bees/2470492902_3572c90f75.jpg  
  inflating: hymenoptera_data/val/bees/2478216347_535c8fe6d7.jpg  
  inflating: hymenoptera_data/val/bees/2501530886_e20952b97d.jpg  
  inflating: hymenoptera_data/val/bees/2506114833_90a41c5267.jpg  
  inflating: hymenoptera_data/val/bees/2509402554_31821cb0b6.jpg  
  inflating: hymenoptera_data/val/bees/2525379273_dcb26a516d.jpg  
  inflating: hymenoptera_data/val/bees/26589803_5ba7000313.jpg  
  inflating: hymenoptera_data/val/bees/2668391343_45e272cd07.jpg  
  inflating: hymenoptera_data/val/bees/2670536155_c170f49cd0.jpg  
  inflating: hymenoptera_data/val/bees/2685605303_9eed79d59d.jpg  
  inflating: hymenoptera_data/val/bees/2702408468_d9ed795f4f.jpg  
  inflating: hymenoptera_data/val/bees/2709775832_85b4b50a57.jpg  
  inflating: hymenoptera_data/val/bees/2717418782_bd83307d9f.jpg  
  inflating: hymenoptera_data/val/bees/272986700_d4d4bf8c4b.jpg  
  inflating: hymenoptera_data/val/bees/2741763055_9a7bb00802.jpg  
  inflating: hymenoptera_data/val/bees/2745389517_250a397f31.jpg  
  inflating: hymenoptera_data/val/bees/2751836205_6f7b5eff30.jpg  
  inflating: hymenoptera_data/val/bees/2782079948_8d4e94a826.jpg  
  inflating: hymenoptera_data/val/bees/2809496124_5f25b5946a.jpg  
  inflating: hymenoptera_data/val/bees/2815838190_0a9889d995.jpg  
  inflating: hymenoptera_data/val/bees/2841437312_789699c740.jpg  
  inflating: hymenoptera_data/val/bees/2883093452_7e3a1eb53f.jpg  
  inflating: hymenoptera_data/val/bees/290082189_f66cb80bfc.jpg  
  inflating: hymenoptera_data/val/bees/296565463_d07a7bed96.jpg  
  inflating: hymenoptera_data/val/bees/3077452620_548c79fda0.jpg  
  inflating: hymenoptera_data/val/bees/348291597_ee836fbb1a.jpg  
  inflating: hymenoptera_data/val/bees/350436573_41f4ecb6c8.jpg  
  inflating: hymenoptera_data/val/bees/353266603_d3eac7e9a0.jpg  
  inflating: hymenoptera_data/val/bees/372228424_16da1f8884.jpg  
  inflating: hymenoptera_data/val/bees/400262091_701c00031c.jpg  
  inflating: hymenoptera_data/val/bees/416144384_961c326481.jpg  
  inflating: hymenoptera_data/val/bees/44105569_16720a960c.jpg  
  inflating: hymenoptera_data/val/bees/456097971_860949c4fc.jpg  
  inflating: hymenoptera_data/val/bees/464594019_1b24a28bb1.jpg  
  inflating: hymenoptera_data/val/bees/485743562_d8cc6b8f73.jpg  
  inflating: hymenoptera_data/val/bees/540976476_844950623f.jpg  
  inflating: hymenoptera_data/val/bees/54736755_c057723f64.jpg  
  inflating: hymenoptera_data/val/bees/57459255_752774f1b2.jpg  
  inflating: hymenoptera_data/val/bees/576452297_897023f002.jpg  
  inflating: hymenoptera_data/val/bees/586474709_ae436da045.jpg  
  inflating: hymenoptera_data/val/bees/590318879_68cf112861.jpg  
  inflating: hymenoptera_data/val/bees/59798110_2b6a3c8031.jpg  
  inflating: hymenoptera_data/val/bees/603709866_a97c7cfc72.jpg  
  inflating: hymenoptera_data/val/bees/603711658_4c8cd2201e.jpg  
  inflating: hymenoptera_data/val/bees/65038344_52a45d090d.jpg  
  inflating: hymenoptera_data/val/bees/6a00d8341c630a53ef00e553d0beb18834-800wi.jpg  
  inflating: hymenoptera_data/val/bees/72100438_73de9f17af.jpg  
  inflating: hymenoptera_data/val/bees/759745145_e8bc776ec8.jpg  
  inflating: hymenoptera_data/val/bees/936182217_c4caa5222d.jpg  
  inflating: hymenoptera_data/val/bees/abeja.jpg  

画像データセットの前処理の手順を定義する。このデータに含まれている画像の枚数が少ないので、訓練データに対して水増しをやってみる。テストデータに対しては、標準化のみ行う。

In [ ]:
# 画像データの前処理を行うための手順を定義
data_transforms = {
    # 訓練データ画像に対する前処理
    'train': transforms.Compose([
        # 入力画像の任意の位置から 224x224 に切り抜き
        transforms.RandomResizedCrop(224),
        # 平行反転
        transforms.RandomHorizontalFlip(),
        # 画像データをテンソル(多次元ベクトル)に変換
        transforms.ToTensor(),
        # 0-225 の数値をほぼ 0-1 の範囲に収める
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    
    # テストデータに対する前処理
    'val': transforms.Compose([
        # 入力画像を強制的に 256x256 に縮小
        transforms.Resize(256),
        # 256x256 の中心から 224x224 の画像を切り出す
        transforms.CenterCrop(224),
        # テンソルに変換
        transforms.ToTensor(),
        # 標準化
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 画像を pytorch に認識させる
data_dir = './hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

# GPU を使えるならば GPU を使用する
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

モデルの学習手続きを定義する。

In [ ]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=20):
    since = time.time()

    # 学習途中でいいモデルができるかもしれないので、これまでの学習のなかで最適なモデルの重みを保存
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    # num_epochs 回分だけ学習する
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # 1 エポックごとに学習と評価を繰り返す
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # training mode
            else:
                model.eval()   # evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # 1 バッチごと学習と評価を繰り返す、評価結果である損失をこのエポックの損失として加算する
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 勾配削除
                optimizer.zero_grad()

                # 学習時のみ行う重みを記録する
                with torch.set_grad_enabled(phase == 'train'):

                    # 学習
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)

                    # 損失計算
                    loss = criterion(outputs, labels)

                    # 損失を逆伝播、そして重みを更新
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # 損失を加算
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)


            if phase == 'train':
                # 学習率を徐々に下げる
                scheduler.step()

            # このエポックの損失と精度
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # 精度が過去最高であれば、その重みを保存
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # 学習過程で最高精度を出した重みをセットアップして、最適モデルとして返す
    model.load_state_dict(best_model_wts)
    return model
In [15]:
# 訓練済み ResNet を取得
model_ft = models.resnet18(pretrained=True)

print(model_ft)
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

モデル構造を確認すると、ResNet 18 は畳み込み演算とプーリング演算の連続になっていて、ニューラルネットワークとなっているのは最後の結合層 (fc): Linear(in_features=512, out_features=1000, bias=True) となっていることが確認できる。

なお、torchsummary パッケージを使用した出力が見やすいので、torchsummary でモデルのアーキテクチャを出力してみるのもよい。

In [16]:
from torchsummary import summary
summary(model_ft.cuda(), (3, 224, 224))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64, 56, 56]               0
           Conv2d-15           [-1, 64, 56, 56]          36,864
      BatchNorm2d-16           [-1, 64, 56, 56]             128
             ReLU-17           [-1, 64, 56, 56]               0
       BasicBlock-18           [-1, 64, 56, 56]               0
           Conv2d-19          [-1, 128, 28, 28]          73,728
      BatchNorm2d-20          [-1, 128, 28, 28]             256
             ReLU-21          [-1, 128, 28, 28]               0
           Conv2d-22          [-1, 128, 28, 28]         147,456
      BatchNorm2d-23          [-1, 128, 28, 28]             256
           Conv2d-24          [-1, 128, 28, 28]           8,192
      BatchNorm2d-25          [-1, 128, 28, 28]             256
             ReLU-26          [-1, 128, 28, 28]               0
       BasicBlock-27          [-1, 128, 28, 28]               0
           Conv2d-28          [-1, 128, 28, 28]         147,456
      BatchNorm2d-29          [-1, 128, 28, 28]             256
             ReLU-30          [-1, 128, 28, 28]               0
           Conv2d-31          [-1, 128, 28, 28]         147,456
      BatchNorm2d-32          [-1, 128, 28, 28]             256
             ReLU-33          [-1, 128, 28, 28]               0
       BasicBlock-34          [-1, 128, 28, 28]               0
           Conv2d-35          [-1, 256, 14, 14]         294,912
      BatchNorm2d-36          [-1, 256, 14, 14]             512
             ReLU-37          [-1, 256, 14, 14]               0
           Conv2d-38          [-1, 256, 14, 14]         589,824
      BatchNorm2d-39          [-1, 256, 14, 14]             512
           Conv2d-40          [-1, 256, 14, 14]          32,768
      BatchNorm2d-41          [-1, 256, 14, 14]             512
             ReLU-42          [-1, 256, 14, 14]               0
       BasicBlock-43          [-1, 256, 14, 14]               0
           Conv2d-44          [-1, 256, 14, 14]         589,824
      BatchNorm2d-45          [-1, 256, 14, 14]             512
             ReLU-46          [-1, 256, 14, 14]               0
           Conv2d-47          [-1, 256, 14, 14]         589,824
      BatchNorm2d-48          [-1, 256, 14, 14]             512
             ReLU-49          [-1, 256, 14, 14]               0
       BasicBlock-50          [-1, 256, 14, 14]               0
           Conv2d-51            [-1, 512, 7, 7]       1,179,648
      BatchNorm2d-52            [-1, 512, 7, 7]           1,024
             ReLU-53            [-1, 512, 7, 7]               0
           Conv2d-54            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-55            [-1, 512, 7, 7]           1,024
           Conv2d-56            [-1, 512, 7, 7]         131,072
      BatchNorm2d-57            [-1, 512, 7, 7]           1,024
             ReLU-58            [-1, 512, 7, 7]               0
       BasicBlock-59            [-1, 512, 7, 7]               0
           Conv2d-60            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-61            [-1, 512, 7, 7]           1,024
             ReLU-62            [-1, 512, 7, 7]               0
           Conv2d-63            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-64            [-1, 512, 7, 7]           1,024
             ReLU-65            [-1, 512, 7, 7]               0
       BasicBlock-66            [-1, 512, 7, 7]               0
AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
           Linear-68                 [-1, 1000]         513,000
================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 62.79
Params size (MB): 44.59
Estimated Total Size (MB): 107.96
----------------------------------------------------------------
In [ ]:
# ResNet の出力層のユニット数を取得
# (ResNet は ImageNet で訓練されたので出力層が 1000 ユニット)
num_ftrs = model_ft.fc.in_features

# 出力層 100 ユニットを 2 ユニットに置き換える(つなぎ直す)
model_ft.fc = nn.Linear(num_ftrs, len(class_names))

# モデルを GPU に送る
model_ft = model_ft.to(device)

# 損失関数を定義
criterion = nn.CrossEntropyLoss()

# モデル中の微分可能なパラメーターを最適化関数にセットアップ
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# 7 エポックごとに学習率を 0.1 倍だけ小さくするようにセットアップ
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
In [18]:
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=20)
Epoch 0/19
----------
train Loss: 0.5790 Acc: 0.6926
val Loss: 0.2692 Acc: 0.9150

Epoch 1/19
----------
train Loss: 0.6755 Acc: 0.7664
val Loss: 1.0073 Acc: 0.7320

Epoch 2/19
----------
train Loss: 0.4341 Acc: 0.8361
val Loss: 0.4415 Acc: 0.8497

Epoch 3/19
----------
train Loss: 0.5090 Acc: 0.8074
val Loss: 0.2810 Acc: 0.8824

Epoch 4/19
----------
train Loss: 0.5428 Acc: 0.7951
val Loss: 0.6174 Acc: 0.7647

Epoch 5/19
----------
train Loss: 0.5231 Acc: 0.8033
val Loss: 0.3569 Acc: 0.8758

Epoch 6/19
----------
train Loss: 0.3993 Acc: 0.8279
val Loss: 0.5651 Acc: 0.7974

Epoch 7/19
----------
train Loss: 0.5652 Acc: 0.7787
val Loss: 0.2478 Acc: 0.9281

Epoch 8/19
----------
train Loss: 0.3159 Acc: 0.8689
val Loss: 0.2463 Acc: 0.9150

Epoch 9/19
----------
train Loss: 0.3461 Acc: 0.8566
val Loss: 0.2453 Acc: 0.9281

Epoch 10/19
----------
train Loss: 0.3172 Acc: 0.8607
val Loss: 0.2362 Acc: 0.9085

Epoch 11/19
----------
train Loss: 0.4338 Acc: 0.7910
val Loss: 0.2410 Acc: 0.9281

Epoch 12/19
----------
train Loss: 0.3169 Acc: 0.8730
val Loss: 0.2383 Acc: 0.9216

Epoch 13/19
----------
train Loss: 0.3513 Acc: 0.8361
val Loss: 0.2184 Acc: 0.9346

Epoch 14/19
----------
train Loss: 0.1766 Acc: 0.9303
val Loss: 0.2271 Acc: 0.9346

Epoch 15/19
----------
train Loss: 0.2579 Acc: 0.8770
val Loss: 0.2399 Acc: 0.9281

Epoch 16/19
----------
train Loss: 0.2475 Acc: 0.8934
val Loss: 0.2178 Acc: 0.9281

Epoch 17/19
----------
train Loss: 0.2136 Acc: 0.9221
val Loss: 0.2337 Acc: 0.9281

Epoch 18/19
----------
train Loss: 0.2801 Acc: 0.8934
val Loss: 0.2194 Acc: 0.9346

Epoch 19/19
----------
train Loss: 0.2496 Acc: 0.8811
val Loss: 0.2650 Acc: 0.9020

Training complete in 1m 15s
Best val Acc: 0.934641

データセットにもよるが、転移学習を利用した場合はより少ないエポック数で高い精度を達成することができる。

Fine-tunning

転移学習では、他のデータセットで学習済みのモデルに、自分たちのデータで再学習させた。この際、学習済みのモデル中のすべての重みを更新しながら再学習を行なっていた。

これに対して、学習済みのモデル中の重みのうち一部だけを再学習させることもできる。例えば 18 層からなる学習済みモデルがあるとき、第 1 層〜第 10 層までの重みを固定させ、第 11 層〜第 18 層の重みだけを再学習させる、といったことができる。何層目まで固定し、何層目から自由に動かすかは、そのモデルが昔に学習したデータセットとこれから学習しようとするデータセットの類似度と量によってチューニングする必要がある。

この項目は pytorch ウェブサイトの FINETUNING TORCHVISION MODELS を参照して作成した。

In [19]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
PyTorch Version:  1.3.1
Torchvision Version:  0.4.2

学習手続きを定義する。

In [ ]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # training mode
            else:
                model.eval()   # evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # for each mini_batch
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

ここで、転移学習などでよく使われる VGG16 と呼ばれる学習済みモデルを取得し fine-tuning を行う。VGG16 のアーキテクチャを出力し、その構造を確認してみる。

In [21]:
model_ft = models.vgg16(pretrained=True)

print(model_ft)
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

まずモデルの描くレイヤーへのアクセス方法を確認する。

In [22]:
for child_name, child in model_ft.named_children():
  print(child_name)
  for child_layer_i, param in enumerate(child.parameters()):
      print(child_layer_i)
features
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
avgpool
classifier
0
1
2
3
4
5

例えば、最後の classifier レイヤーの最後の 3 レイヤーのみを再学習させたければ次のようにする。

In [ ]:
#for child_name, child in model_ft.named_children():
#  for child_layer_i, param in enumerate(child.parameters()):
#
#      if child_name == 'classifier' and child_layer_id > 2:
#          param.requires_grad = True
#      else:
#          param.requires_grad = False
#

今回は特徴量抽出を行うので、以下のようにする。

In [ ]:
for param in model_ft.parameters():
    param.requires_grad = False

VGG16 の出力層 (classifier)(6) が 1000 となっているため、ここでは 2 クラスとなるように設定する。

In [ ]:
num_classes = 2

num_ftrs = model_ft.classifier[6].in_features
model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
input_size = 224
In [ ]:
data_dir = "./hymenoptera_data"
batch_size = 8

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']}

次に損失を定義し、学習すべきパラメーターのみを最適化関数に代入する。

In [27]:
criterion = nn.CrossEntropyLoss()

params_to_update = model_ft.parameters()

params_to_update = []
for name, param in model_ft.named_parameters():
    if param.requires_grad == True:
        params_to_update.append(param)
        print("\t",name)


optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)
	 classifier.6.weight
	 classifier.6.bias

GPU デバイスをセットアップし、学習を開始する。

In [28]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = model_ft.to(device)

num_epochs = 10
model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft,
                             num_epochs=num_epochs)
Epoch 0/9
----------
train Loss: 0.2918 Acc: 0.8607
val Loss: 0.1196 Acc: 0.9542

Epoch 1/9
----------
train Loss: 0.1341 Acc: 0.9426
val Loss: 0.0969 Acc: 0.9542

Epoch 2/9
----------
train Loss: 0.1672 Acc: 0.9221
val Loss: 0.0956 Acc: 0.9542

Epoch 3/9
----------
train Loss: 0.0890 Acc: 0.9672
val Loss: 0.1007 Acc: 0.9412

Epoch 4/9
----------
train Loss: 0.1951 Acc: 0.9303
val Loss: 0.0957 Acc: 0.9412

Epoch 5/9
----------
train Loss: 0.1400 Acc: 0.9344
val Loss: 0.1009 Acc: 0.9477

Epoch 6/9
----------
train Loss: 0.1360 Acc: 0.9549
val Loss: 0.0983 Acc: 0.9542

Epoch 7/9
----------
train Loss: 0.1503 Acc: 0.9467
val Loss: 0.0840 Acc: 0.9542

Epoch 8/9
----------
train Loss: 0.1232 Acc: 0.9508
val Loss: 0.0990 Acc: 0.9542

Epoch 9/9
----------
train Loss: 0.0874 Acc: 0.9631
val Loss: 0.1252 Acc: 0.9542

Training complete in 0m 41s
Best val Acc: 0.954248