tfrnntest - juedaiyuer/researchNote GitHub Wiki

rnn的一个小测试

测试源代码

  • 源代码位置

    import tensorflow as tf import numpy as np

    Create input data

    X = np.random.randn(2, 10, 8)

    The second example is of length 6

    X[1,6,:] = 0 X_lengths = [10, 6]

    cell = tf.nn.rnn_cell.LSTMCell(num_units=64, state_is_tuple=True) cell = tf.nn.rnn_cell.DropoutWrapper(cell=cell, output_keep_prob=0.5) cell = tf.nn.rnn_cell.MultiRNNCell(cells=[cell] * 4, state_is_tuple=True)

    outputs, last_states = tf.nn.dynamic_rnn( cell=cell, dtype=tf.float64, sequence_length=X_lengths, inputs=X)

    result = tf.contrib.learn.run_n( {"outputs": outputs, "last_states": last_states}, n=1, feed_dict=None)

代码注解

rnn_cell