神经网络作业归纳(作业3) - zLulus/My_Note GitHub Wiki
问题:识别0~9的手写数字
多级分类方案
1.加载数据,20*20像素的数字图片与数字结果(0~9),其中0在计算中表示为10
并进行可视化
2.确定方程:cost function和 ∂J/∂theta (=gradient)
3.使用fmincg函数进行数据训练,得到训练结果all_theta
4.使用all_theta去计算数据,得到预测结果
神经网络方案
1.该作业中是一个已经训练好的神经网络,需要写调用的方法
2.按照之前教的,一层层往下递推
function p = predict(Theta1, Theta2, X)
%PREDICT Predict the label of an input given a trained neural network
% p = PREDICT(Theta1, Theta2, X) outputs the predicted label of X given the
% trained weights of a neural network (Theta1, Theta2)
% Useful values
m = size(X, 1);
num_labels = size(Theta2, 1);
% You need to return the following variables correctly
p = zeros(size(X, 1), 1);
% ====================== YOUR CODE HERE ======================
% Instructions: Complete the following code to make predictions using
% your learned neural network. You should set p to a
% vector containing labels between 1 to num_labels.
%
% Hint: The max function might come in useful. In particular, the max
% function can also return the index of the max element, for more
% information see 'help max'. If your examples are in rows, then, you
% can use max(A, [], 2) to obtain the max for each row.
%
% Add ones to the X data matrix -jin
X = [ones(m, 1) X];
a2 = sigmoid(X * Theta1'); % 第二层激活函数输出
a2 = [ones(m, 1) a2]; % 第二层加入b 第一列加入1
a3 = sigmoid(a2 * Theta2');
[aa,p] = max(a3,[],2); % 返回每行最大值的索引位置,也就是预测的数字
% =========================================================================
end