博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TensorFlow-2 数字识别
阅读量:4147 次
发布时间:2019-05-25

本文共 2149 字,大约阅读时间需要 7 分钟。

步骤准备:

 
1  数据准备:可以直接从
'/tmp/tensorflow/mnist/input_data'
中获取
  2 创建模型:x W b
  3 定义损失函数和优化形式(采用softmax分类器)
  4 启动会话(Session)
  5 训练模型
  6 测试并计算输出
  
其中softmax分类器的原理请参考:

程序:

#TensorFlow手写数目识别import tensorflow.examples.tutorials.mnist as inputDataimport tensorflow as tf#导入命令行解析模块import argparseimport sys#Import datadata_dir='/tmp/tensorflow/mnist/input_data'mnist=inputData.input_data.read_data_sets(data_dir,one_hot=True)#Create model (x,W,b, Loss)# The image size is of 28*28#None means any lengthx=tf.placeholder(tf.float32,[None, 784])#W bW=tf.Variable(tf.zeros([784,10]))b=tf.Variable(tf.zeros([10]))#Predicty=tf.matmul(x,W)+b#Define loss and optimizery_=tf.placeholder(tf.float32,[None,10])  # The raw formulation of cross-entropy,  #  #   tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),  #                                 reduction_indices=[1]))                # tf.reduce_sum adds the elements in the second dimension of y,                # due to the reduction_indices=[1] parameter.                # tf.reduce_mean computes the mean over all the examples in the batch.  #  # can be numerically unstable.  #  # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw  # outputs of 'y', and then average across the batch.loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_,logits=y))optimizer=tf.train.GradientDescentOptimizer(0.5).minimize(loss)#Start Sessionsess=tf.InteractiveSession()#initial variablestf.global_variables_initializer().run()#Train --stochastic trainingfor _ in range(100):    # 100 data points are randomly selected from the training data set    batch_x,batch_y=mnist.train.next_batch(100)    #trraining    sess.run(optimizer,feed_dict={x:batch_x,y_:batch_y})#Test#tf.equal:  check if the pridction equals the label, if equal, True; else, False#tf.argmax: obtain the opsition of the max value in row(1)correct_prdiction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))#tf.cast: convert correct_prdiction to tf.float32accuracy=tf.reduce_mean(tf.cast(correct_prdiction,tf.float32))print(sess.run(accuracy,feed_dict={x:mnist.test.images,                                   y_:mnist.test.labels}))
运行结果:0.8868

转载地址:http://qmjti.baihongyu.com/

你可能感兴趣的文章
java杂记
查看>>
RunTime.getRuntime().exec()
查看>>
Oracle 分组排序函数
查看>>
删除weblogic 域
查看>>
VMware Workstation 14中文破解版下载(附密钥)(笔记)
查看>>
日志框架学习
查看>>
日志框架学习2
查看>>
SVN-无法查看log,提示Want to go offline,时间显示1970问题,error主要是 url中 有一层的中文进行了2次encode
查看>>
NGINX
查看>>
Qt文件夹选择对话框
查看>>
1062 Talent and Virtue (25 分)
查看>>
1061 Dating (20 分)
查看>>
1060 Are They Equal (25 分)
查看>>
83. Remove Duplicates from Sorted List(easy)
查看>>
88. Merge Sorted Array(easy)
查看>>
leetcode刷题191 位1的个数 Number of 1 Bits(简单) Python Java
查看>>
leetcode刷题198 打家劫舍 House Robber(简单) Python Java
查看>>
NG深度学习第一门课作业2 通过一个隐藏层的神经网络来做平面数据的分类
查看>>
leetcode刷题234 回文链表 Palindrome Linked List(简单) Python Java
查看>>
NG深度学习第二门课作业1-1 深度学习的实践
查看>>