tensorboard - beyondnlp/nlp GitHub Wiki

  • tensorboard ์‹คํ–‰์€

    • tensorboard --logdir=๋กœ๊ทธ๋””๋ ‰ํ† ๋ฆฌ
  • ๋กœ๊ทธ๋ฅผ ์–ด๋””์— ์ €์žฅํ• ์ง€ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ์™€ ํ•จ๊นจ intance๋ฅผ ์–ป์–ด์˜จ๋‹ค.

    • 56 summary_write = tf.summary.FileWriter("logistic_logs/", graph_def=sess.graph_def)
  • ๊ทธ๋ž˜ํ”„๋กœ ๊ทธ๋ฆด๋ ค๊ณ  ํ•˜๋Š” ๋ณ€์ˆ˜๊ฐ’์„ ์„ค์ •ํ•œ๋‹ค.

  • summary_op = tf.summary.merge_all()()

    • 62 tf.summary.scalar("cost", cost )
    • 63 tf.summary.scalar("eval", eval_op )
    • 64 summary_op = tf.summary.merge_all()
  • ์จ๋จธ๋ฆฌ์— ์ถœ๋ ฅํ•œ ๊ฐ’์„ ๋ชจ์•„์„œ summy_write์— ์ถ”๊ฐ€

    • 88 summary_str = sess.run(summary_op, feed_dict=val_feed_dict)
    • 89 summary_write.add_summary( summary_str, epoch );
    • 90 summary_write.flush()
 44 with tf.Graph().as_default() as sess:
 45     x = tf.placeholder( "float", [None, 784])
 46     y = tf.placeholder( "float", [None,10])
 47
 48     output = inference(x)
 49     cost = get_loss(output,y)
 50
 51     global_step = tf.Variable(0,name='global_step', trainable=False)
 52     train_op = training(cost, global_step)
 53     eval_op = evaluate(output,y)
 54     #saver = tf.train.Saver()
 55     sess = tf.Session()
 56     summary_write = tf.summary.FileWriter("logistic_logs/", graph_def=sess.graph_def)
 57
 58     init_op = tf.initialize_all_variables()
 59
 60     sess.run(init_op)
 61
 62     tf.summary.scalar("cost", cost )
 63     tf.summary.scalar("eval", eval_op )
 64     summary_op = tf.summary.merge_all()
 65     for epoch in range(training_epochs):
 66         avg_cost = 0
 67         total_batch = int(mnist.train.num_examples/batch_size)
 68
 69         for i in range(total_batch):
 70             mbatch_x, mbatch_y = mnist.train.next_batch(batch_size)
 71             feed_dict = {x:mbatch_x, y:mbatch_y}
 72             sess.run(train_op, feed_dict=feed_dict)
 73             minibatch_cost = sess.run(cost, feed_dict=feed_dict)
 74             avg_cost += minibatch_cost/total_batch
 75
 76
 77
 78
 79
 80         if epoch % display_step == 0 :
 81             val_feed_dict = {
 82                 x : mnist.validation.images,
 83                 y : mnist.validation.labels
 84             }
 85             accuracy = sess.run(eval_op, feed_dict=val_feed_dict)
 86             print( "Valid Error:", (1-accuracy))
 87
 88             summary_str = sess.run(summary_op, feed_dict=val_feed_dict)
 89             summary_write.add_summary( summary_str, epoch );
 90             summary_write.flush()
 91
 92             #saver.save( sess, "logistic_logs/model-checkpoint", global_step=global_step)
 93         test_feed_dict = {
 94             x : mnist.test.images,
 95             y : mnist.test.labels
 96         }
 97
 98         accuracy = sess.run(eval_op, feed_dict=test_feed_dict)
 99         print( "Test Error:", accuracy)