KNN 近邻算法

算法描述

http://img.janhen.com/202104160847096mLEpK.jpg

KNN 算法描述

KNN(k-NearestNeighbor)又被称为最近邻算法。

思路是:若一个样本在特征空间中的 k 个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。

KNN 算法是机器学习中最简单的方法之一。KNN 是一种分类算法,KNN 没有显式的学习过程,即没有训练阶段,待收到新样本后直接进行处理。

距离计算

计算待测案例与训练样本之间的距离,常用的距离有欧式距离、曼哈顿距离、余弦距离等。

在欧几里得空间中,点x =(x1,…,xn)和 y =(y1,…,yn)之间的欧氏距离为

http://img.janhen.com/20210416090249gfm1go.png

欧几里得距离

算法实现流程:

  1. 读文件中的测试数据、训练数据集,形成数据集 X,Y
  2. 求数据集 Y 中的每个点到数据集 X 中每个点的位置,得到数据集 D
  3. 找到数据集 D 中最小的 K 个点
  4. 求 K 个点的分布情况
  5. 返回前 K 个点中出现频率最高的类别作为测试数据的预测分类

鸢尾花数据集

数据集内包含 3 类共 150 条 记录,每类各 50 个数据,

记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,可以通过这 4个 特征预测鸢尾花卉属于哪一品种(iris-setosa, iris-versicolour, iris-virginica)。

原始的数据集:Iris.csv

1
2
3
4
5
6
7
8
9
10
11
Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species
1,5.1,3.5,1.4,0.2,Iris-setosa
2,4.9,3.0,1.4,0.2,Iris-setosa
3,4.7,3.2,1.3,0.2,Iris-setosa
..
52,6.4,3.2,4.5,1.5,Iris-versicolor
53,6.9,3.1,4.9,1.5,Iris-versicolor
..
102,5.8,2.7,5.1,1.9,Iris-virginica
103,7.1,3.0,5.9,2.1,Iris-virginica
..

未知的数据集:unknown_iris.csv

1
2
3
8888,5.7,4.4,1.5,0.4,Iris-setosa22
7777,5.5,2.4,4.0,1.4,Iris-versicolor22
6666,6.8,3.2,5.1,2.3,Iris-virginica22

算法实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}

object KNNDriver {
private val K = 15

def main(args: Array[String]): Unit = {
val conf: SparkConf = new SparkConf().setAppName("knn-test").setMaster("local[4]")
val sc = new SparkContext(conf)
sc.setLogLevel("WARN")

// 读取实际的数据 id, (x1,y1,z1,e1)
val realRdd: RDD[(String, Array[Double])] = sc.textFile("data/Iris.csv").map(line => {
val fields: Array[String] = line.split(",")
if (fields(0).equals("Id")) {
("unknown", Array(-1.0))
} else {
(fields.last, fields.init.tail.map(_.toDouble))
}
})
val realBc: Broadcast[Array[(String, Array[Double])]] = sc.broadcast(realRdd.collect)

// 读取待验证的数据
val toValidRdd: RDD[(String, Array[Double])] = sc.textFile("data/unknown_iris.csv").map(line => {
val fields: Array[String] = line.split(",")
if (fields(0).equals("Id")) {
("unknown", Array(-1.0))
} else {
(fields.last, fields.init.tail.map(_.toDouble))
}
})

val varData: Array[(String, Array[Double])] = toValidRdd.collect()
varData.foreach(elem => {
val res: Array[(Double, String)] = realBc.value.map(point => (distance(point._2, elem._2), point._1))
val kNeastNeighbor: Array[(Double, String)] = res.sortBy(_._1).take(K)
val labels: Array[String] = kNeastNeighbor.map(_._2)
print(s"TestData: ${elem._2.toBuffer}, NearestNeighbor: ")
labels.groupBy(x => x).mapValues(_.length).foreach(print)
println()
})
sc.stop()
}

// 多个点之间的欧式距离
def distance(x: Array[Double], y: Array[Double]): Double = {
math.sqrt(x.zip(y).map(z => math.pow(z._1 - z._2, 2)).sum)
}
}

算法输出:

1
2
3
TestData: ArrayBuffer(5.7, 4.4, 1.5, 0.4), NearestNeighbor: (Iris-setosa,15)
TestData: ArrayBuffer(5.5, 2.4, 4.0, 1.4), NearestNeighbor: (Iris-versicolor,15)
TestData: ArrayBuffer(6.8, 3.2, 5.1, 2.3), NearestNeighbor: (Iris-virginica,14)(Iris-versicolor,1)

优缺点

优点:

训练时间复杂度低,为O(n);

简单,易于理解;

可用于非线性分类;

缺点:

使用懒散学习方法,基本上不学习,导致预测时速度比起逻辑回归之类的算法慢;

KNN 模型可解释性不强。

分类的时候,未考虑权重等因素,仅根据投票数量来决定分类结果。

Ref

k-nearest neighbors algorithm - Wikipedia

欧几里得距离 - 维基百科,自由的百科全书