Mục tiêu bài học
- Làm quen với cách tổ chức mô hình trong PyTorch
- Biết cách xử lý dữ liệu huấn luyện (dataloader)
- Xây dựng một mạng nơ-ron feedforward đơn giản
- Huấn luyện mô hình trên dữ liệu giả định
- Biết cách tính loss, backprop, và cập nhật trọng số
1. Tại sao dùng PyTorch?
Dù có thể tự viết bằng NumPy, nhưng PyTorch giúp:
- Tự động tính đạo hàm với
autograd
- Có module mô hình hóa (
nn.Module
) rất trực quan - Dễ debug từng bước như code Python thuần
- Hỗ trợ GPU, mở rộng sang CNN, RNN dễ dàng sau này
2. Cài đặt môi trường
Cài PyTorch:
pip install torch torchvision
3. Tạo dữ liệu huấn luyện
Ở đây ta dùng bài toán XOR tương tự như bài 2:
import torch
import torch.nn as nn
import torch.optim as optim
# Dữ liệu XOR
X = torch.tensor([[0., 0.],
[0., 1.],
[1., 0.],
[1., 1.]])
Y = torch.tensor([[0.],
[1.],
[1.],
[0.]])
4. Xây mạng nơ-ron với nn.Module
class XORNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 4) # input 2 -> hidden 4
self.fc2 = nn.Linear(4, 1) # hidden 4 -> output 1
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.sigmoid(self.fc1(x))
x = self.sigmoid(self.fc2(x))
return x
model = XORNet()
5. Huấn luyện mô hình
# Hàm mất mát và optimizer
criterion = nn.BCELoss() # Binary Cross Entropy
optimizer = optim.SGD(model.parameters(), lr=0.1)
# Vòng lặp huấn luyện
for epoch in range(10000):
y_pred = model(X)
loss = criterion(y_pred, Y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 1000 == 0:
print(f"Epoch {epoch}: Loss = {loss.item():.4f}")
6. Kiểm tra kết quả
with torch.no_grad():
output = model(X)
predicted = (output > 0.5).float()
print("Dự đoán:n", predicted)
print("Thực tế:n", Y)
Kết quả mong đợi:
Dự đoán:
tensor([[0.],
[1.],
[1.],
[0.]])
7. Tổng kết
- Bạn đã biết cách định nghĩa mô hình với PyTorch
- Hiểu cách thiết lập loss, backprop và optimizer
- Biết dùng
.forward()
,.backward()
và.step()
- Từ đây có thể mở rộng mạng sang bài toán phân loại ảnh, chuỗi…
bài học tiếp theo
- Bài 4: Giới thiệu về CNN – Mạng nơ-ron tích chập và bài toán phân loại ảnh
- Bài 4 (thực hành): Dùng PyTorch huấn luyện CNN trên MNIST
- Giải thích trực quan về Gradient Descent và vai trò Learning Rate
Sign up