博客
关于我
(三)Tensorflow的逻辑回归模型
阅读量:797 次
发布时间:2023-03-28

本文共 2331 字,大约阅读时间需要 7 分钟。

使用逻辑回归模型进行MNIST数字识别是一个经典的机器学习问题。本文将详细介绍模型的构建过程、训练方法以及结果分析。

代码优化与解释

首先,我们需要加载MNIST数据集,并进行数据预处理。代码如下:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
def logistic_regression():
# 加载数据集
mnist = input_data.read_data_sets(r'C:\Users\Administrator\Desktop\AI_project\tensorflow\MNIST_data', one_hot=True)
# 定义批次大小
batch_size = 128
# 定义输入占位符
x = tf.placeholder(tf.float32, [batch_size, 784], name="x_data")
y = tf.placeholder(tf.int32, [batch_size, 10], name="y_data")
# 初始化权重和偏置
w = tf.Variable(tf.random_normal([784, 10], stddev=0.1))
b = tf.Variable(tf.zeros([10]))
# 计算预测结果
val = tf.add(tf.matmul(x, w), b)
# 定义损失函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=val))
# 定义优化器并训练模型
optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(loss)
# 预测准确率
pred = tf.nn.softmax(val)
acc = tf.reduce_mean(tf.cast(tf.argmax(pred, axis=1) == tf.argmax(y, axis=1), tf.float32))
# 初始化变量
init_op = tf.global_variables_initializer()
# 训练和测试过程
with tf.Session() as sess:
sess.run(init_op)
# 训练过程
n_batch = math.ceil(mnist.train.num_examples / batch_size)
for i in range(50):
loss_total = 0
for _ in range(n_batch):
x_input, y_input = mnist.train.next_batch(batch_size)
_, loss_val = sess.run([train_op, loss], feed_dict={x: x_input, y: y_input})
loss_total += loss_val
avg_loss = loss_total / n_batch
print(f"Interation {i}, Loss: {avg_loss}")
# 测试过程
acc_total = 0
for _ in range(n_batch):
x_input, y_input = mnist.test.next_batch(batch_size)
acc = sess.run(acc, feed_dict={x: x_input, y: y_input})
acc_total += acc
avg_acc = acc_total / n_batch
print(f"Accuracy: {avg_acc}")

优化点总结

  • 数据处理优化:确保数据加载和批次划分的正确性,避免内存泄漏。
  • 优化器选择:Adam优化器通常表现优异,但可以尝试其他优化器以进行对比。
  • 计算图优化:使用更高效的计算图结构,减少内存占用和加速训练。
  • 损失函数优化:确保损失函数的维度和计算方式正确,避免维度不匹配。
  • 训练过程改进:单独执行测试过程,避免资源竞争,提高效率。
  • 结果展示

    训练过程中,损失函数值逐渐下降,准确率逐渐上升,显示出模型在训练过程中的有效性。随着训练次数的增加,模型性能得到显著提升。

    总结

    通过以上优化,模型在MNIST数据集上的表现更加稳定和可靠。建议在实际应用中根据具体情况调整批次大小和优化器参数,以获得最佳效果。

    转载地址:http://wphfk.baihongyu.com/

    你可能感兴趣的文章
    Objective-C实现找出三角形从上到下的最大路径算法(附完整源码)
    查看>>
    Objective-C实现找出买卖股票的最大利润算法(附完整源码)
    查看>>
    Objective-C实现找出二维数组中的鞍点(附完整源码)
    查看>>
    Objective-C实现找出由两个 3 位数字的乘积构成的最大回文数的算法 (附完整源码)
    查看>>
    Objective-C实现找出矩阵的最大最小值(附完整源码)
    查看>>
    Objective-C实现找到一个数字数组的中值算法(附完整源码)
    查看>>
    Objective-C实现找到具有 500 个除数的第一个三角形数算法(附完整源码)
    查看>>
    Objective-C实现找到最近的点对之间的距离算法(附完整源码)
    查看>>
    Objective-C实现抓包实例(附完整源码)
    查看>>
    Objective-C实现抽签抓阄(附完整源码)
    查看>>
    Objective-C实现抽象工厂模式(附完整源码)
    查看>>
    Objective-C实现拉格朗日插值法(附完整源码)
    查看>>
    Objective-C实现拓扑排序算法(附完整源码)
    查看>>
    Objective-C实现拷贝二进制文件(附完整源码)
    查看>>
    Objective-C实现指定内存空间获取时间的函数(附完整源码)
    查看>>
    Objective-C实现按位倒序(附完整源码)
    查看>>
    Objective-C实现按位运算符乘以无符号数multiplyUnsigned算法(附完整源码)
    查看>>
    Objective-C实现排队叫号系统(附完整源码)
    查看>>
    Objective-C实现控制NRP8S功率计读取功率 (附完整源码)
    查看>>
    Objective-C实现控制程控电源2306读取电流 (附完整源码)
    查看>>