chat
如果我想实现连续多个数字的图片识别,要如何实现?
如何解决这个思路呢?
我们可以不修改单个 CNN 数字解析,然后实现这个功能。
要支持**“一张图片中包含多个数字”(比如 123),我们可以不改模型结构,
而是在推理阶段**(inference)做「多数字分割 + 批量识别」。
也就是说,CNN 依然是识别“单个数字”,我们让它对每个数字小块依次预测。
这样你不需要重新训练新模型,也能直接识别「连续多个数字」。
完整实现:支持多数字识别(改进版)
下面这份是你代码的直接改进版,核心改动只有两个地方:
识别前,用 OpenCV 找轮廓 自动切割出每个数字;
逐个送入 CNN 模型识别后,拼接成结果字符串。
import cv2
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
# ======================
# ✅ 1. 定义 CNN 模型结构
# ======================
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.dropout = nn.Dropout(0.25)
self.fc1 = nn.Linear(64 * 14 * 14, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = self.pool(x)
x = self.dropout(x)
x = x.view(-1, 64 * 14 * 14)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# ======================
# ✅ 2. 加载模型
# ======================
model = CNN()
model.load_state_dict(torch.load("model/mnist_cnn.pth", map_location="cpu"))
model.eval()
# ======================
# ✅ 3. 图像预处理
# ======================
def preprocess_image_for_mnist(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
gray = cv2.bitwise_not(gray)
pil = Image.fromarray(gray)
try:
resample = Image.Resampling.LANCZOS
except AttributeError:
resample = Image.ANTIALIAS
pil = pil.resize((28, 28), resample)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
tensor = transform(pil).unsqueeze(0)
return tensor
# ======================
# ✅ 4. 分割多个数字
# ======================
def segment_digits(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
gray = cv2.bitwise_not(gray)
_, thresh = cv2.threshold(gray, 30, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
boxes = [cv2.boundingRect(c) for c in contours]
boxes = sorted(boxes, key=lambda x: x[0]) # 按 x 坐标从左到右排序
digits = []
for (x, y, w, h) in boxes:
if w > 10 and h > 10: # 过滤噪点
digit_img = gray[y:y+h, x:x+w]
# 在方框内补白背景,让比例居中
size = max(w, h)
square = np.ones((size, size), dtype=np.uint8) * 0
offset_x = (size - w) // 2
offset_y = (size - h) // 2
square[offset_y:offset_y + h, offset_x:offset_x + w] = digit_img
square = cv2.bitwise_not(square)
digits.append(square)
return digits
# ======================
# ✅ 5. 绘图窗口
# ======================
canvas = np.ones((280, 560, 3), dtype=np.uint8) * 255
drawing = False
last_point = None
prediction_text = None
def draw(event, x, y, flags, param):
global drawing, last_point
if event == cv2.EVENT_LBUTTONDOWN:
drawing = True
last_point = (x, y)
elif event == cv2.EVENT_MOUSEMOVE and drawing:
cv2.line(canvas, last_point, (x, y), (0, 0, 0), 12)
last_point = (x, y)
elif event == cv2.EVENT_LBUTTONUP:
drawing = False
last_point = None
cv2.namedWindow("🖌 MNIST Multi-Digit Board", cv2.WINDOW_NORMAL)
cv2.setMouseCallback("🖌 MNIST Multi-Digit Board", draw)
print("🎨 用鼠标左键画多个数字(例如:1234)")
print("✅ 按 's' 识别所有数字")
print("🧹 按 'c' 清空画布")
print("❌ 按 'q' 退出")
# ======================
# ✅ 6. 主循环
# ======================
while True:
display = canvas.copy()
if prediction_text is not None:
cv2.putText(display, f"RES: {prediction_text}", (10, 60),
cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 3)
cv2.imshow("🖌 MNIST Multi-Digit Board", display)
key = cv2.waitKey(10) & 0xFF
if key == ord('q'):
break
elif key == ord('c'):
canvas[:] = 255
prediction_text = None
elif key == ord('s'):
img_for_pred = canvas.copy()
digits = segment_digits(img_for_pred)
if not digits:
print("⚠️ 没检测到数字,请重试!")
continue
preds = []
for d in digits:
tensor = preprocess_image_for_mnist(cv2.cvtColor(d, cv2.COLOR_GRAY2BGR))
with torch.no_grad():
outputs = model(tensor)
_, predicted = torch.max(outputs, 1)
preds.append(str(predicted.item()))
prediction_text = "".join(preds)
print(f"🧠 Predicted digits: {prediction_text}")
cv2.destroyAllWindows()
效果
效果还行,下一步可以直接试一下重新训练一个支持多个字符的模型
