首页
搜索 搜索
当前位置:聚焦 > 正文

百度飞桨(PaddlePaddle)-数字识别|全球最新

2023-05-10 04:49:07 博客园

手写数字识别任务 用于对 0 ~ 9 的十类数字进行分类,即输入手写数字的图片,可识别出这个图片中的数字。


(相关资料图)

使用 pip 工具安装 matplotlib 和 numpy

python -m pip install matplotlib numpy -i https://mirror.baidu.com/pypi/simplepython -m pip install paddlepaddle==2.4.2 -i https://pypi.tuna.tsinghua.edu.cn/simple

D:\OpenSource\PaddlePaddle>python -m pip install matplotlib numpy -i https://mirror.baidu.com/pypi/simpleLooking in indexes: https://mirror.baidu.com/pypi/simpleCollecting matplotlib  Downloading https://mirror.baidu.com/pypi/packages/92/01/2c04d328db6955d77f8f60c17068dde8aa66f153b2c599ca03c2cb0d5567/matplotlib-3.7.1-cp38-cp38-win_amd64.whl (7.6 MB)     |████████████████████████████████| 7.6 MB ...Requirement already satisfied: numpy in d:\program files\python38\lib\site-packages (1.24.3)Collecting packaging>=20.0  Downloading https://mirror.baidu.com/pypi/packages/ab/c3/57f0601a2d4fe15de7a553c00adbc901425661bf048f2a22dfc500caf121/packaging-23.1-py3-none-any.whl (48 kB)     |████████████████████████████████| 48 kB 1.2 MB/sCollecting cycler>=0.10  Downloading https://mirror.baidu.com/pypi/packages/5c/f9/695d6bedebd747e5eb0fe8fad57b72fdf25411273a39791cde838d5a8f51/cycler-0.11.0-py3-none-any.whl (6.4 kB)Requirement already satisfied: pillow>=6.2.0 in d:\program files\python38\lib\site-packages (from matplotlib) (9.5.0)Collecting python-dateutil>=2.7  Downloading https://mirror.baidu.com/pypi/packages/36/7a/87837f39d0296e723bb9b62bbb257d0355c7f6128853c78955f57342a56d/python_dateutil-2.8.2-py2.py3-none-any.whl (247 kB)     |████████████████████████████████| 247 kB ...Collecting importlib-resources>=3.2.0  Downloading https://mirror.baidu.com/pypi/packages/38/71/c13ea695a4393639830bf96baea956538ba7a9d06fcce7cef10bfff20f72/importlib_resources-5.12.0-py3-none-any.whl (36 kB)Collecting fonttools>=4.22.0  Downloading https://mirror.baidu.com/pypi/packages/16/07/1c7547e27f559ec078801d522cc4d5127cdd4ef8e831c8ddcd9584668a07/fonttools-4.39.3-py3-none-any.whl (1.0 MB)     |████████████████████████████████| 1.0 MB ...Collecting pyparsing>=2.3.1  Downloading https://mirror.baidu.com/pypi/packages/6c/10/a7d0fa5baea8fe7b50f448ab742f26f52b80bfca85ac2be9d35cdd9a3246/pyparsing-3.0.9-py3-none-any.whl (98 kB)     |████████████████████████████████| 98 kB 862 kB/sCollecting contourpy>=1.0.1  Downloading https://mirror.baidu.com/pypi/packages/08/ce/9bfe9f028cb5a8ee97898da52f4905e0e2d9ca8203ffdcdbe80e1769b549/contourpy-1.0.7-cp38-cp38-win_amd64.whl (162 kB)     |████████████████████████████████| 162 kB ...Collecting kiwisolver>=1.0.1  Downloading https://mirror.baidu.com/pypi/packages/4f/05/59b34e788bf2b45c7157c3d898d567d28bc42986c1b6772fb1af329eea0d/kiwisolver-1.4.4-cp38-cp38-win_amd64.whl (55 kB)     |████████████████████████████████| 55 kB 784 kB/sCollecting zipp>=3.1.0  Downloading https://mirror.baidu.com/pypi/packages/5b/fa/c9e82bbe1af6266adf08afb563905eb87cab83fde00a0a08963510621047/zipp-3.15.0-py3-none-any.whl (6.8 kB)Requirement already satisfied: six>=1.5 in d:\program files\python38\lib\site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)Installing collected packages: zipp, python-dateutil, pyparsing, packaging, kiwisolver, importlib-resources, fonttools, cycler, contourpy, matplotlibSuccessfully installed contourpy-1.0.7 cycler-0.11.0 fonttools-4.39.3 importlib-resources-5.12.0 kiwisolver-1.4.4 matplotlib-3.7.1 packaging-23.1 pyparsing-3.0.9 python-dateutil-2.8.2 zipp-3.15.0WARNING: You are using pip version 21.1.1; however, version 23.1.2 is available.You should consider upgrading via the "D:\Program Files\Python38\python.exe -m pip install --upgrade pip" command.D:\OpenSource\PaddlePaddle>
创建 DigitalRecognition.py

官网代码少了 plt.show()# 要加上这句,才会显示图片

import paddleimport numpy as npfrom paddle.vision.transforms import Normalizetransform = Normalize(mean=[127.5], std=[127.5], data_format="CHW")# 下载数据集并初始化 DataSet"""飞桨在 paddle.vision.datasets 下内置了计算机视觉(Computer Vision,CV)领域常见的数据集,如 MNIST、Cifar10、Cifar100、FashionMNIST 和 VOC2012 等。在本任务中,先后加载了 MNIST 训练集(mode="train")和测试集(mode="test"),训练集用于训练模型,测试集用于评估模型效果。"""train_dataset = paddle.vision.datasets.MNIST(mode="train", transform=transform)test_dataset = paddle.vision.datasets.MNIST(mode="test", transform=transform)# 打印数据集里图片数量 60000 images in train_dataset, 10000 images in test_dataset# print("{} images in train_dataset, {} images in test_dataset".format(len(train_dataset), len(test_dataset)))# 模型组网并初始化网络lenet = paddle.vision.models.LeNet(num_classes=10)model = paddle.Model(lenet)# 模型训练的配置准备,准备损失函数,优化器和评价指标model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),              paddle.nn.CrossEntropyLoss(),              paddle.metric.Accuracy())# 模型训练model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)# 模型评估model.evaluate(test_dataset, batch_size=64, verbose=1)# 保存模型model.save("./output/mnist")# 加载模型model.load("output/mnist")# 从测试集中取出一张图片img, label = test_dataset[0]# 将图片shape从1*28*28变为1*1*28*28,增加一个batch维度,以匹配模型输入格式要求img_batch = np.expand_dims(img.astype("float32"), axis=0)# 执行推理并打印结果,此处predict_batch返回的是一个list,取出其中数据获得预测结果out = model.predict_batch(img_batch)[0]pred_label = out.argmax()print("true label: {}, pred label: {}".format(label[0], pred_label))# 可视化图片from matplotlib import pyplot as pltplt.imshow(img[0])plt.show()  # 要加上这句,才会显示图片
PyCharm运行(推荐,有错误能显示出来)

Python MatplotlibDeprecationWarning Matplotlib 3.6 and will be removed two minor releases laterFile -> Settings -> Tools -> Python Scientific -> 取消 Show plots in tool window,取消后,将不会看到红字警告提示

CMD 运行

D:\OpenSource\PaddlePaddle>python DigitalRecognition.py

如果碰到下列错误,需要加上 plt.show()Python MatplotlibDeprecationWarning Matplotlib 3.6 and will be removed two minor releases later

MatplotlibDeprecationWarning: Support for FigureCanvases without a required_interactive_framework attribute was deprecated in Matplotlib 3.6 and will be removed two minor releases later.  plt.imshow(img[0])
数据集定义与加载

飞桨在 paddle.vision.datasets 下内置了计算机视觉(Computer Vision,CV)领域常见的数据集,如 MNIST、Cifar10、Cifar100、FashionMNIST 和 VOC2012 等。在本任务中,先后加载了 MNIST 训练集(mode="train")和测试集(mode="test"),训练集用于训练模型,测试集用于评估模型效果。飞桨除了内置了 CV 领域常见的数据集,还在 paddle.text 下内置了自然语言处理(Natural Language Processing,NLP)领域常见的数据集,并提供了自定义数据集与加载功能的 paddle.io.Dataset 和 paddle.io.DataLoader API,详细使用方法可参考『数据集定义与加载』 章节。

另外在 paddle.vision.transforms 下提供了一些常用的图像变换操作,如对图像的翻转、裁剪、调整亮度等处理,可实现数据增强,以增加训练样本的多样性,提升模型的泛化能力。本任务在初始化 MNIST 数据集时通过 transform 字段传入了 Normalize 变换对图像进行归一化,对图像进行归一化可以加快模型训练的收敛速度。该功能的具体使用方法可参考『数据预处理』 章节。

数据集定义与加载数据预处理模型组网

https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/beginner/quick_start_cn.html#moxingzuwang

模型训练与评估

https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/beginner/quick_start_cn.html#moxingxunlianyupinggu

模型推理

https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/beginner/quick_start_cn.html#moxingtuili