准备工作

下载 python

https://www.python.org/ftp/python/3.12.6/python-3.12.6-amd64.exe

下载后直接安装,勾选上 ADD PATH 选项。

安装

pip install torch torchvision

编码

  • mnist_train.py
# mnist_train.py
# 一个最简单的 PyTorch 神经网络:手写数字识别

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms


def main():
    # 1️⃣ 数据预处理:把图片转成 tensor 并归一化
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # 2️⃣ 下载并加载 MNIST 数据集
    train_dataset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.MNIST(
        root='./data', train=False, download=True, transform=transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=64, shuffle=False)

    # 3️⃣ 定义一个简单的神经网络
    class SimpleNN(nn.Module):
        def __init__(self):
            super(SimpleNN, self).__init__()
            self.flatten = nn.Flatten()  # 把 28x28 展开成一维
            self.fc1 = nn.Linear(28 * 28, 128)
            self.relu = nn.ReLU()
            self.fc2 = nn.Linear(128, 10)  # 10 个输出类别(数字 0~9)

        def forward(self, x):
            x = self.flatten(x)
            x = self.fc1(x)
            x = self.relu(x)
            x = self.fc2(x)
            return x

    model = SimpleNN()

    # 4️⃣ 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # 5️⃣ 训练模型
    num_epochs = 5
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            # 前向传播
            outputs = model(images)
            loss = criterion(outputs, labels)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

    # 6️⃣ 测试模型准确率
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"✅ Accuracy on test set: {100 * correct / total:.2f}%")

    print("🎉 训练完成!模型已经学会识别手写数字。")


if __name__ == "__main__":
    main()

运行

python mnist_train.py

测试效果

执行后的测试效果

>python mnist_train.py
100.0%
100.0%
100.0%
100.0%
Epoch [1/5], Loss: 0.3926
Epoch [2/5], Loss: 0.2096
Epoch [3/5], Loss: 0.1493
Epoch [4/5], Loss: 0.1209
Epoch [5/5], Loss: 0.1010
✅ Accuracy on test set: 96.60%
🎉 训练完成!模型已经学会识别手写数字。

如何使用

说明

模型的训练之后,我们可以保存之后使用。

保存

训练结束后加上

torch.save(model.state_dict(), "model/mnist_model.pth")
print("✅ 模型已保存为 mnist_model.pth")

这会在当前目录生成一个 model/mnist_model.pth 文件,保存模型参数。

重新执行一遍:

python mnist_train.py
Epoch [1/5], Loss: 0.3826
Epoch [2/5], Loss: 0.1920
Epoch [3/5], Loss: 0.1380
Epoch [4/5], Loss: 0.1108
Epoch [5/5], Loss: 0.0957
✅ Accuracy on test set: 97.10%
🎉 训练完成!模型已经学会识别手写数字。
✅ 模型已保存为 mnist_model.pth

使用模型

新建一个 Python 文件,例如 mnist_predict.py,内容如下:

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image

# 定义和训练时一样的模型结构
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# 模型超参数(要和训练时一样)
input_size = 28 * 28
hidden_size = 128
num_classes = 10

# 创建模型实例并加载权重
model = NeuralNet(input_size, hidden_size, num_classes)
model.load_state_dict(torch.load("model/mnist_model.pth"))
model.eval()  # 进入推理模式

# 图像预处理:灰度化、缩放、转Tensor、标准化
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # 转为灰度
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 读取要识别的图片路径
img_path = "digit.png"   # 你要识别的图片路径
image = Image.open(img_path).convert('L')  # 转为灰度
image = transform(image).view(-1, 28*28)

# 模型预测
with torch.no_grad():
    output = model(image)
    _, predicted = torch.max(output.data, 1)
    print(f"🧠 预测结果是数字:{predicted.item()}")

图片准备

图片要求:

背景最好是白色或黑色;

数字区域要清晰;

尺寸不限(代码会自动缩放到 28×28);

可以用手机拍照然后裁剪成正方形。

保存成 digit.png 放在与你的 mnist_predict.py 同目录下。

测试

发现给数字 7 识别成了 0,只能说也不是那么准确。

> python mnist_predict.py
🧠 预测结果是数字:0

为什么翻车了?

段代码确实是 MNIST 入门版 —— 简单易懂,但它的能力有限:

它只用了两层全连接网络,没有卷积层(CNN),所以当你拿“现实截图”去识别,就很容易翻车。

MNIST 训练集的图片是这样的:

黑底白字

干净、居中、28×28 尺寸

而你截图的数字图片可能是:

白底黑字(反色)

背景杂、位置偏

尺寸不一样(不是 28×28)

有灰度、噪声或边缘模糊

升级版本 CNN

训练版本

  • mnist_train_cnn.py
# mnist_train_cnn.py
# 一个更强的 CNN 模型,用于手写数字识别

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms


def main():
    # 1️⃣ 数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # 2️⃣ MNIST 数据集
    train_dataset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.MNIST(
        root='./data', train=False, download=True, transform=transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=64, shuffle=False)

    # 3️⃣ 定义卷积神经网络
    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()
            self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
            self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
            self.pool = nn.MaxPool2d(2, 2)
            self.dropout = nn.Dropout(0.25)

            # 临时创建一个假输入自动推导 flatten 大小
            x = torch.zeros(1, 1, 28, 28)
            x = self._forward_features(x)
            flatten_size = x.view(1, -1).size(1)

            self.fc1 = nn.Linear(flatten_size, 128)
            self.fc2 = nn.Linear(128, 10)

        def _forward_features(self, x):
            x = torch.relu(self.conv1(x))
            x = self.pool(torch.relu(self.conv2(x)))
            x = self.dropout(x)
            return x

        def forward(self, x):
            x = self._forward_features(x)
            x = x.view(x.size(0), -1)
            x = torch.relu(self.fc1(x))
            x = self.fc2(x)
            return x

    model = CNN()

    # 4️⃣ 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # 5️⃣ 训练模型
    num_epochs = 5
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

    # 6️⃣ 测试准确率
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"✅ Accuracy on test set: {100 * correct / total:.2f}%")

    # 保存模型
    torch.save(model.state_dict(), "model/mnist_cnn.pth")
    print("🎉 训练完成!模型已保存到 model/mnist_cnn.pth")


if __name__ == "__main__":
    main()

执行训练

> python mnist_train_cnn.py
Epoch [1/5], Loss: 0.1438
Epoch [2/5], Loss: 0.0460
Epoch [3/5], Loss: 0.0300
Epoch [4/5], Loss: 0.0203
Epoch [5/5], Loss: 0.0157
✅ Accuracy on test set: 98.74%
🎉 训练完成!模型已保存到 model/mnist_cnn.pth

感觉这个要比第一种方式慢了不少。

推理脚本(单个图片预测)

  • mnist_cnn_predict.py:
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms

# 模型结构需与训练时一致
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(64 * 14 * 14, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.pool(x)              # 一次池化,28→14
        x = self.dropout(x)
        x = x.view(-1, 64 * 14 * 14)  # 改成对应的维度
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 加载模型
model = CNN()
model.load_state_dict(torch.load("model/mnist_cnn.pth"))
model.eval()

# 图片预处理
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 你的截图路径
img_path = "digit.png"
image = Image.open(img_path).convert("RGB")
image = transform(image).unsqueeze(0)

# 推理
with torch.no_grad():
    outputs = model(image)
    _, predicted = torch.max(outputs, 1)
    print(f"🧠 模型预测结果: {predicted.item()}")

执行测试:

🧠 模型预测结果: 7

这次对了。

后续我们可以从理论角度看一下为什么 CNN 的效果更好。

v3-实时手写数字识别

思路

上面的 CNN 已经可以识别数字,但是不够有趣。

我们可以实现一个画板,手动写数字,然后让其预测。

安装依赖

可以通过 opencv 构建一个画板

我们先安装一下依赖

pip install torch torchvision opencv-python pillow numpy

实现

import cv2
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms

# ======================
# ✅ 1. 定义 CNN 模型结构(与训练时一致)
# ======================
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        # 注意是 64 * 14 * 14(因为只池化一次)
        self.fc1 = nn.Linear(64 * 14 * 14, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        x = self.dropout(x)
        x = x.view(-1, 64 * 14 * 14)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# ======================
# ✅ 2. 加载模型
# ======================
model = CNN()
model.load_state_dict(torch.load("model/mnist_cnn.pth", map_location="cpu"))
model.eval()

# ======================
# ✅ 3. 图像预处理函数
# ======================
def preprocess_image_for_mnist(image):
    # 转成灰度
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    # 反色:白底黑字 -> 黑底白字
    gray = cv2.bitwise_not(gray)
    # 转 PIL
    pil = Image.fromarray(gray)
    # 兼容 Pillow 新旧版本
    try:
        resample = Image.Resampling.LANCZOS
    except AttributeError:
        resample = Image.ANTIALIAS
    pil = pil.resize((28, 28), resample)
    # 变成 Tensor
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    tensor = transform(pil).unsqueeze(0)
    return tensor


# ======================
# ✅ 4. OpenCV 绘图窗口
# ======================
canvas = np.ones((280, 280, 3), dtype=np.uint8) * 255
drawing = False
last_point = None

def draw(event, x, y, flags, param):
    global drawing, last_point
    if event == cv2.EVENT_LBUTTONDOWN:
        drawing = True
        last_point = (x, y)
    elif event == cv2.EVENT_MOUSEMOVE and drawing:
        cv2.line(canvas, last_point, (x, y), (0, 0, 0), 12)
        last_point = (x, y)
    elif event == cv2.EVENT_LBUTTONUP:
        drawing = False
        last_point = None

cv2.namedWindow("MNIST Draw Board")
cv2.setMouseCallback("MNIST Draw Board", draw)

print("🎨 Draw a digit with the mouse.")
print("✅ Press 's' to save & predict.")
print("🧹 Press 'c' to clear.")
print("❌ Press 'q' to quit.")

# ======================
# ✅ 5. 主循环
# ======================
while True:
    cv2.imshow("MNIST Draw Board", canvas)
    key = cv2.waitKey(1) & 0xFF

    if key == ord('q'):
        break

    elif key == ord('c'):
        canvas[:] = 255  # 清空画板

    elif key == ord('s'):
        # 拷贝当前画布
        img_for_pred = canvas.copy()

        # 预处理成模型输入
        tensor = preprocess_image_for_mnist(img_for_pred)

        # 模型预测
        with torch.no_grad():
            outputs = model(tensor)
            _, predicted = torch.max(outputs, 1)
            result = predicted.item()

        print(f"🧠 Predicted digit: {result}")

        # 在窗口上显示识别结果(保留当前画面)
        display = img_for_pred.copy()
        cv2.putText(display, f"RES: {result}", (10, 50),
                    cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 0, 255), 3)

        # 显示预测结果窗口
        cv2.imshow("MNIST Draw Board", display)

        # 等待 2 秒,让你能看清结果
        key2 = cv2.waitKey(2000) & 0xFF
        if key2 == ord('q'):
            break
        elif key2 == ord('c'):
            canvas[:] = 255  # 清空
        else:
            # 2 秒后回到原画布
            cv2.imshow("MNIST Draw Board", canvas)


cv2.destroyAllWindows()

测试

python mnist_draw_predict.py

但是实际发现预测的效果一般。

为什么不那么准确呢?

真正的原因在于:你的手写图片分布和 MNIST 的训练数据分布不一样

🧠 一、为什么你的手写模型识别不准?

我们来拆解一下差异:

1️⃣ 背景不同

  • MNIST 训练集的图片是 纯黑背景 + 白色数字
  • 你在画布上是 白背景 + 黑色数字
  • 虽然你做了 Normalize((0.5,), (0.5,)),但颜色方向反了。

✅ 解决方法:在预测前反转图像颜色:

pil = ImageOps.invert(pil)

放在你 preprocess_image_for_mnist() 函数里。


2️⃣ 数字太靠边 or 不居中

MNIST 的数字是居中在 28x28 图像中,而你画的数字可能偏上、偏左或太大。 CNN 对“空间位置”是相对敏感的,所以:

  • 偏移一点,特征图分布就变了;
  • 模型认为这是另一个“类”。

✅ 解决方法:

  • 在预处理时做一个 自动居中

    • 把手写图像二值化;
    • 找出数字的边界框;
    • 把它裁剪出来并居中放回 28x28。

我可以帮你写这个增强版预处理。


3️⃣ 线条粗细、风格不同

  • MNIST 的笔迹是扫描过的,非常平滑、线条较粗;
  • 你用鼠标画的线可能:

    • 过细;
    • 锯齿明显;
    • 不连续。

✅ 解决方法:

  • 把线条画粗一点(比如 cv2.line(..., thickness=20));
  • 或在预处理时模糊一下(cv2.GaussianBlur);
  • 甚至可以训练时加上“风格噪声增强”来提升泛化。

4️⃣ 模型是“在 MNIST 上训练的”

MNIST 虽然经典,但太老、太干净。

你的手写数据其实是一个「新的分布」,

模型没见过这种风格,就会误判。

✅ 根本性解决方案:

  • 自己 收集手写数据(比如你画 100 张数字图);
  • 用这些数据再 微调(fine-tune)模型
  • 模型立刻会对你的笔迹风格更敏感。

优化版本

import cv2
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms

# ======================
# ✅ 1. 定义 CNN 模型结构(与训练时一致)
# ======================
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(64 * 14 * 14, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        x = self.dropout(x)
        x = x.view(-1, 64 * 14 * 14)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# ======================
# ✅ 2. 加载模型
# ======================
model = CNN()
model.load_state_dict(torch.load("model/mnist_cnn.pth", map_location="cpu"))
model.eval()

# ======================
# ✅ 3. 图像预处理函数
# ======================
def preprocess_image_for_mnist(image):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    gray = cv2.bitwise_not(gray)
    pil = Image.fromarray(gray)
    try:
        resample = Image.Resampling.LANCZOS
    except AttributeError:
        resample = Image.ANTIALIAS
    pil = pil.resize((28, 28), resample)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    tensor = transform(pil).unsqueeze(0)
    return tensor

# ======================
# ✅ 4. 绘图窗口设置
# ======================
canvas = np.ones((280, 280, 3), dtype=np.uint8) * 255
drawing = False
last_point = None
prediction_text = None  # 保存预测结果以便显示

def draw(event, x, y, flags, param):
    global drawing, last_point
    if event == cv2.EVENT_LBUTTONDOWN:
        drawing = True
        last_point = (x, y)
    elif event == cv2.EVENT_MOUSEMOVE and drawing:
        cv2.line(canvas, last_point, (x, y), (0, 0, 0), 12)
        last_point = (x, y)
    elif event == cv2.EVENT_LBUTTONUP:
        drawing = False
        last_point = None

cv2.namedWindow("🖌 MNIST Draw Board", cv2.WINDOW_NORMAL)
cv2.setMouseCallback("🖌 MNIST Draw Board", draw)

print("🎨 用鼠标左键画数字")
print("✅ 按 's' 识别数字")
print("🧹 按 'c' 清空画布")
print("❌ 按 'q' 退出")

# ======================
# ✅ 5. 主循环
# ======================
while True:
    # 如果有预测结果,则叠加显示在画布上
    display = canvas.copy()
    if prediction_text is not None:
        cv2.putText(display, f"RES: {prediction_text}", (10, 60),
                    cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 3)

    cv2.imshow("🖌 MNIST Draw Board", display)
    key = cv2.waitKey(10) & 0xFF

    if key == ord('q'):
        break

    elif key == ord('c'):
        canvas[:] = 255
        prediction_text = None

    elif key == ord('s'):
        img_for_pred = canvas.copy()
        tensor = preprocess_image_for_mnist(img_for_pred)
        with torch.no_grad():
            outputs = model(tensor)
            _, predicted = torch.max(outputs, 1)
            prediction_text = str(predicted.item())
            print(f"🧠 Predicted digit: {prediction_text}")

cv2.destroyAllWindows()

效果还行