Thông báo: Download 4 khóa học Python từ cơ bản đến nâng cao tại đây.
Hướng dẫn cơ bản mạng Nơ-ron Tích Chập (CNN) trong PyTorch
Trong bài hướng dẫn này, mình sẽ cùng tìm hiểu cách xây dựng mạng nơ-ron tích chập (Convolutional Neural Network - CNN), một mô hình mạnh mẽ thường được sử dụng để xử lý dữ liệu hình ảnh. Với bộ dữ liệu nổi tiếng CIFAR-10, gồm 60,000 hình ảnh thuộc 10 lớp khác nhau, mình sẽ từng bước triển khai một mạng CNN từ cơ bản đến hoàn thiện trong PyTorch.
Hướng dẫn này sẽ cung cấp cái nhìn toàn diện về cách hoạt động của các thành phần quan trọng trong mạng CNN, như bộ lọc tích chập (Convolutional Filter), kỹ thuật pooling (Max Pooling), và cách xác định kích thước các tầng một cách phù hợp. Từ đó, bạn sẽ học cách áp dụng kiến thức này để xây dựng kiến trúc CNN, huấn luyện mô hình và kiểm tra hiệu quả phân loại hình ảnh. Toàn bộ mã nguồn từ bài viết đều được chia sẻ đầy đủ để bạn có thể thực hành và tự khám phá thêm.
Mục lục
CNN trong PyTorch
Lưu ý: Kiến trúc CNN trong hướng dẫn này dựa trên mô hình LeNet-5 nổi tiếng. Bạn có thể tìm hiểu thêm về kiến trúc này tại đây.
1 2 3 4 5 6 7 8 | # Import các thư viện cần thiết import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np |
Cấu hình thiết bị và các tham số
1 2 3 4 5 6 7 | # Xác định thiết bị (CPU hoặc GPU) device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) # Các tham số quan trọng num_epochs = 5 batch_size = 4 learning_rate = 0.001 |
Tải và xử lý dữ liệu
Bộ dữ liệu CIFAR-10 gồm 60,000 hình ảnh màu kích thước 32x32, thuộc 10 lớp, với mỗi lớp chứa 6,000 hình.
Hình ảnh sẽ được chuẩn hóa về dải giá trị từ [-1, 1].
1 2 3 4 5 6 7 8 9 10 | transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(( 0.5 , 0.5 , 0.5 ), ( 0.5 , 0.5 , 0.5 ))]) # Tải dữ liệu huấn luyện và kiểm tra train_dataset = torchvision.datasets.CIFAR10(root = './data' , train = True , download = True , transform = transform) test_dataset = torchvision.datasets.CIFAR10(root = './data' , train = False , download = True , transform = transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True ) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = batch_size, shuffle = False ) |
Các lớp trong bộ dữ liệu CIFAR-10:
Bài viết này được đăng tại [free tuts .net]
1 | classes = ( 'plane' , 'car' , 'bird' , 'cat' , 'deer' , 'dog' , 'frog' , 'horse' , 'ship' , 'truck' ) |
Hiển thị một số hình ảnh mẫu:
1 2 3 4 5 6 7 8 9 10 11 | def imshow(img): img = img / 2 + 0.5 # Phục hồi ảnh về dải giá trị gốc npimg = img.numpy() plt.imshow(np.transpose(npimg, ( 1 , 2 , 0 ))) plt.show() # Lấy và hiển thị ảnh ngẫu nhiên từ tập huấn luyện dataiter = iter (train_loader) images, labels = dataiter. next () imshow(torchvision.utils.make_grid(images)) |
Xây dựng kiến trúc CNN
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | class ConvNet(nn.Module): def __init__( self ): super (ConvNet, self ).__init__() self .conv1 = nn.Conv2d( 3 , 6 , 5 ) # Lớp tích chập 1 self .pool = nn.MaxPool2d( 2 , 2 ) # Lớp pooling self .conv2 = nn.Conv2d( 6 , 16 , 5 ) # Lớp tích chập 2 self .fc1 = nn.Linear( 16 * 5 * 5 , 120 ) # Lớp fully connected 1 self .fc2 = nn.Linear( 120 , 84 ) # Lớp fully connected 2 self .fc3 = nn.Linear( 84 , 10 ) # Lớp output def forward( self , x): x = self .pool(F.relu( self .conv1(x))) # -> Kết quả sau lớp tích chập 1 x = self .pool(F.relu( self .conv2(x))) # -> Kết quả sau lớp tích chập 2 x = x.view( - 1 , 16 * 5 * 5 ) # Flatten (chuẩn bị cho fully connected) x = F.relu( self .fc1(x)) x = F.relu( self .fc2(x)) x = self .fc3(x) return x |
Định nghĩa hàm mất mát và bộ tối ưu
1 2 3 | model = ConvNet().to(device) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate) |
Huấn luyện mô hình
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | n_total_steps = len (train_loader) for epoch in range (num_epochs): for i, (images, labels) in enumerate (train_loader): images = images.to(device) labels = labels.to(device) # Truyền xuôi outputs = model(images) loss = criterion(outputs, labels) # Lan truyền ngược và tối ưu hóa optimizer.zero_grad() loss.backward() optimizer.step() # Hiển thị thông tin sau mỗi 2,000 bước if (i + 1 ) % 2000 = = 0 : print (f 'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{n_total_steps}], Loss: {loss.item():.4f}' ) |
Kiểm tra mô hình và đánh giá độ chính xác
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | with torch.no_grad(): n_correct = 0 n_samples = 0 n_class_correct = [ 0 for i in range ( 10 )] n_class_samples = [ 0 for i in range ( 10 )] for images, labels in test_loader: images = images.to(device) labels = labels.to(device) outputs = model(images) _, predicted = torch. max (outputs, 1 ) n_samples + = labels.size( 0 ) n_correct + = (predicted = = labels). sum ().item() for i in range (batch_size): label = labels[i] pred = predicted[i] if label = = pred: n_class_correct[label] + = 1 n_class_samples[label] + = 1 acc = 100.0 * n_correct / n_samples print (f 'Dộ chính xác của mạng: {acc:.2f} %' ) for i in range ( 10 ): acc = 100.0 * n_class_correct[i] / n_class_samples[i] print (f 'Dộ chính xác của lớp {classes[i]}: {acc:.2f} %' ) |
Kết bài
Qua bài hướng dẫn này, mình đã triển khai thành công một mạng nơ-ron tích chập (CNN) trong PyTorch để phân loại hình ảnh trên bộ dữ liệu CIFAR-10. Từ việc tìm hiểu cấu trúc CNN, bộ lọc tích chập, kỹ thuật pooling, đến việc huấn luyện và đánh giá mô hình, bạn đã nắm được các bước cơ bản để xây dựng và áp dụng CNN vào bài toán thực tế.
Cách tiếp cận này không chỉ giúp bạn hiểu rõ hơn về nguyên lý hoạt động của mạng CNN, mà còn mở ra cơ hội để bạn thử nghiệm các cấu trúc phức tạp hơn hoặc áp dụng vào các bộ dữ liệu khác. Hãy tiếp tục khám phá, thực hành và phát triển những ý tưởng mới để tận dụng tối đa tiềm năng của học sâu trong xử lý hình ảnh. Chúc bạn thành công!
Cùng chuyên mục:
Python Regex
Python Concurrency
Tkinter Tutorial
PyQt Tutorial
Python căn bản
Python nâng cao
Django
Matplotlib
Pandas
Numpy
Python function