利用SKL实现KNN算法
本文通过SKL
中的KNN
分类器,实现了对iris
数据的的分类预测,主要涉及的内容包含:
- KNN算法原理
- KNN算法优缺点
- K值选取
- SKL实现KNN
KNN 算法原理
找个K个和新数据最近的样本,取样本中最多的一个类别作为新数据的类别
KNN优点
- 算法简单易实现
把全部的数据当做模型本身
- 对边界不规则的数据效果好
KNN缺点
-
只适合小数据集
-
数据不平衡效果不好
-
数据必须标准化
-
不适合特征维度太多的数据
K 值选取
-
k越小越容易过拟合
-
k越小大越容易欠拟合
k值一般选择是奇数,偶数可能难以抉择
SKL实现KNN
1 | from sklearn import datasets |
1 | np.random.seed(0) # 保证每次运行的结果相同;不设置的话,默认按照时间作为参数的 |
导入数据
1 | iris = datasets.load_iris() |
1 | X = iris.data |
array([[5.1, 3.5, 1.4, 0.2],
[4.9, 3. , 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2]])
1 | y = iris.target |
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
分割数据
将数据分成训练集和测试集,总共是150条,分成140的训练集和10条的测试集
注意是随机分割,使用的是permutation参数
1 | # 产生该长度内的随机乱序的一维数据 |
1 | # 训练集 |
1 | X_train |
array([[5.8, 2.8, 5.1, 2.4],
[6. , 2.2, 4. , 1. ],
[5.5, 4.2, 1.4, 0.2],
[7.3, 2.9, 6.3, 1.8],
[5. , 3.4, 1.5, 0.2],
[6.3, 3.3, 6. , 2.5],
[5. , 3.5, 1.3, 0.3],
[6.7, 3.1, 4.7, 1.5],
[6.8, 2.8, 4.8, 1.4],
[6.1, 2.8, 4. , 1.3],
[6.1, 2.6, 5.6, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.5, 2.8, 4.6, 1.5],
[6.1, 2.9, 4.7, 1.4],
[4.9, 3.6, 1.4, 0.1],
[6. , 2.9, 4.5, 1.5],
[5.5, 2.6, 4.4, 1.2],
[4.8, 3. , 1.4, 0.3],
[5.4, 3.9, 1.3, 0.4],
[5.6, 2.8, 4.9, 2. ],
[5.6, 3. , 4.5, 1.5],
[4.8, 3.4, 1.9, 0.2],
[4.4, 2.9, 1.4, 0.2],
[6.2, 2.8, 4.8, 1.8],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.8, 1.9, 0.4],
[6.2, 2.9, 4.3, 1.3],
[5. , 2.3, 3.3, 1. ],
[5. , 3.4, 1.6, 0.4],
[6.4, 3.1, 5.5, 1.8],
[5.4, 3. , 4.5, 1.5],
[5.2, 3.5, 1.5, 0.2],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.2],
[5.2, 2.7, 3.9, 1.4],
[5.7, 3.8, 1.7, 0.3],
[6. , 2.7, 5.1, 1.6],
[5.9, 3. , 4.2, 1.5],
[5.8, 2.6, 4. , 1.2],
[6.8, 3. , 5.5, 2.1],
[4.7, 3.2, 1.3, 0.2],
[6.9, 3.1, 5.1, 2.3],
[5. , 3.5, 1.6, 0.6],
[5.4, 3.7, 1.5, 0.2],
[5. , 2. , 3.5, 1. ],
[6.5, 3. , 5.5, 1.8],
[6.7, 3.3, 5.7, 2.5],
[6. , 2.2, 5. , 1.5],
[6.7, 2.5, 5.8, 1.8],
[5.6, 2.5, 3.9, 1.1],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.3, 4.7, 1.6],
[5.5, 2.4, 3.8, 1.1],
[6.3, 2.7, 4.9, 1.8],
[6.3, 2.8, 5.1, 1.5],
[4.9, 2.5, 4.5, 1.7],
[6.3, 2.5, 5. , 1.9],
[7. , 3.2, 4.7, 1.4],
[6.5, 3. , 5.2, 2. ],
[6. , 3.4, 4.5, 1.6],
[4.8, 3.1, 1.6, 0.2],
[5.8, 2.7, 5.1, 1.9],
[5.6, 2.7, 4.2, 1.3],
[5.6, 2.9, 3.6, 1.3],
[5.5, 2.5, 4. , 1.3],
[6.1, 3. , 4.6, 1.4],
[7.2, 3.2, 6. , 1.8],
[5.3, 3.7, 1.5, 0.2],
[4.3, 3. , 1.1, 0.1],
[6.4, 2.7, 5.3, 1.9],
[5.7, 3. , 4.2, 1.2],
[5.4, 3.4, 1.7, 0.2],
[5.7, 4.4, 1.5, 0.4],
[6.9, 3.1, 4.9, 1.5],
[4.6, 3.1, 1.5, 0.2],
[5.9, 3. , 5.1, 1.8],
[5.1, 2.5, 3. , 1.1],
[4.6, 3.4, 1.4, 0.3],
[6.2, 2.2, 4.5, 1.5],
[7.2, 3.6, 6.1, 2.5],
[5.7, 2.9, 4.2, 1.3],
[4.8, 3. , 1.4, 0.1],
[7.1, 3. , 5.9, 2.1],
[6.9, 3.2, 5.7, 2.3],
[6.5, 3. , 5.8, 2.2],
[6.4, 2.8, 5.6, 2.1],
[5.1, 3.8, 1.6, 0.2],
[4.8, 3.4, 1.6, 0.2],
[6.5, 3.2, 5.1, 2. ],
[6.7, 3.3, 5.7, 2.1],
[4.5, 2.3, 1.3, 0.3],
[6.2, 3.4, 5.4, 2.3],
[4.9, 3. , 1.4, 0.2],
[5.7, 2.5, 5. , 2. ],
[6.9, 3.1, 5.4, 2.1],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.6, 1.4, 0.2],
[7.2, 3. , 5.8, 1.6],
[5.1, 3.5, 1.4, 0.3],
[4.4, 3. , 1.3, 0.2],
[5.4, 3.9, 1.7, 0.4],
[5.5, 2.3, 4. , 1.3],
[6.8, 3.2, 5.9, 2.3],
[7.6, 3. , 6.6, 2.1],
[5.1, 3.5, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.2],
[5.2, 3.4, 1.4, 0.2],
[5.7, 2.8, 4.5, 1.3],
[6.6, 3. , 4.4, 1.4],
[5. , 3.2, 1.2, 0.2],
[5.1, 3.3, 1.7, 0.5],
[6.4, 2.9, 4.3, 1.3],
[5.4, 3.4, 1.5, 0.4],
[7.7, 2.6, 6.9, 2.3],
[4.9, 2.4, 3.3, 1. ],
[7.9, 3.8, 6.4, 2. ],
[6.7, 3.1, 4.4, 1.4],
[5.2, 4.1, 1.5, 0.1],
[6. , 3. , 4.8, 1.8],
[5.8, 4. , 1.2, 0.2],
[7.7, 2.8, 6.7, 2. ],
[5.1, 3.8, 1.5, 0.3],
[4.7, 3.2, 1.6, 0.2],
[7.4, 2.8, 6.1, 1.9],
[5. , 3.3, 1.4, 0.2],
[6.3, 3.4, 5.6, 2.4],
[5.7, 2.8, 4.1, 1.3],
[5.8, 2.7, 3.9, 1.2],
[5.7, 2.6, 3.5, 1. ],
[6.4, 3.2, 5.3, 2.3],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 4.9, 1.5],
[6.7, 3. , 5. , 1.7],
[5. , 3. , 1.6, 0.2],
[5.5, 2.4, 3.7, 1. ],
[6.7, 3.1, 5.6, 2.4],
[5.8, 2.7, 5.1, 1.9],
[5.1, 3.4, 1.5, 0.2],
[6.6, 2.9, 4.6, 1.3]])
1 | y_train |
array([2, 1, 0, 2, 0, 2, 0, 1, 1, 1, 2, 1, 1, 1, 1, 0, 1, 1, 0, 0, 2, 1,
0, 0, 2, 0, 0, 1, 1, 0, 2, 1, 0, 2, 2, 1, 0, 1, 1, 1, 2, 0, 2, 0,
0, 1, 2, 2, 2, 2, 1, 2, 1, 1, 2, 2, 2, 2, 1, 2, 1, 0, 2, 1, 1, 1,
1, 2, 0, 0, 2, 1, 0, 0, 1, 0, 2, 1, 0, 1, 2, 1, 0, 2, 2, 2, 2, 0,
0, 2, 2, 0, 2, 0, 2, 2, 0, 0, 2, 0, 0, 0, 1, 2, 2, 0, 0, 0, 1, 1,
0, 0, 1, 0, 2, 1, 2, 1, 0, 2, 0, 2, 0, 0, 2, 0, 2, 1, 1, 1, 2, 2,
1, 1, 0, 1, 2, 2, 0, 1])
1 | # 测试集 |
1 | X_test |
array([[5.6, 3. , 4.1, 1.3],
[5.9, 3.2, 4.8, 1.8],
[6.3, 2.3, 4.4, 1.3],
[5.5, 3.5, 1.3, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.9, 3.1, 1.5, 0.1],
[6.3, 2.9, 5.6, 1.8],
[5.8, 2.7, 4.1, 1. ],
[7.7, 3.8, 6.7, 2.2],
[4.6, 3.2, 1.4, 0.2]])
1 | y_test |
array([1, 1, 1, 0, 0, 0, 2, 1, 2, 0])
KNN训练器
1 | # 定义一个knn分类器对象 |
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=None, n_neighbors=5, p=2,
weights='uniform')
预测
1 | y_predict = knn.predict(X_test) # 传入预测的数据 |
array([1, 2, 1, 0, 0, 0, 2, 1, 2, 0])
1 | # 计算各个测试样本预测的概率值 |
array([[0. , 1. , 0. ],
[0. , 0.4, 0.6],
[0. , 1. , 0. ],
[1. , 0. , 0. ],
[1. , 0. , 0. ],
[1. , 0. , 0. ],
[0. , 0. , 1. ],
[0. , 1. , 0. ],
[0. , 0. , 1. ],
[1. , 0. , 0. ]])
1 | # 计算和最后一个测试样本距离最近的5个点,返回的是这些样本的序号组成的数组 |
(array([[0.14142136, 0.14142136, 0.2236068 , 0.2236068 , 0.2236068 ]]),
array([[ 75, 41, 96, 78, 123]]))
1 | # 调用对象的打分方法,计算出准确率 |
0.9
输出测试结果
对比knn的预测值和实际的值,准确率是90%
1 | # knn的预测值 |
array([1, 2, 1, 0, 0, 0, 2, 1, 2, 0])
1 | # 原始值 |
array([1, 1, 1, 0, 0, 0, 2, 1, 2, 0])