16.xx - wwj-2017-1117/graph GitHub Wiki
# 实战:人工神经网络逼近股票收盘均价V2
import numpy as np import tensorflow as tf import matplotlib.pyplot as plt
y轴表示日期
date = np.linspace(1, 15, 15)
收盘价
endPrice = np.array( [2511.90, 2538.26, 2510.68, 2591.66, 2732.98, 2701.69, 2701.29, 2678.67, 2726.50, 2681.50, 2739.17, 2715.07, 2823.58, 2864.90, 2919.08])
开盘价
startPrice = np.array( [2438.71, 2500.88, 2534.95, 2512.52, 2594.04, 2743.26, 2697.47, 2695.24, 2678.23, 2722.13, 2674.93, 2744.13, 2717.46, 2832.73, 2877.40])
绘图
plt.figure() for i in range(15): # 柱状图 dataOne = np.zeros([2]) dataOne[0] = i dataOne[1] = i priceOne = np.zeros([2]) priceOne[0] = startPrice[i] priceOne[1] = endPrice[i] if endPrice[i] > startPrice[i]: plt.plot(dataOne, priceOne, 'r', lw=8) else: plt.plot(dataOne, priceOne, 'g', lw=8)
搭建神经网络 : 输入日期,输出股价
A(151) * w1(110) + b1(110) = B(1510)
B(1510) * w2(101) + b2(151) = C(151)
为了计算方便,将日期进行归一化
dateNormal = np.zeros([15, 1]) priceNormal = np.zeros([15, 1]) for i in range(15): dateNormal[i] = i / 14.0 priceNormal[i] = endPrice[i] / 3000.0
定义输入层
x = tf.placeholder(tf.float32, [None, 1]) # 表面n行1列 y = tf.placeholder(tf.float32, [None, 1]) # 表面n行1列
定义隐藏层
w1 = tf.Variable(tf.random_uniform([1, 10], 0, 1)) # 创建w1, 初始数据在0~1范围内, 因为神经网络需要对其进行更新,所以是变量 b1 = tf.Variable(tf.zeros([1, 10]))
定义一个操作
wb1 = tf.matmul(x, w1) + b1 layer1 = tf.nn.relu(wb1) # 激励函数
定义输出层
w2 = tf.Variable(tf.random_uniform([10, 1], 0, 1)) b2 = tf.Variable(tf.zeros([15, 1])) wb2 = tf.matmul(layer1, w2) + b2 layer2 = tf.nn.relu(wb2)
定义神经网络输出与实际值的差异,为了调整循环次数
loss = tf.reduce_mean(tf.square(y - layer2)) # y是真实值 这行代码实际是运算了标准差 train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) # 梯度下降,步长,最小化loss
运行神经网络
with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 变量初始化 for i in range(10000): sess.run(train_step, feed_dict={x: dateNormal, y: priceNormal}) pred = sess.run(layer2, feed_dict={x: dateNormal}) predPrice = np.zeros([15, 1]) for i in range(15): predPrice[i] = pred[i] * 3000
plt.plot(date, predPrice, 'b', lw=2) plt.show() plt.plot(date, endPrice, 'g', lw=10) plt.show()