实现思路:
通过上网查资料,KNN算法就是给定一个点确定他所属的类型,其中K是指他周围作为参考点的个数,当周围K个参考点中某一类的个数占多数,那么测试点就可以认为是属于那一个多数类点的类型。所以主要就是计算测试点周围各个点的距离,然后对各个点距离进行排序,再根据K值确定周围点的个数,通过比对多数点的类型来确定测试点的类型
实验代码:
本次试验使用的语言是HTML/JavaScript,因为可以直观的把所有点以图像的形式显示出来,这样很比较有说服力下面是代码,重点部分以加粗的形式显示出来了:
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title></title>
<style>
p {float: left; margin: 20px;}
canvas, div{float: left;}
.content {
margin-top: 30px;
margin-left: 10px;
font-family: "微软雅黑";
}
.head {
clear: right;
}
</style>
</head>
<body>
<!--
作者:1094253525@qq.com
时间:2017-05-11
描述:KNN算法
-->
<div c
<div class="head">
<h1 style="text-indent: 100px;">KNN算法及其图像显示</h1>
请输入训练元组个数:<input type="text" id="practice" value="100"/><br>
请输入测试元组个数:<input type="text" id="test" value="20"/><br />
<button type="button" id="sure">确定</button>
</div>
<canvas id="canvas" width="400" height="400"></canvas>
<div class="content">
<p>训练元组:</p>
<div id="practiceArr" style="height: 450px; width: 200px;overflow: auto;"></div>
<p>测试元组:</p>
<div id="testArr" style="height: 450px; width: 150px;overflow: auto;"></div>
<p>KNN得出的类型:</p>
<div id="type" style="height: 450px; width: 40px;overflow: auto;"></div>
</div>
<div style="clear: both; "></div>
<div style="width: 1000px; margin-left: 500px;">
<p style="text-indent: 2em; font-size: 25px;">
训练元组和测试元组都是使用Math.random()函数随机生成的,类型也是事先固定的,<br>
默认类型为“1”和“-1”,取值范围是[1-4),为了便于在图像上显示我把所有数值都乘上<br>
了100,K值为3。这里共有<span id="pNum">0</span>个训练元组,<span id="tNum">0</span>个测试元组。
</p>
</div>
<script type="text/javascript">
var canvas = document.getElementById("canvas");
var ctx = canvas.getContext("2d");
ctx.font = "bolder 14px Arial";//文字字体,大小
ctx.textAlign = "left"; //文本对齐方式
ctx.textBaseline = "middle"; //文本的基线
ctx.fillText("红色圆点表示训练元组中类型为:1的数据", 100, 30);
ctx.fillText("蓝色圆点表示训练元组中类型为:-1的数据", 100, 50);
ctx.fillText("黄色圆点表示测试元组", 100, 70);
//用来获得一组随机数据组成的训练元组,
//length:获得训练元组数据的个数
function getArr(length) {
var arr = new Array();
for(var i = 0; i < length/2; i++) {
var x = (Math.random() * 3 + 1).toFixed(4);
var y = (Math.random() * 3 + 1).toFixed(4);;
var t = {X: x, Y: y, type: "-1"};
arr.push(t);
}
for(var j = length/2; j < length; j++) {
var x = (Math.random() * 3 + 1).toFixed(4);;
var y = (Math.random() * 3 + 1).toFixed(4);;
var t = {X: x, Y: y, type: "1"};
arr.push(t);
}
return arr;
}
//计算两个点之间的距离,a训练元组,b测试元组
function getDistance(a, b) {
var distance = 0;
var t = (a.X - b.X) * (a.Y - b.Y);
distance = Math.abs(t);
return distance;
}
//比较函数,比较两个数的大小
function compare(x1, x2) {
if(x1.D < x2.D) {
return -1;
} else if(x1.D > x2.D) {
return 1;
} else {
return 0;
}
}
//执行KNN算法,根据设定的K值返回测试元组的类别
//a:训练元组, b:测试元组, k:设定的K值
function knn(a, b, k) {
var temp = new Array();
var type_b = new Array();
var practice = a;
for(var i = 0; i < b.length; i++) {
var tb = b[i];
for(var j = 0; j < practice.length; j++) {
var dis = getDistance(a[j], tb);
var _temp = {B_X: tb.X, B_Y: tb.Y,
A_X: practice[j].X, A_Y: practice[j].Y, A_style: practice[j].type,
D: dis};
temp.push(_temp);
}/*end of j, the inner circle*/
temp.sort(compare);
var type1 = 0 , type2 = 0;
for(var m = 0; m < k; m++) {
if(temp[m].A_style == "1") {
type1 ++;
} else {
type2 ++;
}
}/*end of m, the inner circle*/
if(type1 > type2) {
var testResult = {X: tb.X, Y: tb.Y, type: "1"};
type_b.push(testResult);
} else {
var testResult = {X: tb.X, Y: tb.Y, type: "-1"};
type_b.push(testResult);
}
temp.length = 0; //清空数组内容
}/*end of i, the outter circle*/
return type_b;
}/*end of function KNN*/
/*获得测试元组b的数据,len为获得数据的长度*/
function getB(len) {
var bArry = new Array();
for(var i = 0; i < len; i++) {
var x = (Math.random() * 3 + 1).toFixed(4);
var y = (Math.random() * 3 + 1).toFixed(4);;
var t = {X: x, Y: y};
bArry.push(t);
}
return bArry;
};
document.getElementById("sure").onclick = function(){
ctx.clearRect(100, 100, 300, 300); //每点击一次就清空一次内容
ctx.fillStyle = "black";
ctx.fillRect(100, 100, 300, 300); //重绘背景
var l1 = 0, l2 = 0;
l1 = document.getElementById("practice").value;
l2 = document.getElementById("test").value;
/*获得训练元组arr与测试元组b的数据*/
var arr = getArr(l1);
var b = getB(l2);
/*打印信息*/
var textPra = "",
textTes = "",
textTyp = "";
/*将数据以圆点的形式显示在页面上,同时将数据显示到页面上*/
for(var i = 0; i < arr.length; i++) {
if(arr[i].type == "1") {
ctx.fillStyle = "red";
ctx.fillText("·", arr[i].X * 100, arr[i].Y * 100);
} else {
ctx.fillStyle = "blue";
ctx.fillText("·", arr[i].X * 100, arr[i].Y * 100);
}
textPra += arr[i].X + " " + arr[i].Y + " " + arr[i].type + "<br>";
}
for(var i = 0; i < b.length; i++) {
ctx.fillStyle = "yellow";
ctx.fillText("·", b[i].X * 100, b[i].Y * 100);
textTes += b[i].X + " " + b[i].Y + "<br>";
}
//执行KNN算法并获得对应的类型type,这里把K值设定为3
var test = knn(arr, b, 3);
for(var i = 0; i < test.length; i++) {
textTyp += test[i].type + "<br>";
}
document.getElementById("practiceArr").innerHTML = textPra;
document.getElementById("testArr").innerHTML = textTes;
document.getElementById("type").innerHTML = textTyp;
document.getElementById("pNum").innerHTML = arr.length;
document.getElementById("tNum").innerHTML = b.length;
}
</script>
</body>
</html>
实验结果与数据处理
实验结果都是一次性的,并没有把每一次训练的结果加入到训练元组中,而是每一次都重新生成新的训练元组和测试元组,主要是考虑到数组太多了,如果事先定义的话会很麻烦,为了减少麻烦所以测试元组和训练元组都是使用随机函数生成的。
实验中的K值我自己设定为3,因为考虑到图像并没有那么密集,太大了可能有失精确性,所以就暂确定为3,但是K的值可以自行调整。
总结
KNN算法理解完毕之后算是一个比较简单的算法了,算法的重点就是在计算点与点之间的距离,这里我用的是两点之间的距离公式来求解,因为元组中的各个点保存了它所在的横坐标和纵坐标,方便计算。
关于各个点的储存问题,JS中的数组可以很好地解决这个问题,因为数组中可以存储一整个对象,所以把各个点的类型和坐标保存下来是一件非常容易的事情,而不像C/C++/Java那样需要使用多种数据类型,所以使用JS是一个比较偷懒的方式。
刚开始的时候KNN算法特别难以理解,查找资料后也只是一知半解,但是在看过很多别人实现的代码后对问题有了自己的认识。在实现这个问题时想着如果能把所有点都画出来可能会更直观一点,考虑到最近学习的前端知识所以就想到了用JS来实现,用HTML的canvas来绘制,所以便有了以上的结果。
但是问题还是很多的瑕疵,个人理解中每一次训练完毕后的测试元组应该要变成下一次的训练元组的,所以需要把他们加入到训练元组中,但是自己没有实现,而是每一次都重新训练一遍,可能不是太好。