Thông báo: Download 4 khóa học Python từ cơ bản đến nâng cao tại đây.
Triển khai Decision Tree bằng Python
Cây quyết định (Decision Tree) là một trong những thuật toán phổ biến và dễ hiểu nhất trong học máy. Với khả năng xử lý cả bài toán phân loại và hồi quy, Decision Tree thường được xem là công cụ mạnh mẽ và dễ áp dụng trong các bài toán thực tế.

Bài hướng dẫn này sẽ giúp bạn tự triển khai một thuật toán Decision Tree hoàn chỉnh chỉ sử dụng các module Python cơ bản và thư viện numpy. Qua đó, bạn sẽ hiểu rõ cách thức hoạt động, từ lý thuyết đến mã nguồn, mà không cần phụ thuộc vào các thư viện phức tạp.
.png)
Decision Tree sử dụng các module Python
Dưới đây là mã nguồn kèm giải thích, giúp bạn nắm rõ từng bước triển khai.
import numpy as np
from collections import Counter
# Hàm tính toán Entropy để đánh giá mức độ hỗn loạn trong dữ liệu
def entropy(y):
hist = np.bincount(y) # Tính tần suất xuất hiện của từng nhãn
ps = hist / len(y) # Phân phối xác suất
return -np.sum([p * np.log2(p) for p in ps if p > 0]) # Công thức tính entropy
# Lớp Node đại diện cho từng nút trong cây
class Node:
def __init__(self, feature=None, threshold=None, left=None, right=None, *, value=None):
self.feature = feature # Chỉ số đặc trưng được chọn để chia
self.threshold = threshold # Ngưỡng chia của đặc trưng
self.left = left # Nhánh trái
self.right = right # Nhánh phải
self.value = value # Giá trị tại nút lá (nếu là lá)
# Xác định nút hiện tại có phải nút lá không
def is_leaf_node(self):
return self.value is not None
# Lớp chính để xây dựng và huấn luyện Decision Tree
class DecisionTree:
def __init__(self, min_samples_split=2, max_depth=100, n_feats=None):
self.min_samples_split = min_samples_split # Số lượng mẫu tối thiểu để tiếp tục chia
self.max_depth = max_depth # Độ sâu tối đa của cây
self.n_feats = n_feats # Số lượng đặc trưng sử dụng khi xây dựng
self.root = None # Gốc của cây
# Huấn luyện cây quyết định
def fit(self, X, y):
self.n_feats = X.shape[1] if not self.n_feats else min(self.n_feats, X.shape[1])
self.root = self._grow_tree(X, y)
# Dự đoán nhãn cho các mẫu
def predict(self, X):
return np.array([self._traverse_tree(x, self.root) for x in X])
# Xây dựng cây bằng cách chọn chia cắt tối ưu
def _grow_tree(self, X, y, depth=0):
n_samples, n_features = X.shape
n_labels = len(np.unique(y))
# Điều kiện dừng: độ sâu, số lượng nhãn hoặc số lượng mẫu tối thiểu
if (depth >= self.max_depth or n_labels == 1 or n_samples < self.min_samples_split):
leaf_value = self._most_common_label(y)
return Node(value=leaf_value)
# Chọn ngẫu nhiên một số lượng đặc trưng
feat_idxs = np.random.choice(n_features, self.n_feats, replace=False)
# Tìm tiêu chí chia cắt tốt nhất
best_feat, best_thresh = self._best_criteria(X, y, feat_idxs)
# Phân nhánh cây
left_idxs, right_idxs = self._split(X[:, best_feat], best_thresh)
left = self._grow_tree(X[left_idxs, :], y[left_idxs], depth + 1)
right = self._grow_tree(X[right_idxs, :], y[right_idxs], depth + 1)
return Node(best_feat, best_thresh, left, right)
# Tìm tiêu chí chia cắt (feature và threshold) tối ưu
def _best_criteria(self, X, y, feat_idxs):
best_gain = -1 # Lợi ích thông tin tốt nhất (initial)
split_idx, split_thresh = None, None
for feat_idx in feat_idxs:
X_column = X[:, feat_idx]
thresholds = np.unique(X_column)
for threshold in thresholds:
gain = self._information_gain(y, X_column, threshold)
if gain > best_gain:
best_gain = gain
split_idx = feat_idx
split_thresh = threshold
return split_idx, split_thresh
# Tính lợi ích thông tin dựa trên Entropy trước và sau chia
def _information_gain(self, y, X_column, split_thresh):
parent_entropy = entropy(y) # Entropy ban đầu
# Chia dữ liệu thành 2 nhóm
left_idxs, right_idxs = self._split(X_column, split_thresh)
if len(left_idxs) == 0 or len(right_idxs) == 0:
return 0
# Entropy sau khi chia
n = len(y)
n_l, n_r = len(left_idxs), len(right_idxs)
e_l, e_r = entropy(y[left_idxs]), entropy(y[right_idxs])
child_entropy = (n_l / n) * e_l + (n_r / n) * e_r
return parent_entropy - child_entropy
# Chia dữ liệu dựa trên ngưỡng threshold
def _split(self, X_column, split_thresh):
left_idxs = np.argwhere(X_column <= split_thresh).flatten()
right_idxs = np.argwhere(X_column > split_thresh).flatten()
return left_idxs, right_idxs
# Đi qua từng nút của cây để dự đoán
def _traverse_tree(self, x, node):
if node.is_leaf_node():
return node.value
if x[node.feature] <= node.threshold:
return self._traverse_tree(x, node.left)
return self._traverse_tree(x, node.right)
# Xác định nhãn phổ biến nhất ở nút lá
def _most_common_label(self, y):
counter = Counter(y)
most_common = counter.most_common(1)[0][0]
return most_common
Giải thích mã
Bài viết này được đăng tại [free tuts .net]
Entropy: Đo độ hỗn loạn của dữ liệu. Giá trị càng nhỏ thì dữ liệu càng thuần nhất.
Node: Đại diện cho một nút trong cây quyết định. Có thể là nút chia hoặc nút lá.
DecisionTree:
fit: Xây dựng cây bằng cách lặp qua từng cấp độ, chọn chia cắt tối ưu.predict: Duyệt qua cây từ gốc đến lá để trả về nhãn cho từng mẫu._best_criteria: Chọn đặc trưng và ngưỡng chia tốt nhất dựa trên Information Gain._grow_tree: Quy trình đệ quy xây dựng cây._split: Chia dữ liệu thành 2 nhóm theo điều kiện.
Kết bài
Decision Tree là một thuật toán dễ hiểu và mạnh mẽ trong học máy. Việc triển khai từ đầu giúp bạn nắm rõ cách cây hoạt động, từ việc chia dữ liệu đến dự đoán. Đây cũng là bước nền để học các mô hình nâng cao hơn như Random Forest hay Gradient Boosted Trees.

Các kiểu dữ liệu trong C ( int - float - double - char ...)
Thuật toán tìm ước chung lớn nhất trong C/C++
Cấu trúc lệnh switch case trong C++ (có bài tập thực hành)
ComboBox - ListBox trong lập trình C# winforms
Random trong Python: Tạo số random ngẫu nhiên
Lệnh cin và cout trong C++
Cách khai báo biến trong PHP, các loại biến thường gặp
Download và cài đặt Vertrigo Server
Thẻ li trong HTML
Thẻ article trong HTML5
Cấu trúc HTML5: Cách tạo template HTML5 đầu tiên
Cách dùng thẻ img trong HTML và các thuộc tính của img
Thẻ a trong HTML và các thuộc tính của thẻ a thường dùng