机器学习 – 结课作业 – 鸢尾花分类问题

MachineLearning – Term Project

一、问题分析

1.1 问题背景

    本课题给出鸢尾花分类问题,一共三个品种,四个特征,数据有缺失。五列数据分别对应花萼长度、花萼宽度、花瓣长度、花瓣宽度和种类,其中种类分别为山鸢尾、变色鸢尾和维吉尼亚鸢尾三个类别。

1.2 模型选择

    本问题是一个典型的分类问题,解决此问题的方法多种多样,基础的办法有逻辑回归、贝叶斯等方法。但是逻辑回归对于多分类问题处理起来较为麻烦,同时计算其代价函数的偏导数过程也很复杂,这里不考虑使用逻辑回归。而贝叶斯分类器在一定程度上来说表现效果不佳。
    同时,贝叶斯分类器已经由小组成员完成,这里使用神经网络对数据进行处理。下面展示神经网络的模型示意图:

图1 神经网络示意图

    观察给出的数据,本身有5列数据,但是作为特征性的数据只有4种,所以我们的输入层只需要4个节点,同时加上一个偏置项,共5个节点,即输入层a1共有5个节点。
    输出层的节点数量和分类的数量保持一致,共3个节点,这里不含偏置项。
    隐藏层的节点数量可以是任意多的,但是这里为了计算简单,令隐藏层和输入层的节点数量一样,保持为5个节点(含偏置项)。
    对于结果而言,我们只需要两个权重矩阵的内容,即小delta的内容。θ存放代价函数对每一个节点的偏导数,用于更新我们的权值矩阵。最后的大DELTA矩阵是更新之后的偏导数矩阵。

1.3 数据补充

正如前面所提到的,一共有150组数据,其中缺失数据项共有27项。我们先删除这27项,保留123项进行数据可视化,然后分析数据。

% 数据可视化

Data = load('Data-DeleteMissingData.txt')
% 此处删除不完整变量,下面进行可视化

Setosa = Data([1:39], :)
Versicolor = Data([40:82], :)
Virginica = Data([83:123], :)

##### 绘制第一特征 #####
Feature1_Setosa = Setosa(:, 1)
Feature1_Versicolor = Versicolor(:, 1)
Feature1_Virginica = Virginica(:, 1)

figure(1)
hist(Feature1_Setosa, 50)
xlabel('数值')
ylabel('数量')
title('Setosa Feature 1')

figure(2)
hist(Feature1_Versicolor, 50)
xlabel('数值')
ylabel('数量')
title('Versicolor Feature 1')

figure(3)
hist(Feature1_Virginica, 50)
xlabel('数值')
ylabel('数量')
title('Virginica Feature 1')

代码块1 数据可视化部分内容

    这里我们使用Octave对数据进行可视化,导入的数据为删除缺失项之后的全部数据。如上所示,我们对特征1进行可视化。

图2 特征1


图3 特征2


图4 特征3


图5 特征4

    对删除缺失项的数据观察得到:所有数据都趋近于平均值附近,所以对缺失项使用平均值填充,对于个严重偏平均值的数据采用删除办法。
    综上,处理后的数据共146组。

二、构建模型

2.1 准备部分

clear; close all; clc

DataSize = 146;
INIT_EPSILON = 1;
% 数据大小、随机初始化参数

theta1 = rand(4, 5) * (2 * INIT_EPSILON) - INIT_EPSILON
theta2 = rand(3, 5) * (2 * INIT_EPSILON) - INIT_EPSILON
% 权重矩阵的初始化

lambda = 0.05
% 学习速率

Etotal = 50;
% 总误差

DELTA1 = 0;
DELTA2 = 0;
%theta的偏导矩阵

Data = load('Data-iris.txt')
% 载入原始数据
XFeature = Data(:, [1:4])
Y_ini = Data(:, 5);
Y = zeros(DataSize, 1);
for i=1:DataSize
  col = Y_ini(i, 1);
  Y(i, col) = 1;
endfor
Y
% 输出层向量化

function g = sigmoid(z)
g = zeros(size(z));
g = 1./(1+e.^(-z));
end
% Sigmoid函数

代码块2 模型准备部分

    这里是模型的准备阶段,不涉及神经网络和反向传播算法。
    第一行我们对程序进行初始化,避免不必要的错误。然后选择数据的大小,数据大小与我们经过处理的数据保持一致,即:146组数据。

theta1 = rand(4, 5) * (2 * INIT_EPSILON) - INIT_EPSILON
theta2 = rand(3, 5) * (2 * INIT_EPSILON) - INIT_EPSILON
% 权重矩阵的初始化

    这段代码是对我们所需要的权值矩阵进行初始化,这里随机生成一个4 * 5和3 * 5的矩阵,矩阵中的每个元素是[-1, 1]之间产生的随机值。注意:元素值全部相同的矩阵将会导致对每个节点的更新一样。
    在下面我们添加了一个误差计算项,这个误差是**平均误差**而非代价函数的**均方误差**,之所以这样选择的原因是:如果使用反向传播算法则无需表示和计算代价函数就可以得到他的偏导矩阵,而使用平均误差可以减少计算量,提高程序运行速度。
    之后是两个DELTA矩阵,这是用于存放偏导的矩阵。其后是载入数据。

function g = sigmoid(z)
g = zeros(size(z));
g = 1./(1+e.^(-z));
end
% Sigmoid函数

    这里我们添加了一个sigmoid函数,用于之后的计算。

2.2 反向传播算法

    如上面所提到的,为了提高模型的计算效率,我们使用反向传播算法。计算步骤分为以下5步:

·1:把输入数据添加到输入层。

·2:使用正向传播算法计算每一层的节点。

·3:计算最后一层的偏导矩阵。

·4:反向计算隐藏层的偏导矩阵。

·5:计算出每个偏导矩阵的误差项,并更新权重矩阵。

while Etotal>0.1
  DELTA1 = 0;
  DELTA2 = 0;
  Etotal = 0;
  for i=1:DataSize
   
  ################################## 正向传播 ##################################
    % ---输入层 a1---
    a1 = XFeature(i, :);
    a1 = a1';               % 4x1
    a1 = [ones(1,1); a1];   % 添加偏置项后为 5x1
   
    % ---隐藏层 a2---
    z2 = theta1 * a1;
    a2 = sigmoid(z2);       % 4x1
    a2 = [ones(1,1); a2];   % 添加偏置项后为 5x1
   
    % ---输出层 a3---
    z3 = theta2 * a2;
    a3 = sigmoid(z3);       % 4x1

  ################################## 计算误差项 ################################
    y = Y(i, :);            %提取输出y
    y = y';
    delta3 = a3 - y;
    delta2 = ((theta2') * delta3).*a2.*(1 - a2);
   
    Etotal = Etotal + delta3(1) + delta3(2) + delta3(3);
   
  ################################## 计算偏导数 ################################

    DELTA1 = DELTA1 + delta2*(a1');
    DELTA2 = DELTA2 + delta3*(a2');
  endfor
  Etotal
  DELTA1 = DELTA1./DataSize();
  DELTA2 = DELTA2./DataSize();
  theta1 = theta1 .- (lambda .* DELTA1([2:5], :));
  theta2 = theta2 .- (lambda .* DELTA2);
endwhile
Etotal
theta1
theta2

代码块3 反向传播算法

    这里我们使用了一个双重循环,内部循环是反向传播算法的主体,外部循环用于检测误差大小,如果误差满足我们的需要则输出权重矩阵。

  DELTA1 = 0;
  DELTA2 = 0;
  Etotal = 0;

    在每次的反向传播算法执行之前,我们都对偏导矩阵和误差进行一次初始化,将其初始化为全0的矩阵。

2.2.1 正向传播
    % ---输入层 a1---
    a1 = XFeature(i, :);
    a1 = a1';               % 4x1
    a1 = [ones(1,1); a1];   % 添加偏置项后为 5x1
   
    % ---隐藏层 a2---
    z2 = theta1 * a1;
    a2 = sigmoid(z2);       % 4x1
    a2 = [ones(1,1); a2];   % 添加偏置项后为 5x1
   
    % ---输出层 a3---
    z3 = theta2 * a2;
    a3 = sigmoid(z3);       % 4x1

    这里使用正向传播算法更新我们的所有节点。

    a是当前层的所有节点的矩阵(向量),z是sigmoid函数中的所有参数。

    我们将上式重写,就得到了如图所示的表达方式。g函数是sigmoid函数。

    这里将a向量化,就可以使用权重矩阵和a的乘积来表示z。

    于是可以直接表示a=g(z)。
    至此,正向传播算法的表示已经和代码部分一致,之后便可以计算误差项。

2.2.2 计算误差项
    y = Y(i, :);            %提取输出y
    y = y';
    delta3 = a3 - y;
    delta2 = ((theta2') * delta3).*a2.*(1 - a2);
   
    Etotal = Etotal + delta3(1) + delta3(2) + delta3(3);

    这里我们提取出结果Y,然后用输出层a3减去结果y得到输出层的误差delta3。然后通过反向传播计算隐藏层a2的误差delta2,输出层不存在误差所以这里没有delta1。
    最后计算平均误差Etotal,这里没有使用均方误差代价函数的原因上面已经提到。

图6 误差矩阵的更新

2.2.3 计算偏导数
    DELTA1 = DELTA1 + delta2*(a1');
    DELTA2 = DELTA2 + delta3*(a2');

    这里使用误差矩阵来计算我们的偏导数矩阵。

图7 偏导矩阵的计算

2.2.4 更新权重矩阵
  Etotal
  DELTA1 = DELTA1./DataSize();
  DELTA2 = DELTA2./DataSize();
  theta1 = theta1 .- (lambda .* DELTA1([2:5], :));
  theta2 = theta2 .- (lambda .* DELTA2);

    这里第一行输出误差便于观察,第二行更新大DELTA矩阵,再下面是更新权重矩阵。
    至此,完成了一次反向传播算法,之后只需要重复计算以降低误差Etotal即可。

2.3 结果

    我们经过三次运行,得到了3个不同的权值矩阵。

2.3.1 第一次计算
theta1 = [  
   0.3634220  -0.5785249  -0.6768396  -0.9181471  -0.1386722;
   0.1001102  -0.4271788  -2.1368118   2.8189427   1.6372495;
   1.2574465   0.0056068   2.3774486  -2.7294670  -1.5477634;
  -6.0360525  -2.9948475  -2.2951499   4.6916875   5.0640912
  ]

theta2 = [
   0.22857  -0.21957  -5.99995   5.64679  -2.87037;
   0.18200   0.32620   5.21638  -5.89776  -9.70556;
  -4.79493  -0.45691  -0.77220  -2.57438   9.93093;
  ]
2.3.2 第二次计算
theta1 = [  
  -13.57609   -1.01151   -5.04825    2.99442   10.95433;
    1.06531    0.64667    1.27045   -2.76158   -1.02987;
   -0.82416   -0.20056   -2.41945    2.95448    2.02394;
    0.74033   -0.72004   -2.04824    3.03701    1.74662;
  ]

theta2 = [
    2.38448   -2.38319    5.91602   -5.71440   -5.34631;
   -3.42144  -16.76785   -4.71169    6.35865    5.23593;
   -5.20791   16.71897   -4.29564   -2.21458   -0.72748;
   ]
2.3.3 第三次计算
theta1 = [  
   14.53745    1.06695    5.41210   -3.19412  -11.71895;
   -0.10527    0.55710    0.99034    1.08439    1.00968;
   -1.42449   -0.48922   -2.83584    4.16115    1.87471;
    0.82783    0.88992    0.56318    0.86599   -0.45926;
  ]

theta2 = [
    0.64984    4.01768    0.55466  -14.63997    2.01987;
   -7.43639   15.33393   -6.76862   14.13434   -7.85377;
    1.62750  -17.23315    0.66101    4.67994    1.76113;
  ]

三、结果检验

3.1 结果统计


图8 结果统计

矩阵\类别 Setosa Versicolor Virginica
第一组 46 50 50
第二组 46 54 46
第三组 46 54 46
标准结果 46 50 50

    统计使用c++完成,代码将放在最后的附录当中。

3.2 精度检测

3.2.1 第一组测试结果

setosa

预测\实际
TP = 46 FP = 0
FN = 0 TN = 100

    精度 = 100.00%    召回率 = 100.00%

versicolor

预测\实际
TP = 50 FP = 0
FN = 0 TN = 96

    精度 = 100.00%    召回率 = 100.00%

virginica

预测\实际
TP = 50 FP = 0
FN = 0 TN = 96

    精度 = 100.00%    召回率 = 100.00%

总体平均

    精度 = 100.00%    召回率 = 100.00%

3.2.2 第二组测试结果

setosa

预测\实际
TP = 46 FP = 0
FN = 0 TN = 100

    精度 = 100.00%    召回率 = 100.00%

versicolor

预测\实际
TP = 50 FP = 4
FN = 0 TN = 92

    精度 = 92.59%    召回率 = 100.00%

virginica

预测\实际
TP = 46 FP = 0
FN = 4 TN = 96

    精度 = 100.00%    召回率 = 92.00%

总体平均

    精度 = 97.53%    召回率 = 97.33%

3.2.3 第三组测试结果

setosa

预测\实际
TP = 46 FP = 0
FN = 0 TN = 100

    精度 = 100.00%    召回率 = 100.00%

versicolor

预测\实际
TP = 50 FP = 4
FN = 0 TN = 92

    精度 = 92.59%    召回率 = 100.00%

virginica

预测\实际
TP = 46 FP = 0
FN = 4 TN = 96

    精度 = 100.00%    召回率 = 92.00%

总体平均

    精度 = 97.53%    召回率 = 97.33%

3.2.4 平均统计结果

    精度 = 98.35%    召回率 = 98.22%

四、附录

4.1 数据处理

    下列为数据处理后的输入值,前四列为特征,最后一列用1、2、3表示三个品种。

5.1,3.5,1.4,0.2,1
4.9,3,1.4,0.2,1
4.7,3.2,1.3,0.2,1
4.6,3.1,1.5,0.2,1
5,3.6,1.4,0.244186047,1
5.4,3.9,1.7,0.4,1
4.6,3.4,1.4,0.3,1
5,3.4,1.5,0.2,1
4.4,2.9,1.4,0.2,1
5.034090909,3.1,1.5,0.1,1
5.4,3.7,1.5,0.2,1
4.8,3.4,1.6,0.2,1
4.8,3,1.4,0.1,1
5.8,4,1.481395349,0.2,1
5.7,4.4,1.5,0.4,1
5.4,3.9,1.3,0.4,1
5.1,3.5,1.4,0.3,1
5.7,3.8,1.481395349,0.3,1
5.1,3.8,1.5,0.3,1
5.4,3.4,1.7,0.2,1
5.1,3.7,1.5,0.4,1
5.1,3.3,1.7,0.5,1
4.8,3.4,1.9,0.2,1
5,3,1.6,0.2,1
5,3.4,1.6,0.4,1
5.2,3.5,1.5,0.2,1
5.2,3.4,1.4,0.2,1
4.7,3.2,1.481395349,0.2,1
4.8,3.1,1.6,0.2,1
5.4,3.4,1.5,0.4,1
5.2,4.1,1.5,0.244186047,1
5.034090909,4.2,1.4,0.2,1
4.9,3.1,1.5,0.1,1
5,3.2,1.2,0.2,1
5.5,3.5,1.3,0.2,1
4.9,3.1,1.5,0.1,1
4.4,3,1.3,0.2,1
5.1,3.4,1.5,0.2,1
5,3.5,1.3,0.3,1
4.4,3.2,1.3,0.2,1
5.1,3.8,1.9,0.4,1
4.8,3,1.4,0.3,1
5.1,3.8,1.6,0.2,1
4.6,3.2,1.4,0.244186047,1
5.3,3.7,1.5,0.2,1
5,3.3,1.4,0.2,1
7,3.2,4.7,1.325,2
6.4,2.7625,4.5,1.5,2
6.9,3.1,4.238297872,1.5,2
5.5,2.3,4,1.3,2
6.5,2.8,4.6,1.5,2
5.7,2.8,4.5,1.3,2
6.3,3.3,4.7,1.6,2
5.963829787,2.4,3.3,1,2
6.6,2.9,4.6,1.3,2
5.2,2.7,3.9,1.4,2
5,2,3.5,1,2
5.9,3,4.2,1.5,2
6,2.2,4,1,2
6.1,2.9,4.7,1.4,2
5.6,2.9,3.6,1.3,2
6.7,3.1,4.4,1.4,2
5.6,3,4.5,1.5,2
5.8,2.7,4.1,1,2
6.2,2.2,4.5,1.5,2
5.6,2.5,3.9,1.1,2
5.9,3.2,4.238297872,1.8,2
6.1,2.8,4,1.3,2
6.3,2.5,4.9,1.5,2
6.1,2.8,4.7,1.2,2
6.4,2.9,4.3,1.3,2
6.6,3,4.4,1.4,2
6.8,2.8,4.8,1.4,2
6.7,3,5,1.7,2
5.963829787,2.9,4.5,1.5,2
5.7,2.6,3.5,1,2
5.5,2.4,3.8,1.1,2
5.5,2.4,3.7,1,2
5.8,2.7,3.9,1.2,2
6,2.7,5.1,1.6,2
5.4,3,4.5,1.5,2
6,3.4,4.5,1.6,2
6.7,3.1,4.7,1.5,2
6.3,2.3,4.4,1.3,2
5.6,3,4.1,1.3,2
5.5,2.5,4,1.3,2
5.5,2.6,4.4,1.2,2
6.1,3,4.6,1.4,2
5.8,2.6,4,1.2,2
5,2.3,3.3,1,2
5.963829787,2.7625,4.2,1.3,2
5.7,3,4.2,1.2,2
5.7,2.9,4.2,1.3,2
6.2,2.9,4.3,1.3,2
5.1,2.5,3,1.1,2
5.7,2.8,4.238297872,1.325,2
6.3,3.3,6,2.5,3
5.8,2.7,5.1,1.9,3
7.1,3,5.9,2.1,3
6.3,2.9,5.6,1.8,3
6.5,3,5.8,2.2,3
7.6,3,6.6,2.1,3
4.9,2.5,4.5,1.7,3
7.3,2.9,6.3,1.8,3
6.7,2.5,5.8,1.8,3
7.2,3.6,6.1,2.5,3
6.5,3.2,5.1,2,3
6.4,2.7,5.3,1.9,3
6.8,3,5.5,2.1,3
5.7,2.5,5.580851064,2,3
5.8,2.8,5.1,2.4,3
6.4,3.2,5.3,2.3,3
6.5,3,5.5,1.8,3
7.7,3.8,6.7,2.2,3
7.7,2.6,6.9,2.3,3
6,2.2,5,1.5,3
6.9,3.2,5.7,2.3,3
5.6,2.8,4.9,2,3
7.7,2.8,6.7,2,3
6.3,2.7,5.580851064,1.8,3
6.7,3.3,5.7,2.1,3
7.2,3.2,6,1.8,3
6.2,2.8,4.8,2.029787234,3
6.1,3,4.9,1.8,3
6.4,2.8,5.6,2.1,3
7.2,3,5.8,1.6,3
7.4,2.8,6.1,1.9,3
7.9,3.8,6.4,2,3
6.4,2.8,5.6,2.2,3
6.3,2.8,5.1,1.5,3
6.1,2.981632653,5.6,1.4,3
7.7,3,6.1,2.029787234,3
6.3,3.4,5.6,2.4,3
6.4,3.1,5.5,1.8,3
6,3,4.8,2.029787234,3
6.9,3.1,5.580851064,2.1,3
6.7,3.1,5.6,2.4,3
6.9,3.1,5.1,2.3,3
5.8,2.7,5.1,1.9,3
6.8,3.2,5.9,2.3,3
6.585714286,3.3,5.7,2.5,3
6.7,3,5.2,2.3,3
6.3,2.5,5,1.9,3
6.5,3,5.2,2,3
6.2,3.4,5.4,2.3,3
5.9,3,5.1,1.8,3

4.2 输出结果

4.2.1 标准输出
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   1   0   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   1   0
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
   0   0   1
4.2.2 第一组输出

权值矩阵:

theta1 = [
   0.3634220  -0.5785249  -0.6768396  -0.9181471  -0.1386722;
   0.1001102  -0.4271788  -2.1368118   2.8189427   1.6372495;
   1.2574465   0.0056068   2.3774486  -2.7294670  -1.5477634;
  -6.0360525  -2.9948475  -2.2951499   4.6916875   5.0640912
  ]

theta2 = [
   0.22857  -0.21957  -5.99995   5.64679  -2.87037;
   0.18200   0.32620   5.21638  -5.89776  -9.70556;
  -4.79493  -0.45691  -0.77220  -2.57438   9.93093;
  ]

输出结果:
0.99704446 0.00345585 0.00063367
0.99667543 0.00386404 0.00064348
0.99694166 0.00356879 0.00063534
0.99659812 0.00394397 0.00064263
0.99706350 0.00343430 0.00063286
0.99701309 0.00348998 0.00063475
0.99693737 0.00357184 0.00063475
0.99693331 0.00357826 0.00063642
0.99643458 0.00412154 0.00064502
0.99674750 0.00378517 0.00064224
0.99707872 0.00341814 0.00063319
0.99682472 0.00369601 0.00063809
0.99674575 0.00378588 0.00064113
0.99714691 0.00334253 0.00063144
0.99716799 0.00331866 0.00063063
0.99713904 0.00335103 0.00063128
0.99701701 0.00348610 0.00063438
0.99709968 0.00339527 0.00063295
0.99707755 0.00341850 0.00063257
0.99676902 0.00376222 0.00064273
0.99701912 0.00348299 0.00063407
0.99619819 0.00438446 0.00065436
0.99627669 0.00429159 0.00064898
0.99626288 0.00431772 0.00065444
0.99670112 0.00383280 0.00064186
0.99699872 0.00350664 0.00063520
0.99700923 0.00349545 0.00063502
0.99676345 0.00376375 0.00063945
0.99642906 0.00413093 0.00064793
0.99686368 0.00365743 0.00063989
0.99714448 0.00334454 0.00063095
0.99716708 0.00331957 0.00063027
0.99673177 0.00380121 0.00064181
0.99701975 0.00348405 0.00063429
0.99709325 0.00340276 0.00063301
0.99673177 0.00380121 0.00064181
0.99675345 0.00377399 0.00063835
0.99694001 0.00357141 0.00063662
0.99705838 0.00344028 0.00063302
0.99691981 0.00359126 0.00063462
0.99674498 0.00378243 0.00063998
0.99656063 0.00398859 0.00064539
0.99706206 0.00343552 0.00063299
0.99681910 0.00370220 0.00063764
0.99707572 0.00342124 0.00063309
0.99694751 0.00356320 0.00063621
0.00318023 0.99517701 0.00393508
0.00268906 0.99220123 0.00657291
0.00332403 0.99509338 0.00384230
0.00252524 0.99000291 0.00838960
0.00264819 0.99186987 0.00687694
0.00213566 0.98282310 0.01468462
0.00246374 0.98912700 0.00914179
0.00443051 0.99369011 0.00364272
0.00305439 0.99476312 0.00433779
0.00261215 0.98950170 0.00843033
0.00317609 0.99438590 0.00440769
0.00295040 0.99330898 0.00538243
0.00316322 0.99510414 0.00399666
0.00210989 0.98258600 0.01500369
0.00405659 0.99400002 0.00372021
0.00323327 0.99507666 0.00394004
0.00145504 0.93916965 0.05371314
0.00329963 0.99492709 0.00395097
0.00139847 0.93568993 0.05800231
0.00323669 0.99474860 0.00412468
0.00234425 0.98590655 0.01162703
0.00332687 0.99491202 0.00393612
0.00077357 0.66354015 0.32274145
0.00258517 0.99112966 0.00750449
0.00320232 0.99500383 0.00401566
0.00317046 0.99499264 0.00405866
0.00280682 0.99341361 0.00556457
0.00137949 0.93259529 0.06082118
0.00213917 0.98296654 0.01457362
0.00419571 0.99396579 0.00365329
0.00323808 0.99470958 0.00414584
0.00342192 0.99482773 0.00388121
0.00336180 0.99481160 0.00395353
0.00026317 0.04902736 0.95204684
0.00099348 0.80977114 0.17509013
0.00267096 0.99077626 0.00750071
0.00292760 0.99402447 0.00497881
0.00272079 0.99267806 0.00619967
0.00319773 0.99403743 0.00458164
0.00277308 0.99226988 0.00634704
0.00207278 0.98111301 0.01620480
0.00256168 0.99057308 0.00792306
0.00319009 0.99472774 0.00419268
0.00387910 0.99419895 0.00376616
0.00306840 0.99437339 0.00455992
0.00322447 0.99440298 0.00433758
0.00300824 0.99366035 0.00507541
0.00315079 0.99477123 0.00421430
0.00622758 0.99115407 0.00337902
0.00279179 0.99241818 0.00621668
0.00017670 0.01332115 0.98739999
0.00018663 0.01591843 0.98486184
0.00018479 0.01546074 0.98532279
0.00018776 0.01628083 0.98451888
0.00017789 0.01361953 0.98710990
0.00017792 0.01363812 0.98709472
0.00019690 0.01876223 0.98201623
0.00018486 0.01549040 0.98529604
0.00018230 0.01478309 0.98598248
0.00017885 0.01386287 0.98687336
0.00034269 0.11116134 0.88893387
0.00020023 0.02014297 0.98074475
0.00019872 0.01964856 0.98123036

4.2.3 第二组输出

权值矩阵:

theta1 = [  
  -13.57609   -1.01151   -5.04825    2.99442   10.95433;
    1.06531    0.64667    1.27045   -2.76158   -1.02987;
   -0.82416   -0.20056   -2.41945    2.95448    2.02394;
    0.74033   -0.72004   -2.04824    3.03701    1.74662;
  ]
   
theta2 = [
    2.38448   -2.38319    5.91602   -5.71440   -5.34631;
   -3.42144  -16.76785   -4.71169    6.35865    5.23593;
   -5.20791   16.71897   -4.29564   -2.21458   -0.72748;
   ]

输出结果:
0.999728062 0.000318616 0.000076658
0.999684353 0.000366934 0.000078102
0.999713638 0.000334320 0.000077367
0.999659410 0.000392552 0.000080109
0.999729205 0.000317261 0.000076688
0.999719565 0.000327396 0.000077562
0.999707247 0.000340725 0.000078128
0.999711145 0.000336741 0.000077781
0.999638207 0.000416152 0.000080261
0.999692120 0.000357901 0.000078272
0.999732889 0.000313375 0.000076388
0.999689024 0.000359926 0.000079458
0.999690864 0.000359300 0.000078251
0.999743162 0.000302436 0.000075536
0.999745131 0.000300277 0.000075418
0.999741788 0.000303956 0.000075592
0.999724337 0.000322734 0.000076805
0.999737197 0.000308918 0.000075927
0.999729813 0.000316461 0.000076789
0.999692629 0.000357050 0.000078562
0.999721670 0.000325359 0.000077196
0.999606086 0.000452068 0.000080726
0.999581426 0.000472279 0.000085507
0.999629245 0.000426989 0.000079990
0.999675759 0.000375243 0.000079281
0.999721680 0.000325467 0.000077101
0.999725128 0.000321978 0.000076678
0.999684722 0.000365216 0.000079098
0.999638669 0.000415254 0.000080690
0.999708232 0.000340856 0.000077170
0.999740040 0.000305520 0.000076017
0.999743877 0.000301489 0.000075636
0.999686975 0.000363205 0.000078698
0.999728173 0.000318915 0.000076266
0.999737780 0.000308426 0.000075751
0.999686975 0.000363205 0.000078698
0.999684867 0.000365477 0.000078681
0.999713540 0.000334288 0.000077543
0.999730143 0.000316479 0.000076417
0.999705870 0.000342308 0.000078088
0.999665007 0.000384503 0.000081615
0.999666957 0.000386125 0.000078550
0.999726122 0.000320191 0.000077257
0.999692567 0.000356859 0.000078618
0.999731732 0.000314552 0.000076527
0.999715840 0.000331987 0.000077239
0.000184475 0.999696572 0.000280303
0.000170664 0.999615437 0.000372572
0.000205206 0.999663933 0.000269196
0.000176559 0.999645248 0.000334525
0.000169583 0.999616457 0.000374615
0.000175354 0.999693535 0.000299438
0.000175150 0.999672909 0.000316680
0.000420977 0.999379256 0.000199165
0.000179489 0.999699546 0.000286480
0.000187851 0.999667583 0.000295927
0.000200874 0.999671235 0.000269944
0.000186501 0.999675078 0.000292809
0.000188524 0.999690495 0.000276443
0.000172255 0.999676422 0.000319724
0.000275749 0.999567556 0.000241930
0.000191920 0.999684319 0.000276594
0.000171973 0.999649113 0.000342875
0.000197141 0.999679419 0.000272717
0.000093296 0.978215110 0.021538840
0.000197143 0.999678430 0.000272147
0.000169834 0.999447362 0.000507319
0.000204451 0.999667053 0.000268653
0.000109968 0.993432130 0.006583443
0.000175034 0.999704574 0.000291025
0.000190156 0.999687133 0.000277300
0.000187358 0.999688227 0.000281120
0.000171985 0.999672752 0.000323076
0.000126868 0.997495212 0.002501921
0.000171726 0.999638794 0.000351506
0.000316036 0.999513776 0.000228913
0.000199102 0.999675148 0.000270971
0.000215313 0.999653882 0.000261807
0.000207412 0.999663931 0.000266807
0.000068799 0.851651170 0.148095315
0.000170290 0.999635427 0.000358104
0.000182147 0.999683815 0.000294925
0.000176730 0.999686026 0.000302567
0.000169694 0.999623517 0.000368293
0.000195998 0.999678847 0.000275394
0.000183180 0.999677277 0.000296575
0.000175462 0.999696398 0.000296825
0.000176306 0.999692448 0.000298454
0.000193458 0.999682036 0.000275302
0.000267889 0.999580145 0.000241159
0.000186217 0.999688868 0.000282654
0.000193717 0.999683412 0.000275965
0.000187341 0.999689359 0.000281075
0.000187732 0.999689853 0.000279552
0.000605839 0.999128331 0.000184692
0.000182337 0.999690338 0.000288712
0.000015750 0.000185486 0.999809585
0.000017650 0.000405222 0.999583487
0.000016071 0.000212955 0.999781324
0.000021502 0.001641463 0.998323124
0.000015827 0.000191518 0.999803323
0.000015806 0.000190462 0.999804548
0.000030937 0.019728583 0.979808288
0.000017661 0.000415125 0.999574836
0.000016364 0.000242332 0.999751334
0.000015800 0.000189110 0.999805770
0.000031991 0.024894900 0.974551031
0.000017657 0.000407842 0.999581073
0.000016555 0.000259802 0.999732884
0.000015794 0.000188995 0.999805962
0.000015780 0.000185820 0.999808830
0.000016078 0.000209942 0.999783738
0.000031552 0.023647823 0.976003265
0.000016759 0.000287181 0.999705570
0.000015739 0.000185079 0.999810092
0.000035717 0.054423082 0.944861846
0.000015905 0.000197193 0.999797314
0.000017348 0.000354164 0.999634971
0.000015798 0.000189885 0.999805154
0.000017949 0.000462192 0.999526264
0.000017542 0.000391518 0.999598181
0.000035288 0.051004291 0.948462433
0.000018644 0.000569676 0.999410377
0.000070133 0.862472426 0.136371803
0.000015894 0.000196829 0.999797806
0.000094654 0.982231883 0.017909584
0.000016522 0.000259409 0.999733885
0.000042876 0.174732939 0.824277480
0.000015798 0.000188834 0.999806023
0.000140430 0.998827445 0.001183145
0.000147007 0.999182776 0.000831156
0.000016446 0.000250647 0.999742771
0.000015859 0.000193081 0.999801506
0.000039072 0.098125534 0.900906422
0.000022173 0.001904439 0.998032230
0.000016900 0.000300458 0.999691240
0.000015790 0.000187546 0.999807235
0.000016373 0.000232390 0.999759544
0.000017650 0.000405222 0.999583487
0.000015822 0.000190986 0.999803853
0.000015774 0.000186515 0.999808351
0.000015991 0.000201040 0.999792682
0.000017321 0.000352064 0.999637404
0.000020123 0.001005051 0.998967219
0.000016266 0.000228433 0.999764876
0.000043608 0.187558726 0.810570535

4.2.4 第三组输出

权值矩阵:

theta1 = [  

   14.53745    1.06695    5.41210   -3.19412  -11.71895;
   -0.10527    0.55710    0.99034    1.08439    1.00968;
   -1.42449   -0.48922   -2.83584    4.16115    1.87471;
    0.82783    0.88992    0.56318    0.86599   -0.45926;
  ]
   
theta2 = [
    0.64984    4.01768    0.55466  -14.63997    2.01987;
   -7.43639   15.33393   -6.76862   14.13434   -7.85377;
    1.62750  -17.23315    0.66101    4.67994    1.76113;
  ]

输出结果:
0.9992791911 0.0012123671 0.0000018844
0.9992607714 0.0012454972 0.0000018987
0.9992743611 0.0012239042 0.0000018871
0.9992531206 0.0012580571 0.0000019048
0.9992798966 0.0012110744 0.0000018838
0.9992773767 0.0012127593 0.0000018868
0.9992736286 0.0012233861 0.0000018882
0.9992735278 0.0012217806 0.0000018891
0.9992438236 0.0012771008 0.0000019110
0.9992636095 0.0012393814 0.0000018969
0.9992806583 0.0012083929 0.0000018838
0.9992662523 0.0012337513 0.0000018950
0.9992637343 0.0012416580 0.0000018960
0.9992832894 0.0012030857 0.0000018819
0.9992838873 0.0012016428 0.0000018816
0.9992830340 0.0012044140 0.0000018817
0.9992781499 0.0012138653 0.0000018853
0.9992817527 0.0012059042 0.0000018831
0.9992804105 0.0012090172 0.0000018838
0.9992641132 0.0012353433 0.0000018975
0.9992780139 0.0012130319 0.0000018857
0.9992284325 0.0012936154 0.0000019261
0.9992202067 0.0013070860 0.0000019327
0.9992326045 0.0012893993 0.0000019221
0.9992605253 0.0012421012 0.0000019000
0.9992770223 0.0012152270 0.0000018865
0.9992778597 0.0012146005 0.0000018855
0.9992636864 0.0012399219 0.0000018964
0.9992414971 0.0012752796 0.0000019148
0.9992715476 0.0012236161 0.0000018912
0.9992829829 0.0012041245 0.0000018819
0.9992836644 0.0012033688 0.0000018812
0.9992621385 0.0012423318 0.0000018979
0.9992785615 0.0012163110 0.0000018840
0.9992815841 0.0012078315 0.0000018827
0.9992621385 0.0012423318 0.0000018979
0.9992639844 0.0012444246 0.0000018944
0.9992740765 0.0012205960 0.0000018887
0.9992799344 0.0012117145 0.0000018835
0.9992726282 0.0012285335 0.0000018878
0.9992591164 0.0012424966 0.0000019018
0.9992544294 0.0012559360 0.0000019037
0.9992794312 0.0012104403 0.0000018847
0.9992673934 0.0012348239 0.0000018930
0.9992804478 0.0012089198 0.0000018839
0.9992747734 0.0012206723 0.0000018877
0.0006144429 0.9993909343 0.0002029522
0.0005829440 0.9992654872 0.0002515275
0.0006203274 0.9993809727 0.0002035204
0.0005950154 0.9993160667 0.0002317071
0.0005824539 0.9992643550 0.0002520851
0.0006080755 0.9993741651 0.0002100339
0.0006021838 0.9993482514 0.0002196057
0.0007236545 0.9992890458 0.0001921369
0.0006120461 0.9993874967 0.0002048018
0.0006114021 0.9993610855 0.0002123988
0.0006200589 0.9993832527 0.0002031221
0.0006115816 0.9993683036 0.0002103056
0.0006147826 0.9993898013 0.0002031700
0.0006009333 0.9993475291 0.0002203242
0.0006729605 0.9993362157 0.0001968709
0.0006165889 0.9993867663 0.0002033111
0.0005935454 0.9993127500 0.0002332549
0.0006201255 0.9993871005 0.0002019372
0.0002257973 0.9734175100 0.0144417091
0.0006195899 0.9993860429 0.0002024369
0.0005492594 0.9990534285 0.0003319502
0.0006221763 0.9993821484 0.0002025344
0.0003048026 0.9913927224 0.0040264356
0.0006108579 0.9993867798 0.0002054512
0.0006161006 0.9993875924 0.0002032681
0.0006143404 0.9993853472 0.0002045392
0.0005993498 0.9993418459 0.0002225702
0.0003840924 0.9964184250 0.0014978640
0.0005902418 0.9992985891 0.0002387257
0.0006855319 0.9993252095 0.0001954901
0.0006201060 0.9993848427 0.0002025870
0.0006277606 0.9993796941 0.0002012126
0.0006245368 0.9993814842 0.0002018654
0.0001358426 0.8404635127 0.1146743631
0.0005888489 0.9992922988 0.0002411472
0.0006111988 0.9993722896 0.0002093425
0.0006065187 0.9993670799 0.0002126057
0.0005840538 0.9992732391 0.0002487975
0.0006213448 0.9993836419 0.0002024341
0.0006089434 0.9993665672 0.0002118319
0.0006089078 0.9993778425 0.0002086960
0.0006083652 0.9993742819 0.0002098823
0.0006175700 0.9993857675 0.0002032444
0.0006622940 0.9993473414 0.0001977245
0.0006141970 0.9993842967 0.0002048933
0.0006200739 0.9993864427 0.0002021303
0.0006160418 0.9993857861 0.0002037969
0.0006154764 0.9993873470 0.0002035688
0.0009256105 0.9990988590 0.0001775669
0.0006122215 0.9993806020 0.0002066618
0.0000110259 0.0003621594 0.9998380502
0.0000128664 0.0006526332 0.9996860314
0.0000112921 0.0003966510 0.9998206096
0.0000175947 0.0021516790 0.9987990290
0.0000110824 0.0003692944 0.9998344593
0.0000110710 0.0003678319 0.9998351923
0.0000324765 0.0218766626 0.9836119973
0.0000129085 0.0006607778 0.9996816032
0.0000115703 0.0004352675 0.9998008656
0.0000110595 0.0003663644 0.9998359310
0.0000341718 0.0264263556 0.9797025222
0.0000128690 0.0006530976 0.9996857677
0.0000117324 0.0004589645 0.9997886328
0.0000110578 0.0003662350 0.9998360117
0.0000110306 0.0003628003 0.9998377415
0.0000112686 0.0003934708 0.9998222220
0.0000337192 0.0251655251 0.9807979937
0.0000119715 0.0004956799 0.9997695316
0.0000110223 0.0003616927 0.9998382811
0.0000416826 0.0548163057 0.9536377684
0.0000111387 0.0003764879 0.9998308270
0.0000125141 0.0005869713 0.9997213091
0.0000110653 0.0003671119 0.9998355548
0.0000132221 0.0007241557 0.9996470817
0.0000127683 0.0006337884 0.9996961834
0.0000410457 0.0518426385 0.9564748084
0.0000138416 0.0008617917 0.9995707513
0.0001391432 0.8522429797 0.1047050905
0.0000111346 0.0003759794 0.9998310889
0.0002381005 0.9782050744 0.0115258639
0.0000117216 0.0004573687 0.9997894595
0.0000581318 0.1710817109 0.8316064557
0.0000110568 0.0003660583 0.9998360914
0.0004570652 0.9981548651 0.0007101551
0.0004954222 0.9986435349 0.0005024456
0.0000116424 0.0004456865 0.9997954940
0.0000110995 0.0003714475 0.9998333720
0.0000493756 0.0996470737 0.9086622912
0.0000182376 0.0024641379 0.9986007243
0.0000120786 0.0005128018 0.9997605630
0.0000110454 0.0003646020 0.9998368208
0.0000114880 0.0004234010 0.9998069345
0.0000128664 0.0006526332 0.9996860314
0.0000110773 0.0003686365 0.9998347887
0.0000110358 0.0003633922 0.9998374291
0.0000111802 0.0003818208 0.9998281276
0.0000124769 0.0005803815 0.9997248220
0.0000156971 0.0013926141 0.9992637181
0.0000114507 0.0004182834 0.9998095705
0.0000600164 0.1890009047 0.8116075054

4.3 源码

4.3.1 MainSource.m

    此代码由Octave编写,是代码的核心部分,主要实现神经网络模型。

###################################################################################
################################## 第一部分:初始化 ##################################
###################################################################################
clear; close all; clc

DataSize = 146;
INIT_EPSILON = 1;
% 数据大小、随机初始化参数

theta1 = rand(4, 5) * (2 * INIT_EPSILON) - INIT_EPSILON
theta2 = rand(3, 5) * (2 * INIT_EPSILON) - INIT_EPSILON
% 权重矩阵的初始化

lambda = 0.05
% 学习速率

Etotal = 50;
% 总误差

DELTA1 = 0;
DELTA2 = 0;
%theta的偏导矩阵

Data = load('Data-iris.txt')
% 载入原始数据
XFeature = Data(:, [1:4])
Y_ini = Data(:, 5);
Y = zeros(DataSize, 1);
for i=1:DataSize
  col = Y_ini(i, 1);
  Y(i, col) = 1;
endfor
Y
% 输出层向量化

function g = sigmoid(z)
g = zeros(size(z));
g = 1./(1+e.^(-z));
end
% Sigmoid函数

###################################################################################
################################## 第二部分:反向传播 ################################
###################################################################################
while Etotal>0.1
  DELTA1 = 0;
  DELTA2 = 0;
  Etotal = 0;
  for i=1:DataSize
   
  ################################## 正向传播 ##################################
    % ---输入层 a1---
    a1 = XFeature(i, :);
    a1 = a1';               % 4x1
    a1 = [ones(1,1); a1];   % 添加偏置项后为 5x1
   
    % ---隐藏层 a2---
    z2 = theta1 * a1;
    a2 = sigmoid(z2);       % 4x1
    a2 = [ones(1,1); a2];   % 添加偏置项后为 5x1
   
    % ---输出层 a3---
    z3 = theta2 * a2;
    a3 = sigmoid(z3);       % 4x1

  ################################## 计算误差项 ################################
    y = Y(i, :);            %提取输出y
    y = y';
    delta3 = a3 - y;
    delta2 = ((theta2') * delta3).*a2.*(1 - a2);
   
    Etotal = Etotal + delta3(1) + delta3(2) + delta3(3);
   
  ################################## 计算偏导数 ################################

    DELTA1 = DELTA1 + delta2*(a1');
    DELTA2 = DELTA2 + delta3*(a2');
  endfor
  Etotal
  DELTA1 = DELTA1./DataSize();
  DELTA2 = DELTA2./DataSize();
  theta1 = theta1 .- (lambda .* DELTA1([2:5], :));
  theta2 = theta2 .- (lambda .* DELTA2);
endwhile
Etotal
theta1
theta2
4.3.2 DataView.m

    此代码使用Octave编写,主要功能是将数据可视化。

% 数据可视化

Data = load('Data-DeleteMissingData.txt')
% 此处删除不完整变量,下面进行可视化

Setosa = Data([1:39], :)
Versicolor = Data([40:82], :)
Virginica = Data([83:123], :)

##### 绘制第一特征 #####
Feature1_Setosa = Setosa(:, 1)
Feature1_Versicolor = Versicolor(:, 1)
Feature1_Virginica = Virginica(:, 1)

figure(1)
hist(Feature1_Setosa, 50)
xlabel('数值')
ylabel('数量')
title('Setosa Feature 1')

figure(2)
hist(Feature1_Versicolor, 50)
xlabel('数值')
ylabel('数量')
title('Versicolor Feature 1')

figure(3)
hist(Feature1_Virginica, 50)
xlabel('数值')
ylabel('数量')
title('Virginica Feature 1')

##### 绘制第二特征 #####
Feature2_Setosa = Setosa(:, 2)
Feature2_Versicolor = Versicolor(:, 2)
Feature2_Virginica = Virginica(:, 2)

figure(4)
hist(Feature2_Setosa, 50)
xlabel('数值')
ylabel('数量')
title('Setosa Feature 2')

figure(5)
hist(Feature2_Versicolor, 50)
xlabel('数值')
ylabel('数量')
title('Versicolor Feature 2')

figure(6)
hist(Feature2_Virginica, 50)
xlabel('数值')
ylabel('数量')
title('Virginica Feature 2')

##### 绘制第三特征 #####
Feature3_Setosa = Setosa(:, 3)
Feature3_Versicolor = Versicolor(:, 3)
Feature3_Virginica = Virginica(:, 3)

figure(7)
hist(Feature2_Setosa, 50)
xlabel('数值')
ylabel('数量')
title('Setosa Feature 3')

figure(8)
hist(Feature2_Versicolor, 50)
xlabel('数值')
ylabel('数量')
title('Versicolor Feature 3')

figure(9)
hist(Feature2_Virginica, 50)
xlabel('数值')
ylabel('数量')
title('Virginica Feature 3')

##### 绘制第四特征 #####
Feature4_Setosa = Setosa(:, 4)
Feature4_Versicolor = Versicolor(:, 4)
Feature4_Virginica = Virginica(:, 4)

figure(10)
hist(Feature2_Setosa, 50)
xlabel('数值')
ylabel('数量')
title('Setosa Feature 4')

figure(11)
hist(Feature2_Versicolor, 50)
xlabel('数值')
ylabel('数量')
title('Versicolor Feature 4')

figure(12)
hist(Feature2_Virginica, 50)
xlabel('数值')
ylabel('数量')
title('Virginica Feature 4')
4.3.3 AnsTest.m

    此代码使用Octave编写,目的是将原数据放入其中,经过权值矩阵输出结果。这里的theata1和theta2根据MainSource.m的计算结果手动替换。

###################################################################################
################################## 第一部分:初始化 ##################################
###################################################################################
clear; close all; clc

DataSize = 146;
INIT_EPSILON = 1;
% 数据大小、随机初始化参数

theta1 = [  

   14.53745    1.06695    5.41210   -3.19412  -11.71895;
   -0.10527    0.55710    0.99034    1.08439    1.00968;
   -1.42449   -0.48922   -2.83584    4.16115    1.87471;
    0.82783    0.88992    0.56318    0.86599   -0.45926;
  ]
   
theta2 = [
    0.64984    4.01768    0.55466  -14.63997    2.01987;
   -7.43639   15.33393   -6.76862   14.13434   -7.85377;
    1.62750  -17.23315    0.66101    4.67994    1.76113;
  ]
% 权重矩阵的初始化

lambda = 0.1
% 学习速率

Etotal = 0;
% 总误差

DELTA1 = 0;
DELTA2 = 0;
%theta的偏导矩阵

A = zeros(DataSize, 3)

Data = load('Data-iris.txt')
% 载入原始数据
XFeature = Data(:, [1:4])
Y_ini = Data(:, 5);
Y = zeros(DataSize, 1);
for i=1:DataSize
  col = Y_ini(i, 1);
  Y(i, col) = 1;
endfor
Y
% 输出层向量化

function g = sigmoid(z)
g = zeros(size(z));
g = 1./(1+e.^(-z));
end
% Sigmoid函数

for i=1:DataSize

   
  ################################## 正向传播 ##################################
    % ---输入层 a1---
    a1 = XFeature(i, :);
    a1 = a1';               % 4x1
    a1 = [ones(1,1); a1];   % 添加偏置项后为 5x1
   
    % ---隐藏层 a2---
    z2 = theta1 * a1;
    a2 = sigmoid(z2);       % 4x1
    a2 = [ones(1,1); a2];   % 添加偏置项后为 5x1
   
    % ---输出层 a3---
    z3 = theta2 * a2;
    a3 = sigmoid(z3)        % 4x1

    A(i,1) = a3(1);
    A(i,2) = a3(2);
    A(i,3) = a3(3);
    % ANS
  endfor
 Y
 A
4.3.4 main.cpp

    此代码使用c++语言,目的是统计输出结果。

#include <iostream>
#include <vector>
#include <algorithm>

const int DataSize = 146;

//用于存放一个答案向量组
class Vec3
{
private:
    std::vector<double> nums;

    void SetZero()
    {
        this->nums[0] = 0;
        this->nums[1] = 0;
        this->nums[2] = 0;
    }
public:
    Vec3(const double x = 0, const double y = 0, const double z = 0)
    {
        nums.clear();
        nums.push_back(x);
        nums.push_back(y);
        nums.push_back(z);
        nums.shrink_to_fit();
    }

    void DataCollation()
    {
        auto it = std::max_element(nums.begin(), nums.end());
        SetZero();
        *it = 1;
    }

    bool operator==(Vec3 comp)
    {
        return nums == comp.nums ? true : false;
    }
};

//输入数据
void DataInput(std::vector<Vec3>& _ans1, std::vector<Vec3>& _ans2, std::vector<Vec3>& _ans3, std::vector<Vec3>& _Y);
//数据修正
void DataCorrection(std::vector<Vec3>& _ans1, std::vector<Vec3>& _ans2, std::vector<Vec3>& _ans3, std::vector<Vec3>& _Y);
//统计结果
void ClassificationStatistics(std::vector<Vec3>& _ans1, std::vector<Vec3>& _ans2, std::vector<Vec3>& _ans3, std::vector<Vec3>& _Y);

int main()
{
    std::vector<Vec3> ans1;
    std::vector<Vec3> ans2;
    std::vector<Vec3> ans3;
    std::vector<Vec3> Y;
    DataInput(ans1, ans2, ans3, Y);
    //数据输入

    DataCorrection(ans1, ans2, ans3, Y);
    ClassificationStatistics(ans1, ans2, ans3, Y);

    return 0;
}

void DataInput(std::vector<Vec3>& _ans1, std::vector<Vec3>& _ans2, std::vector<Vec3>& _ans3, std::vector<Vec3>& _Y)
{
    std::cout << "input ans1" << std::endl;
    for (int i(0); i < DataSize; ++i)
    {
        double x(0), y(0), z(0);
        std::cin >> x >> y >> z;
        Vec3 m(x, y, z);
        //输入中介
        _ans1.push_back(m);
    }

    std::cout << "input ans2" << std::endl;
    for (int i(0); i < DataSize; ++i)
    {
        double x(0), y(0), z(0);
        std::cin >> x >> y >> z;
        Vec3 m(x, y, z);
        //输入中介
        _ans2.push_back(m);
    }

    std::cout << "input ans3" << std::endl;
    for (int i(0); i < DataSize; ++i)
    {
        double x(0), y(0), z(0);
        std::cin >> x >> y >> z;
        Vec3 m(x, y, z);
        //输入中介
        _ans3.push_back(m);
    }

    std::cout << "input Y" << std::endl;
    for (int i(0); i < DataSize; ++i)
    {
        double x(0), y(0), z(0);
        std::cin >> x >> y >> z;
        Vec3 m(x, y, z);
        //输入中介
        _Y.push_back(m);
    }
}

void DataCorrection(std::vector<Vec3>& _ans1, std::vector<Vec3>& _ans2, std::vector<Vec3>& _ans3, std::vector<Vec3>& _Y)
{
    for (std::vector<Vec3>::iterator it = _ans1.begin(); it != _ans1.end(); ++it)
    {
        it->DataCollation();
    }

    for (std::vector<Vec3>::iterator it = _ans2.begin(); it != _ans2.end(); ++it)
    {
        it->DataCollation();
    }

    for (std::vector<Vec3>::iterator it = _ans3.begin(); it != _ans3.end(); ++it)
    {
        it->DataCollation();
    }
}

void ClassificationStatistics(std::vector<Vec3>& _ans1, std::vector<Vec3>& _ans2, std::vector<Vec3>& _ans3, std::vector<Vec3>& _Y)
{
    Vec3 setosa(1, 0, 0);
    Vec3 versicolor(0, 1, 0);
    Vec3 virginica(0, 0, 1);

    std::vector<int> ans_setosa(4);
    std::vector<int> ans_versicolor(4);
    std::vector<int> ans_virginica(4);
    //存放分类数组,前三项为ans的分类情况,最后的为Y的标准答案

    for (auto i : _ans1)
    {
        if (i == setosa)
        {
            ans_setosa[0]++;
        }
        else if (i == versicolor)
        {
            ans_versicolor[0]++;
        }
        else if (i == virginica)
        {
            ans_virginica[0]++;
        }
    }

    for (auto i : _ans2)
    {
        if (i == setosa)
        {
            ans_setosa[1]++;
        }
        else if (i == versicolor)
        {
            ans_versicolor[1]++;
        }
        else if (i == virginica)
        {
            ans_virginica[1]++;
        }
    }

    for (auto i : _ans3)
    {
        if (i == setosa)
        {
            ans_setosa[2]++;
        }
        else if (i == versicolor)
        {
            ans_versicolor[2]++;
        }
        else if (i == virginica)
        {
            ans_virginica[2]++;
        }
    }

    for (auto i : _Y)
    {
        if (i == setosa)
        {
            ans_setosa[3]++;
        }
        else if (i == versicolor)
        {
            ans_versicolor[3]++;
        }
        else if (i == virginica)
        {
            ans_virginica[3]++;
        }
    }

    std::cout << "ans_setosa" << std::endl;
    for (auto i : ans_setosa)
    {
        std::cout << i << ' ';
    }
    std::cout << std::endl;

    std::cout << "ans_versicolor" << std::endl;
    for (auto i : ans_versicolor)
    {
        std::cout << i << ' ';
    }
    std::cout << std::endl;

    std::cout << "ans_virginica" << std::endl;
    for (auto i : ans_virginica)
    {
        std::cout << i << ' ';
    }
    std::cout << std::endl;
}