Thông báo: Download 4 khóa học Python từ cơ bản đến nâng cao tại đây.
Học chuyển giao (Transfer Learning) trong PyTorch Beginner
Trong lĩnh vực trí tuệ nhân tạo, học chuyển giao (Transfer Learning) đã trở thành một kỹ thuật quan trọng và hiệu quả, đặc biệt là khi làm việc với các bài toán thị giác máy tính như phân loại hình ảnh. Với sự hỗ trợ của các mô hình tiền huấn luyện, Transfer Learning giúp giảm thời gian huấn luyện và nâng cao độ chính xác, đặc biệt khi dữ liệu huấn luyện bị hạn chế.
Bài viết này sẽ hướng dẫn bạn cách triển khai Transfer Learning trong PyTorch, một framework mạnh mẽ và linh hoạt cho học sâu. Cụ thể, bạn sẽ học:
- Hiểu khái niệm học chuyển giao và lợi ích của nó.
- Cách sử dụng mô hình tiền huấn luyện ResNet-18.
- Phương pháp phân loại hình ảnh kiến và ong.
- Tùy chỉnh lớp Fully Connected cuối cùng để phù hợp với bài toán cụ thể.
- So sánh hai chiến lược phổ biến:
- Tinh chỉnh toàn bộ mạng (Finetuning).
- Huấn luyện chỉ lớp cuối cùng.
- Thực hiện đánh giá kết quả để tìm ra phương pháp tối ưu.
Hãy cùng đi sâu vào chi tiết và khám phá sức mạnh của Transfer Learning!
Học chuyển giao trong PyTorch
Bộ dữ liệu có thể tải từ trang web chính thức của PyTorch.
Bài viết này được đăng tại [free tuts .net]
Thiết lập và tiền xử lý dữ liệu
Đầu tiên, chúng ta thiết lập các hàm chuyển đổi để chuẩn hóa và tạo mẫu cho tập dữ liệu:
import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler import numpy as np import torchvision from torchvision import datasets, models, transforms import matplotlib.pyplot as plt import time import os import copy # Trung bình và độ lệch chuẩn để chuẩn hóa dữ liệu mean = np.array([0.5, 0.5, 0.5]) std = np.array([0.25, 0.25, 0.25]) data_transforms = { 'train': transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std) ]), 'val': transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean, std) ]), } # Đường dẫn dữ liệu data_dir = 'data/hymenoptera_data' image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=0) for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(class_names)
Hiển thị một số hình ảnh mẫu
def imshow(inp, title): """Hiển thị hình ảnh từ Tensor.""" inp = inp.numpy().transpose((1, 2, 0)) inp = std * inp + mean inp = np.clip(inp, 0, 1) plt.imshow(inp) plt.title(title) plt.show() # Hiển thị một lô hình ảnh inputs, classes = next(iter(dataloaders['train'])) out = torchvision.utils.make_grid(inputs) imshow(out, title=[class_names[x] for x in classes])
Hàm huấn luyện mô hình trong PyTorch
Hàm train_model
được sử dụng để huấn luyện và đánh giá mô hình.
def train_model(model, criterion, optimizer, scheduler, num_epochs=25): since = time.time() best_model_wts = copy.deepcopy(model.state_dict()) best_acc = 0.0 for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # Pha huấn luyện và đánh giá for phase in ['train', 'val']: if phase == 'train': model.train() # Chế độ huấn luyện else: model.eval() # Chế độ đánh giá running_loss = 0.0 running_corrects = 0 # Duyệt qua các lô dữ liệu for inputs, labels in dataloaders[phase]: inputs = inputs.to(device) labels = labels.to(device) # Tiến hành with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) # Tính toán đạo hàm và cập nhật trọng số trong pha huấn luyện if phase == 'train': optimizer.zero_grad() loss.backward() optimizer.step() # Cập nhật các thông số running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) if phase == 'train': scheduler.step() epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_corrects.double() / dataset_sizes[phase] print('{} Loss: {:.4f} Acc: {:.4f}'.format( phase, epoch_loss, epoch_acc)) # Lưu mô hình tốt nhất if phase == 'val' and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) print() time_elapsed = time.time() - since print('Huấn luyện hoàn thành trong {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Độ chính xác tốt nhất trên tập validation: {:4f}'.format(best_acc)) # Load trọng số mô hình tốt nhất model.load_state_dict(best_model_wts) return model
Phương pháp 1: Tinh chỉnh toàn bộ mạng
- Dùng mô hình tiền huấn luyện ResNet-18.
- Thay thế lớp Fully Connected cuối cùng.
- Huấn luyện lại toàn bộ mạng.
model = models.resnet18(pretrained=True) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 2) model = model.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001) step_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) model = train_model(model, criterion, optimizer, step_lr_scheduler, num_epochs=25)
Phương pháp 2: Sử dụng mạng làm trích xuất đặc trưng
- Đóng băng tất cả trọng số trừ lớp cuối cùng.
model_conv = torchvision.models.resnet18(pretrained=True) for param in model_conv.parameters(): param.requires_grad = False num_ftrs = model_conv.fc.in_features model_conv.fc = nn.Linear(num_ftrs, 2) model_conv = model_conv.to(device) criterion = nn.CrossEntropyLoss() optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9) exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1) model_conv = train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=25)