用Python进行AI数据分析进阶教程68:
训练神经网络并进行预测
关键词:神经网络、MNIST、TensorFlow、模型训练、预测分析
摘要:本文以MNIST手写数字识别为例,详细讲解了使用Python和TensorFlow/Keras库训练神经网络并进行预测的过程。内容涵盖数据准备、模型构建、编译、训练、评估与预测六大关键步骤,并强调了数据归一化、过拟合处理及超参数调优等注意事项。通过示例代码展示了从加载数据、预处理、搭建网络结构到最终预测的完整流程,并对训练输出结果进行了分析,帮助理解模型性能评估指标如损失值、准确率等。全文旨在提供一个端到端的神经网络实践指南。
👉 欢迎订阅🔗
《用Python进行AI数据分析进阶教程》专栏
《AI大模型应用实践进阶教程》专栏
《Python编程知识集锦》专栏
《字节跳动旗下AI制作抖音视频》专栏
《智能辅助驾驶》专栏
《工具软件及IT技术集锦》专栏
下面将详细讲解使用 Python 进行神经网络的训练与预测,以 MNIST 手写数字识别为例,借助TensorFlow和Keras库实现。
一、关键点
- 数据准备:加载、划分并预处理数据,例如归一化操作。
- 模型构建:搭建神经网络的结构,确定层数、神经元数量、激活函数等。
- 模型编译:指定损失函数、优化器和评估指标。
- 模型训练:利用训练数据对模型进行训练。
- 模型评估:使用测试数据评估模型的性能。
- 模型预测:使用训练好的模型进行预测。
二、注意点
- 数据归一化:归一化能加快模型收敛速度,避免梯度消失或爆炸。
- 过拟合问题:可采用正则化、Dropout 等方法防止过拟合。
- 超参数调整:如学习率、批次大小、训练轮数等,需要进行调优。
三、示例代码
Python脚本
# 导入 TensorFlow 库,这是构建和训练神经网络的核心框架
import tensorflow as tf
# 从 Keras 模块中导入 MNIST 数据集加载函数,用于获取手写数字图像数据
from tensorflow.keras.datasets import mnist
# 从 Keras 模块中导入 Sequential 模型类,用于创建按顺序堆叠的神经网络层
from tensorflow.keras.models import Sequential
# 从 Keras 模块中导入 Dense 和 Flatten 层类
# Dense:全连接层,每个神经元与前一层所有神经元相连
# Flatten:展平层,将多维数据转换为一维向量
from tensorflow.keras.layers import Dense, Flatten
# 从 Keras 模块中导入 to_categorical 函数,用于将整数标签转换为 one-hot 编码格式
from tensorflow.keras.utils import to_categorical
# 导入 NumPy 库,用于进行数值计算和数组操作
import numpy as np
# 加载 MNIST 数据集,这是一个包含手写数字图像的标准数据集
# 该函数返回两个元组:训练集和测试集
# 每个元组包含图像数据和对应的标签
(train_images, train_labels), \
(test_images, test_labels) = mnist.load_data()
# 对训练图像进行数据预处理,将像素值从 0-255 范围归一化到 0-1 范围
# 这样做可以加快模型收敛速度并提高训练稳定性
train_images = train_images / 255.0
# 对测试图像进行相同的数据预处理,确保训练和测试数据在同一数值范围内
test_images = test_images / 255.0
# 将训练标签从整数格式转换为 one-hot 编码格式
# 例如:标签 3 转换为 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
train_labels = to_categorical(train_labels)
# 将测试标签从整数格式转换为 one-hot 编码格式
test_labels = to_categorical(test_labels)
# 创建一个顺序模型,用于构建神经网络
model = Sequential([
# 添加展平层,将 28x28 的二维图像数据转换为 784 维的一维向量
# input_shape 参数指定输入数据的形状,不包含样本数量维度
Flatten(input_shape=(28, 28)),
# 添加全连接隐藏层,包含 128 个神经元
# 使用 ReLU 激活函数,引入非线性特性,增强模型的表达能力
Dense(128, activation='relu'),
# 添加输出层,包含 10 个神经元,对应 0-9 这 10 个数字类别
# 使用 softmax 激活函数,将输出转换为概率分布,所有输出值之和为 1
Dense(10, activation='softmax')
])
# 编译模型,配置训练过程所需的各种参数
# optimizer='adam':使用 Adam 优化器,它能自适应调整学习率
# loss='categorical_crossentropy':使用分类交叉熵损失函数,适用于多分类问题
# metrics=['accuracy']:监控准确率指标,用于评估模型性能
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
# 开始训练模型
# train_images 和 train_labels:训练数据和对应的标签
# epochs=5:训练轮数,模型将遍历整个训练集 5 次
# batch_size=64:每个批次包含 64 个样本,模型参数会按批次更新
model.fit(
train_images,
train_labels,
epochs=5,
batch_size=64
)
# 在测试集上评估已训练好的模型性能
# 返回测试损失值和测试准确率
test_loss, test_accuracy = model.evaluate(
test_images,
test_labels
)
# 打印模型在测试集上的准确率
# 使用 f-string 格式化字符串,显示测试准确率的数值
print(f"Test accuracy: {test_accuracy}")
# 使用训练好的模型对测试集中前 5 个样本进行预测
# 返回每个样本属于各个类别的概率分布
predictions = model.predict(test_images[:5])
# 遍历前 5 个预测结果
for i in range(5):
# 使用 np.argmax 函数找到概率分布中最大值的索引
# 该索引即为模型预测的数字类别
predicted_digit = np.argmax(predictions[i])
# 打印每个样本的预测结果
# i + 1 是为了使样本编号从 1 开始而不是从 0 开始
print(f"Predicted digit for sample {i + 1}: {predicted_digit}")
输出 / 打印结果分析
(1)模型训练阶段
在训练阶段,模型会输出每个轮次的训练信息,示例如下:
plaintext
Epoch 1/5
938/938 [==============================] - 3s 3ms/step - loss: 0.2552 - accuracy: 0.9246
Epoch 2/5
938/938 [==============================] - 3s 3ms/step - loss: 0.1058 - accuracy: 0.9682
Epoch 3/5
938/938 [==============================] - 3s 3ms/step - loss: 0.0723 - accuracy: 0.9776
Epoch 4/5
938/938 [==============================] - 3s 3ms/step - loss: 0.0535 - accuracy: 0.9836
Epoch 5/5
938/938 [==============================] - 3s 3ms/step - loss: 0.0415 - accuracy: 0.9871
- Epoch:表示训练的轮次,这里总共训练 5 个轮次。
- 938/938:表示每个轮次中训练的批次数量,训练集总共有 60000 个样本,每个批次 64 个样本,所以大约有 938 个批次。
- loss:表示训练集的损失值,损失值越小说明模型的预测结果与真实标签越接近。
- accuracy:表示训练集的准确率,即模型正确预测的样本数占总样本数的比例。
(2)模型评估阶段
plaintext
313/313 [==============================] - 1s 2ms/step - loss: 0.0713 - accuracy: 0.9788
Test accuracy: 0.9788
- 313/313:表示测试集的批次数量,测试集总共有 10000 个样本,每个批次 64 个样本,所以大约有 313 个批次。
- loss:表示测试集的损失值。
- accuracy:表示测试集的准确率,这里的准确率是 0.9788,说明模型在测试集上的表现较好。
(3)模型预测阶段
plaintext
Predicted digit: 7
Predicted digit: 2
Predicted digit: 1
Predicted digit: 0
Predicted digit: 4
- 这是模型对测试集前 5 个样本的预测结果,分别预测为 7、2、1、0、4。实际的预测结果可能会因模型训练的随机性而有所不同。
四、重点语句解读
1. 数据加载与结构理解
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
解读分析:
- MNIST数据集是机器学习领域的"Hello World",包含70,000张28×28像素的手写数字图像
- 数据自动分为训练集(60,000张)和测试集(10,000张)
- 图像数据形状为(样本数, 28, 28),标签为一维数组,值域为0-9
2. 数据预处理的核心操作
train_images = train_images / 255.0
test_images = test_images / 255.0
解读分析:
- 像素值归一化是深度学习的关键预处理步骤
- 将0-255的整数值转换为0-1的浮点数,符合神经网络对输入数据范围的偏好
- 归一化后梯度更新更稳定,避免了因数值过大导致的梯度爆炸问题
3. 标签编码转换
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
解读分析:
- 多分类问题需要将整数标签转换为one-hot编码
- 例如:数字"3" → [0,0,0,1,0,0,0,0,0,0]
- 这种编码方式使模型输出与标签在数学上可以直接比较计算损失
4. 神经网络架构设计
model = Sequential([
Flatten(input_shape=(28, 28)),
Dense(128, activation='relu'),
Dense(10, activation='softmax')
])
解读分析:
- 输入层:Flatten将28×28图像展平为784维向量,保持信息完整性
- 隐藏层:128个神经元的全连接层,ReLU激活函数引入非线性,使网络能学习复杂模式
- 输出层:10个神经元对应10个类别,Softmax确保输出为概率分布
5. 模型编译参数选择
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
解读分析:
- Adam优化器:结合了动量和自适应学习率的优点,收敛速度快且稳定
- 分类交叉熵损失:专门用于多分类问题,当预测错误时损失值大,正确时损失值小
- 准确率监控:直观反映模型分类正确的比例
6. 训练过程控制
model.fit(
train_images,
train_labels,
epochs=5,
batch_size=64
)
解读分析:
- epochs=5:完整遍历训练集5次,在小数据集上通常足够
- batch_size=64:平衡了内存使用和训练稳定性,批次太小梯度噪声大,太大收敛慢
- 批量训练使参数更新更频繁,加速收敛过程
7. 模型评估与预测
test_loss, test_accuracy = model.evaluate(test_images, test_labels)
predictions = model.predict(test_images[:5])
解读分析:
- evaluate():在未参与训练的测试集上评估,真实反映模型泛化能力
- predict():输出概率分布,每个样本对应10个类别的概率值
- 测试准确率是衡量模型是否过拟合的重要指标
8. 预测结果解析
predicted_digit = np.argmax(predictions[i])
解读分析:
- np.argmax()找到概率分布中最大值的索引,即模型最确信的分类结果
- 这是从概率空间到具体分类标签的关键转换步骤
- 体现了神经网络"软分类"到"硬分类"的决策过程
这些重点语句构成了完整的深度学习工作流程:数据准备→模型构建→训练→评估→预测,每个环节都体现了深度学习的核心思想和最佳实践。
——The END——
🔗 欢迎订阅专栏
| 序号 | 专栏名称 | 说明 |
|---|---|---|
| 1 | 用Python进行AI数据分析进阶教程 | 《用Python进行AI数据分析进阶教程》专栏 |
| 2 | AI大模型应用实践进阶教程 | 《AI大模型应用实践进阶教程》专栏 |
| 3 | Python编程知识集锦 | 《Python编程知识集锦》专栏 |
| 4 | 字节跳动旗下AI制作抖音视频 | 《字节跳动旗下AI制作抖音视频》专栏 |
| 5 | 智能辅助驾驶 | 《智能辅助驾驶》专栏 |
| 6 | 工具软件及IT技术集锦 | 《工具软件及IT技术集锦》专栏 |
👉 关注我 @理工男大辉郎 获取实时更新
欢迎关注、收藏或转发。
敬请关注 我的
微信搜索公众号:cnFuJH
CSDN博客:理工男大辉郎
抖音号:31580422589

3万+

被折叠的 条评论
为什么被折叠?



