16.xx - wwj-2017-1117/graph GitHub Wiki

# 实战:人工神经网络逼近股票收盘均价V2

import numpy as np import tensorflow as tf import matplotlib.pyplot as plt

y轴表示日期

date = np.linspace(1, 15, 15)

收盘价

endPrice = np.array( [2511.90, 2538.26, 2510.68, 2591.66, 2732.98, 2701.69, 2701.29, 2678.67, 2726.50, 2681.50, 2739.17, 2715.07, 2823.58, 2864.90, 2919.08])

开盘价

startPrice = np.array( [2438.71, 2500.88, 2534.95, 2512.52, 2594.04, 2743.26, 2697.47, 2695.24, 2678.23, 2722.13, 2674.93, 2744.13, 2717.46, 2832.73, 2877.40])

绘图

plt.figure() for i in range(15): # 柱状图 dataOne = np.zeros([2]) dataOne[0] = i dataOne[1] = i priceOne = np.zeros([2]) priceOne[0] = startPrice[i] priceOne[1] = endPrice[i] if endPrice[i] > startPrice[i]: plt.plot(dataOne, priceOne, 'r', lw=8) else: plt.plot(dataOne, priceOne, 'g', lw=8)

搭建神经网络 : 输入日期,输出股价

A(151) * w1(110) + b1(110) = B(1510)

B(1510) * w2(101) + b2(151) = C(151)

为了计算方便,将日期进行归一化

dateNormal = np.zeros([15, 1]) priceNormal = np.zeros([15, 1]) for i in range(15): dateNormal[i] = i / 14.0 priceNormal[i] = endPrice[i] / 3000.0

定义输入层

x = tf.placeholder(tf.float32, [None, 1]) # 表面n行1列 y = tf.placeholder(tf.float32, [None, 1]) # 表面n行1列

定义隐藏层

w1 = tf.Variable(tf.random_uniform([1, 10], 0, 1)) # 创建w1, 初始数据在0~1范围内, 因为神经网络需要对其进行更新,所以是变量 b1 = tf.Variable(tf.zeros([1, 10]))

定义一个操作

wb1 = tf.matmul(x, w1) + b1 layer1 = tf.nn.relu(wb1) # 激励函数

定义输出层

w2 = tf.Variable(tf.random_uniform([10, 1], 0, 1)) b2 = tf.Variable(tf.zeros([15, 1])) wb2 = tf.matmul(layer1, w2) + b2 layer2 = tf.nn.relu(wb2)

定义神经网络输出与实际值的差异,为了调整循环次数

loss = tf.reduce_mean(tf.square(y - layer2)) # y是真实值 这行代码实际是运算了标准差 train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) # 梯度下降,步长,最小化loss

运行神经网络

with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 变量初始化 for i in range(10000): sess.run(train_step, feed_dict={x: dateNormal, y: priceNormal}) pred = sess.run(layer2, feed_dict={x: dateNormal}) predPrice = np.zeros([15, 1]) for i in range(15): predPrice[i] = pred[i] * 3000

plt.plot(date, predPrice, 'b', lw=2) plt.show() plt.plot(date, endPrice, 'g', lw=10) plt.show()