使用LeNet在MNIST数据集实现图像分类¶

1,773次阅读
没有评论

https://www.paddlepaddle.org.cn/documentation/docs/zh/practices/cv/image_classification.html

一、环境配置

本教程基于PaddlePaddle 2.3.0 编写,如果你的环境不是本版本,请先参考官网安装 PaddlePaddle 2.3.0。import paddle print(paddle.__version__) 2.3.0

二、数据加载

手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28×28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist 。

我们使用飞桨框架自带的 paddle.vision.datasets.MNIST 完成mnist数据集的加载。from paddle.vision.transforms import Compose, Normalize transform = Compose([Normalize(mean=[127.5], std=[127.5], data_format=’CHW’)]) # 使用transform对数据集做归一化 print(‘download training data and load training data’) train_dataset = paddle.vision.datasets.MNIST(mode=’train’, transform=transform) test_dataset = paddle.vision.datasets.MNIST(mode=’test’, transform=transform) print(‘load finished’)

取训练集中的一条数据看一下。import numpy as np import matplotlib.pyplot as plt train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1] train_data0 = train_data0.reshape([28,28]) plt.figure(figsize=(2,2)) plt.imshow(train_data0, cmap=plt.cm.binary) print(‘train_data0 label is: ‘ + str(train_label_0)) train_data0labelis: [5]

使用LeNet在MNIST数据集实现图像分类¶

三、组网

用paddle.nn下的API,如Conv2DMaxPool2DLinear完成LeNet的构建。import paddle import paddle.nn.functional as F classLeNet(paddle.nn.Layer): def__init__(self): super().__init__() self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2) self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2) self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1) self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2) self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120) self.linear2 = paddle.nn.Linear(in_features=120, out_features=84) self.linear3 = paddle.nn.Linear(in_features=84, out_features=10) defforward(self, x): x = self.conv1(x) x = F.relu(x) x = self.max_pool1(x) x = self.conv2(x) x = F.relu(x) x = self.max_pool2(x) x = paddle.flatten(x, start_axis=1,stop_axis=-1) x = self.linear1(x) x = F.relu(x) x = self.linear2(x) x = F.relu(x) x = self.linear3(x) return x

四、方式1:基于高层API,完成模型的训练与预测

通过paddle提供的Model 构建实例,使用封装好的训练与测试接口,快速完成模型训练与测试。

4.1 使用 Model.fit来训练模型

from paddle.metric import Accuracy model = paddle.Model(LeNet()) # 用Model封装模型 optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()) # 配置模型 model.prepare( optim, paddle.nn.CrossEntropyLoss(), Accuracy() ) # 训练模型 model.fit(train_dataset, epochs=2, batch_size=64, verbose=1 )

4.2 使用 Model.evaluate 来预测模型

model.evaluate(test_dataset, batch_size=64, verbose=1) Eval begin… step 157/157 [==============================] – loss: 4.2854e-04 – acc: 0.9841 – 7ms/step Eval samples: 10000 {‘loss’: [0.00042853763], ‘acc’: 0.9841}

方式一结束

以上就是方式一,可以快速、高效的完成网络模型训练与预测。

五、方式2:基于基础API,完成模型的训练与预测

5.1 模型训练

组网后,开始对模型进行训练,先构建train_loader,加载训练数据,然后定义train函数,设置好损失函数后,按batch加载数据,完成模型的训练。import paddle.nn.functional as F train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True) # 加载训练集 batch_size 设为 64deftrain(model): model.train() epochs = 2 optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()) # 用Adam作为优化函数for epoch in range(epochs): for batch_id, data in enumerate(train_loader()): x_data = data[0] y_data = data[1] predicts = model(x_data) loss = F.cross_entropy(predicts, y_data) # 计算损失 acc = paddle.metric.accuracy(predicts, y_data) loss.backward() if batch_id % 300 == 0: print(“epoch: {}, batch_id: {}, loss is: {}, acc is: {}”.format(epoch, batch_id, loss.numpy(), acc.numpy())) optim.step() optim.clear_grad() model = LeNet() train(model) epoch: 0, batch_id: 0, lossis: [2.9878871], accis: [0.140625] epoch: 0, batch_id: 300, lossis: [0.22775462], accis: [0.921875] epoch: 0, batch_id: 600, lossis: [0.06251755], accis: [0.984375] epoch: 0, batch_id: 900, lossis: [0.1097075], accis: [0.96875] epoch: 1, batch_id: 0, lossis: [0.04311676], accis: [0.984375] epoch: 1, batch_id: 300, lossis: [0.00150577], accis: [1.] epoch: 1, batch_id: 600, lossis: [0.08764459], accis: [0.96875] epoch: 1, batch_id: 900, lossis: [0.14419323], accis: [0.9375]

5.2 模型验证

训练完成后,需要验证模型的效果,此时,加载测试数据集,然后用训练好的模对测试集进行预测,计算损失与精度。test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64) # 加载测试数据集deftest(model): model.eval() batch_size = 64 for batch_id, data in enumerate(test_loader()): x_data = data[0] y_data = data[1] predicts = model(x_data) # 获取预测结果 loss = F.cross_entropy(predicts, y_data) acc = paddle.metric.accuracy(predicts, y_data) if batch_id % 20 == 0: print(“batch_id: {}, loss is: {}, acc is: {}”.format(batch_id, loss.numpy(), acc.numpy())) test(model) batch_id: 0, lossis: [0.01201783], accis: [1.] batch_id: 20, lossis: [0.09013407], accis: [0.984375] batch_id: 40, lossis: [0.07025866], accis: [0.96875] batch_id: 60, lossis: [0.08602518], accis: [0.984375] batch_id: 80, lossis: [0.00779913], accis: [1.] batch_id: 100, lossis: [0.00508764], accis: [1.] batch_id: 120, lossis: [0.00401443], accis: [1.] batch_id: 140, lossis: [0.03930391], accis: [0.96875]

方式二结束

以上就是方式二,通过底层API,可以清楚的看到训练和测试中的每一步过程。但是,这种方式比较复杂。因此,我们提供了训练方式一,使用高层API来完成模型的训练与预测。对比底层API,高层API能够更加快速、高效的完成模型的训练与测试。

六、总结

以上就是用LeNet对手写数字数据及MNIST进行分类。本示例提供了两种训练模型的方式,一种可以快速完成模型的组建与预测,非常适合新手用户上手。另一种则需要多个步骤来完成模型的训练,适合进阶用户使用。

正文完
可以使用微信扫码关注公众号(ID:xzluomor)
post-qrcode
 0
评论(没有评论)

文心AIGC

2023 年 12 月
 123
45678910
11121314151617
18192021222324
25262728293031
文心AIGC
文心AIGC
人工智能ChatGPT,AIGC指利用人工智能技术来生成内容,其中包括文字、语音、代码、图像、视频、机器人动作等等。被认为是继PGC、UGC之后的新型内容创作方式。AIGC作为元宇宙的新方向,近几年迭代速度呈现指数级爆发,谷歌、Meta、百度等平台型巨头持续布局
文章搜索
热门文章
潞晨尤洋:日常办公没必要上私有模型,这三类企业才需要 | MEET2026

潞晨尤洋:日常办公没必要上私有模型,这三类企业才需要 | MEET2026

潞晨尤洋:日常办公没必要上私有模型,这三类企业才需要 | MEET2026 Jay 2025-12-22 09...
“昆山杯”第二十七届清华大学创业大赛决赛举行

“昆山杯”第二十七届清华大学创业大赛决赛举行

“昆山杯”第二十七届清华大学创业大赛决赛举行 一水 2025-12-22 17:04:24 来源:量子位 本届...
MiniMax海螺视频团队首次开源:Tokenizer也具备明确的Scaling Law

MiniMax海螺视频团队首次开源:Tokenizer也具备明确的Scaling Law

MiniMax海螺视频团队首次开源:Tokenizer也具备明确的Scaling Law 一水 2025-12...
天下苦SaaS已久,企业级AI得靠「结果」说话

天下苦SaaS已久,企业级AI得靠「结果」说话

天下苦SaaS已久,企业级AI得靠「结果」说话 Jay 2025-12-22 13:46:04 来源:量子位 ...
最新评论
ufabet ufabet มีเกมให้เลือกเล่นมากมาย: เกมเดิมพันหลากหลาย ครบทุกค่ายดัง
tornado crypto mixer tornado crypto mixer Discover the power of privacy with TornadoCash! Learn how this decentralized mixer ensures your transactions remain confidential.
ดูบอลสด ดูบอลสด Very well presented. Every quote was awesome and thanks for sharing the content. Keep sharing and keep motivating others.
ดูบอลสด ดูบอลสด Pretty! This has been a really wonderful post. Many thanks for providing these details.
ดูบอลสด ดูบอลสด Pretty! This has been a really wonderful post. Many thanks for providing these details.
ดูบอลสด ดูบอลสด Hi there to all, for the reason that I am genuinely keen of reading this website’s post to be updated on a regular basis. It carries pleasant stuff.
Obrazy Sztuka Nowoczesna Obrazy Sztuka Nowoczesna Thank you for this wonderful contribution to the topic. Your ability to explain complex ideas simply is admirable.
ufabet ufabet Hi there to all, for the reason that I am genuinely keen of reading this website’s post to be updated on a regular basis. It carries pleasant stuff.
ufabet ufabet You’re so awesome! I don’t believe I have read a single thing like that before. So great to find someone with some original thoughts on this topic. Really.. thank you for starting this up. This website is something that is needed on the internet, someone with a little originality!
ufabet ufabet Very well presented. Every quote was awesome and thanks for sharing the content. Keep sharing and keep motivating others.
热评文章
小米大模型“杀”进第一梯队:代码能力开源第一,智商情商全在线

小米大模型“杀”进第一梯队:代码能力开源第一,智商情商全在线

小米大模型“杀”进第一梯队:代码能力开源第一,智商情商全在线 克雷西 2025-12-18 08:57:11 ...
ISC.AI 2025创新百强颁奖典礼落幕,首发智能体专家驱动产业升级

ISC.AI 2025创新百强颁奖典礼落幕,首发智能体专家驱动产业升级

ISC.AI 2025创新百强颁奖典礼落幕,首发智能体专家驱动产业升级 量子位的朋友们 2025-12-18 ...
具身智能的数据难题,终于有了可规模化的解法

具身智能的数据难题,终于有了可规模化的解法

具身智能的数据难题,终于有了可规模化的解法 思邈 2025-12-18 14:20:44 来源:量子位 成立4...
医生版ChatGPT,估值120亿美元

医生版ChatGPT,估值120亿美元

医生版ChatGPT,估值120亿美元 Jay 2025-12-18 13:45:12 来源:量子位 Jay ...
国产AI芯片看两个指标:模型覆盖+集群规模能力 | 百度智能云王雁鹏@MEET2026

国产AI芯片看两个指标:模型覆盖+集群规模能力 | 百度智能云王雁鹏@MEET2026

国产AI芯片看两个指标:模型覆盖+集群规模能力 | 百度智能云王雁鹏@MEET2026 西风 2025-12-...