Thông báo: Download 4 khóa học Python từ cơ bản đến nâng cao tại đây.
Cách lưu trữ và tải lại Models trong PyTorch
Trong bài viết này, bạn sẽ học cách lưu trữ và tải lại Models trong PyTorch. Hướng dẫn này sẽ giới thiệu các hàm quan trọng mà bạn cần nhớ và các phương pháp lưu trữ Models khác nhau. Ngoài ra, bài viết cũng sẽ đề cập đến những điều cần lưu ý khi bạn làm việc với GPU.
Những hàm cần biết:
torch.save()
: Lưu dữ liệu (Models , tensor hoặc dictionary).torch.load()
: Tải lại dữ liệu đã lưu.load_state_dict()
: Tải lại trạng thái (state dictionary) của Models
Hai phương pháp lưu trữ phổ biến
1. Lưu toàn bộ Models (Cách "lười")
- Ưu điểm: Tiện lợi, bạn không cần tạo lại Models trước khi tải.
- Nhược điểm: Bị phụ thuộc vào định nghĩa lớp Models khi khôi phục.
2. Lưu chỉ state_dict
(Khuyến nghị)
- Ưu điểm: Linh hoạt, dễ dàng áp dụng cho các cấu hình khác nhau.
- Nhược điểm: Phải tạo lại Models trước khi tải.
Cách lưu trữ và tải lại Models trong PyTorch
1. Định nghĩa Models :
Bài viết này được đăng tại [free tuts .net]
import torch import torch.nn as nn class Model(nn.Module): def __init__(self, n_input_features): super(Model, self).__init__() self.linear = nn.Linear(n_input_features, 1) def forward(self, x): y_pred = torch.sigmoid(self.linear(x)) return y_pred model = Model(n_input_features=6) # Huấn luyện Models ...
Phương pháp 1: Lưu và tải toàn bộ Models
Lưu Models :
FILE = "model.pth" torch.save(model, FILE) # Lưu toàn bộ Models vào file
Tải lại Models :
loaded_model = torch.load(FILE) # Tải lại từ file loaded_model.eval() # Chuyển sang chế độ đánh giá (evaluation mode)
Lưu ý: Bạn cần có định nghĩa lớp Model
giống như trước.
Phương pháp 2: Lưu và tải state_dict
của Models (Khuyến nghị)
Lưu state_dict
:
FILE = "model.pth" torch.save(model.state_dict(), FILE) # Lưu state_dict vào file
Tải lại state_dict
:
loaded_model = Model(n_input_features=6) # Tạo lại Models loaded_model.load_state_dict(torch.load(FILE)) # Nạp state_dict từ file loaded_model.eval() # Chuyển sang chế độ đánh giá
Lưu và tải "checkpoint" trong PyTorch
"Checkpoint" là một snapshot (bản sao) lưu lại trạng thái Models cùng với trạng thái tối ưu hóa, giúp bạn tiếp tục huấn luyện sau đó.
Lưu checkpoint:
learning_rate = 0.01 optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) checkpoint = { "epoch": 90, "model_state": model.state_dict(), "optim_state": optimizer.state_dict() } FILE = "checkpoint.pth" torch.save(checkpoint, FILE)
Tải checkpoint:
checkpoint = torch.load(FILE) model = Model(n_input_features=6) optimizer = torch.optim.SGD(model.parameters(), lr=0) # Khởi tạo tối ưu hóa model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optim_state']) epoch = checkpoint['epoch'] model.eval() # Hoặc model.train() để tiếp tục huấn luyện
Lưu ý khi sử dụng GPU và CPU trong PyTorch
Lưu trên GPU, tải trên CPU
device = torch.device("cuda") model.to(device) torch.save(model.state_dict(), "model.pth") device = torch.device('cpu') model = Model(*args, **kwargs) model.load_state_dict(torch.load("model.pth", map_location=device))
Lưu trên GPU, tải trên GPU
device = torch.device("cuda") model.to(device) torch.save(model.state_dict(), "model.pth") model = Model(*args, **kwargs) model.load_state_dict(torch.load("model.pth")) model.to(device)
Lưu trên CPU, tải trên GPU
torch.save(model.state_dict(), "model.pth") device = torch.device("cuda") model = Model(*args, **kwargs) model.load_state_dict(torch.load("model.pth", map_location="cuda:0")) model.to(device)
Quan trọng
Khi chuyển Models sang chế độ eval()
, các tầng như dropout hoặc batch normalization sẽ hoạt động đúng chế độ suy luận. Để tiếp tục huấn luyện, hãy chuyển sang train()
.
Kết bài
Trong bài viết này, chúng ta đã tìm hiểu cách lưu trữ và tải lại Models trong PyTorch với hai phương pháp: lưu toàn bộ Models và chỉ lưu state_dict
. Việc áp dụng checkpoint sẽ giúp quá trình huấn luyện không bị gián đoạn. Nếu bạn làm việc với GPU, cần cẩn trọng khi xử lý chuyển đổi giữa GPU và CPU để đảm bảo tính chính xác.
Chúc bạn thành công với các Models PyTorch của mình!