PyTorchで画像分類モデルをファインチューニングする【転移学習の実装】

PyTorchで画像分類モデルをファインチューニングする【転移学習の実装】 AI資格・学習

はじめに

画像分類を実装するとき、ゼロからCNNを学習させるより「事前学習済みモデルをファインチューニング」する方が圧倒的に精度が高く、データも少なくて済みます。PyTorchを使ったファインチューニングの実装を解説します。

ファインチューニングとは

ImageNet(100万枚以上)で事前学習済みのモデル(ResNet・EfficientNet・ViTなど)の重みを、自分のタスク用データで微調整する手法です。特に転移学習では、ベースモデルの重みを固定して最終層だけ学習する「特徴抽出」と、全層を少ない学習率で再学習する「ファインチューニング」の2段階アプローチが効果的です。

ResNet50のファインチューニング実装

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.models import resnet50, ResNet50_Weights

# デバイス設定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用デバイス: {device}')

# データの前処理
transform = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# データセットの読み込み(フォルダ構造: data/train/class1/, data/val/class1/)
train_dataset = ImageFolder('data/train', transform=transform['train'])
val_dataset   = ImageFolder('data/val',   transform=transform['val'])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset,   batch_size=32, shuffle=False)

num_classes = len(train_dataset.classes)
print(f'クラス数: {num_classes}, クラス: {train_dataset.classes}')

# ResNet50の読み込みとファインチューニング設定
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# ステップ1:最終層以外を凍結(特徴抽出)
for param in model.parameters():
    param.requires_grad = False

# 最終層を自分のクラス数に変更
model.fc = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(model.fc.in_features, num_classes)
)
model = model.to(device)

# 学習設定
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# 学習ループ
def train_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss, correct = 0, 0
    for X, y in loader:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        outputs = model(X)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        correct += (outputs.argmax(1) == y).sum().item()
    return total_loss / len(loader), correct / len(loader.dataset)

for epoch in range(10):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)
    scheduler.step()
    print(f'Epoch {epoch+1:2d}: loss={train_loss:.4f}, acc={train_acc:.4f}')

# ステップ2:全層をアンフリーズしてファインチューニング
for param in model.parameters():
    param.requires_grad = True
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)  # 低い学習率

for epoch in range(5):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)
    print(f'Finetune Epoch {epoch+1}: loss={train_loss:.4f}, acc={train_acc:.4f}')

torch.save(model.state_dict(), 'finetuned_resnet50.pth')

まとめ

ファインチューニングは「最終層のみ学習→全層を低学習率で学習」の2段階アプローチが効果的です。データ拡張(RandomCrop・Flip・ColorJitter)も精度向上に欠かせません。製造業の外観検査・農業の病気判定・医療画像解析など、少ないデータで高精度の画像分類が実現できます。

📌 プログラミング・AI学習のおすすめスクール

※本記事にはアフィリエイトリンクが含まれます。

コメント

タイトルとURLをコピーしました