Thông báo: Download 4 khóa học Python từ cơ bản đến nâng cao tại đây.
Học chuyển giao (Transfer Learning) trong PyTorch Beginner
Trong lĩnh vực trí tuệ nhân tạo, học chuyển giao (Transfer Learning) đã trở thành một kỹ thuật quan trọng và hiệu quả, đặc biệt là khi làm việc với các bài toán thị giác máy tính như phân loại hình ảnh. Với sự hỗ trợ của các mô hình tiền huấn luyện, Transfer Learning giúp giảm thời gian huấn luyện và nâng cao độ chính xác, đặc biệt khi dữ liệu huấn luyện bị hạn chế.

Bài viết này sẽ hướng dẫn bạn cách triển khai Transfer Learning trong PyTorch, một framework mạnh mẽ và linh hoạt cho học sâu. Cụ thể, bạn sẽ học:
- Hiểu khái niệm học chuyển giao và lợi ích của nó.
- Cách sử dụng mô hình tiền huấn luyện ResNet-18.
- Phương pháp phân loại hình ảnh kiến và ong.
- Tùy chỉnh lớp Fully Connected cuối cùng để phù hợp với bài toán cụ thể.
- So sánh hai chiến lược phổ biến:
- Tinh chỉnh toàn bộ mạng (Finetuning).
- Huấn luyện chỉ lớp cuối cùng.
- Thực hiện đánh giá kết quả để tìm ra phương pháp tối ưu.
Hãy cùng đi sâu vào chi tiết và khám phá sức mạnh của Transfer Learning!
Học chuyển giao trong PyTorch

Bộ dữ liệu có thể tải từ trang web chính thức của PyTorch.
Bài viết này được đăng tại [free tuts .net]
Thiết lập và tiền xử lý dữ liệu
Đầu tiên, chúng ta thiết lập các hàm chuyển đổi để chuẩn hóa và tạo mẫu cho tập dữ liệu:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
# Trung bình và độ lệch chuẩn để chuẩn hóa dữ liệu
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.25, 0.25, 0.25])
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std)
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean, std)
]),
}
# Đường dẫn dữ liệu
data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=0)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(class_names)
Hiển thị một số hình ảnh mẫu
def imshow(inp, title):
"""Hiển thị hình ảnh từ Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
plt.title(title)
plt.show()
# Hiển thị một lô hình ảnh
inputs, classes = next(iter(dataloaders['train']))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])
Hàm huấn luyện mô hình trong PyTorch
Hàm train_model được sử dụng để huấn luyện và đánh giá mô hình.
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# Pha huấn luyện và đánh giá
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Chế độ huấn luyện
else:
model.eval() # Chế độ đánh giá
running_loss = 0.0
running_corrects = 0
# Duyệt qua các lô dữ liệu
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# Tiến hành
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# Tính toán đạo hàm và cập nhật trọng số trong pha huấn luyện
if phase == 'train':
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Cập nhật các thông số
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
# Lưu mô hình tốt nhất
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
print()
time_elapsed = time.time() - since
print('Huấn luyện hoàn thành trong {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Độ chính xác tốt nhất trên tập validation: {:4f}'.format(best_acc))
# Load trọng số mô hình tốt nhất
model.load_state_dict(best_model_wts)
return model
Phương pháp 1: Tinh chỉnh toàn bộ mạng
- Dùng mô hình tiền huấn luyện ResNet-18.
- Thay thế lớp Fully Connected cuối cùng.
- Huấn luyện lại toàn bộ mạng.
model = models.resnet18(pretrained=True) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 2) model = model.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001) step_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) model = train_model(model, criterion, optimizer, step_lr_scheduler, num_epochs=25)
Phương pháp 2: Sử dụng mạng làm trích xuất đặc trưng
- Đóng băng tất cả trọng số trừ lớp cuối cùng.
model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
param.requires_grad = False
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)
model_conv = model_conv.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
model_conv = train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=25)

Các kiểu dữ liệu trong C ( int - float - double - char ...)
Thuật toán tìm ước chung lớn nhất trong C/C++
Cấu trúc lệnh switch case trong C++ (có bài tập thực hành)
ComboBox - ListBox trong lập trình C# winforms
Random trong Python: Tạo số random ngẫu nhiên
Lệnh cin và cout trong C++
Cách khai báo biến trong PHP, các loại biến thường gặp
Download và cài đặt Vertrigo Server
Thẻ li trong HTML
Thẻ article trong HTML5
Cấu trúc HTML5: Cách tạo template HTML5 đầu tiên
Cách dùng thẻ img trong HTML và các thuộc tính của img
Thẻ a trong HTML và các thuộc tính của thẻ a thường dùng