博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
双向循环神经网络(BiRNN)MNIST手写体识别(tensorflow)
阅读量:2135 次
发布时间:2019-04-30

本文共 3115 字,大约阅读时间需要 10 分钟。

 

tf.contrib.rnn.static_bidirectional_rnn() 函数

#1、将mnist每个字体的像素序列当做时间序列,喂给网络#2、实现一个双向RNN网络,其中的cell可以是LSTM,也可以是GRUimport tensorflow as tfimport numpy as nplearning_rate = 0.01    # 因为优化器是adam,所以学习速率较低max_samples = 400000    # 最大训练样本数为40万batch_size = 128display_step = 10       # 每间隔10次训练,就展示一次训练情况n_input = 28 # 图像的宽度为28,因此设置输入为28n_steps = 28 # 图像的高度为28,因此设置LSTM的展开步数(unrolled steps of LSTM)也设置为28n_hidden = 256 # 定义一个方向的cell的数量n_classes = 10 # 0-9,共有10个分类.from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=True)# endregion# region 定义softmax层的权重# 因为要合成两个LSTM的输入,所以第一个维度是2*n_hiddenweights = {    # Hidden layer weights => 2*n_hidden because of foward + backward cells    'out': tf.Variable(tf.random_normal([2*n_hidden, n_classes]))}biases = {    'out': tf.Variable(tf.random_normal([n_classes]))}# endregion# region 构建计算图# x是一个二维结构,但是和卷积网络中的空间二维结构不同,# 这里的二维被理解成第一个维度是时间序列n_steps,第二维度是每个时间点下的数据n_inputx = tf.placeholder("float", [None, n_steps, n_input])y = tf.placeholder("float", [None, n_classes])# 将x拆成一个长度为n_steps的列表,每个元素tensor的尺寸为[batch_size,n_input]x_unstack = tf.unstack(x, axis=1)# lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)lstm_fw_cell = tf.contrib.rnn.GRUCell(n_hidden)# lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)lstm_bw_cell = tf.contrib.rnn.GRUCell(n_hidden)outputs, _, _ = tf.contrib.rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x_unstack,                                           dtype=tf.float32)pred=tf.matmul(outputs[-1], weights['out']) + biases['out']cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)# tf.argmax(pred,1)求每行中最大的元素的角标(即列号)correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))# tf.cast(correct_pred, tf.float32),将correct_pred转化为浮点型accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))init = tf.global_variables_initializer()# endregion# region 执行计算图with tf.Session() as sess:    sess.run(init)    step = 1    while step * batch_size < max_samples:# step*128<400000        # 直接读出来的batch_x的尺寸为[batch_size,784]        batch_x, batch_y = mnist.train.next_batch(batch_size)        # batch_x经过reshape后,尺寸变为(batch_size, n_steps, n_input)        batch_x = batch_x.reshape((batch_size, n_steps, n_input))        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})        if step % display_step == 0:            acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})            loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})            print("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \                  "{:.6f}".format(loss) + ", Training Accuracy= " + \                  "{:.5f}".format(acc))        step += 1    print("Optimization Finished!")    test_len = 10000    test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))    test_label = mnist.test.labels[:test_len]    print("Testing Accuracy:", \        sess.run(accuracy, feed_dict={x: test_data, y: test_label}))# endregion

 

 

参考:

转载地址:http://lxygf.baihongyu.com/

你可能感兴趣的文章
CMake 入门实战
查看>>
绑定CPU逻辑核心的利器——taskset
查看>>
Linux下perf性能测试火焰图只显示函数地址不显示函数名的问题
查看>>
c结构体、c++结构体和c++类的区别以及错误纠正
查看>>
Linux下查看根目录各文件内存占用情况
查看>>
A星算法详解(个人认为最详细,最通俗易懂的一个版本)
查看>>
利用栈实现DFS
查看>>
逆序对的数量(递归+归并思想)
查看>>
数的范围(二分查找上下界)
查看>>
算法导论阅读顺序
查看>>
Windows程序设计:直线绘制
查看>>
linux之CentOS下文件解压方式
查看>>
Django字段的创建并连接MYSQL
查看>>
div标签布局的使用
查看>>
HTML中表格的使用
查看>>
(模板 重要)Tarjan算法解决LCA问题(PAT 1151 LCA in a Binary Tree)
查看>>
(PAT 1154) Vertex Coloring (图的广度优先遍历)
查看>>
(PAT 1115) Counting Nodes in a BST (二叉查找树-统计指定层元素个数)
查看>>
(PAT 1143) Lowest Common Ancestor (二叉查找树的LCA)
查看>>
(PAT 1061) Dating (字符串处理)
查看>>