機器學習--knn手寫數字識別系統
阿新 • • 發佈:2019-01-03
0.k近鄰演算法
剛接觸java,並且在學習機器學習的相關演算法,knn又非常的易於實現,於是就有了這個小系統。
1.knn演算法簡介:
存在一個樣本資料集合,也稱為訓練樣本集,並且樣本集中的每一個數據都有標籤,即我們知道樣本集中的每一個數據的特徵和對應的型別。當輸入沒有標籤的新的資料的時候,將新的資料集的每一個特徵和樣本集中的每一個數據的對應的特徵進行比較(計算兩個樣本的特徵之間的距離),然後提取樣本集中和輸入的新資料特徵最相似的資料的類的標籤,通常我們只關心前k個最相似的資料,這就是k近演算法中的k的出處。一般來說,我們只選擇樣本資料集中的前k最相似的資料,然後選擇k個最相似的資料集中出現次數最多的作為新資料的分類。
2.該程式的功能主要有如下幾個,
功能1:可以在面板上手寫輸入數字
功能2:可以對特定的區域進行截圖,因為要獲取使用者手寫的數字,儲存為影象,然後使用演算法進行分析
功能3:可以對圖片進行縮放,要保證圖片的大小(維度)要和資料集中的大小一樣。
功能5:對圖片中的手寫數字使用KNN演算法進行識別,也可以在測試集上計算演算法的準確性。
(演示)
功能1的實現程式碼:手寫板
建立一個JPane類的子類,通過監聽mouseDragged事件,呼叫graphics來實現手寫板的功能。
class Board extends JPanel implements MouseMotionListener { final private int boardWidth = 320; final private int boardHeight = 320; final private int boardX = 1; final private int boardY = 1; private int pencilWidth = 40; public void paint(Graphics graphics) { super.paint(graphics); graphics.setColor(Color.BLACK); graphics.draw3DRect(this.boardX - 1, this.boardY - 1, this.boardWidth + 1, this.boardHeight + 1, true); graphics.setColor(Color.WHITE); graphics.fill3DRect(this.boardX, this.boardY, this.boardWidth, this.boardHeight, true); } @Override public void mouseDragged(MouseEvent e) { // TODO Auto-generated method stub Graphics graphics = this.getGraphics(); if (e.getX() > 1 && e.getX() < boardWidth - this.pencilWidth && e.getY() > 1 && e.getY() < boardHeight - pencilWidth) graphics.fillOval(e.getX(), e.getY(), pencilWidth, pencilWidth); } @Override public void mouseMoved(MouseEvent e) { // TODO Auto-generated method stub } }
功能2的實現:可以對特定的區域進行截圖
功能3:可以對圖片進行縮放,要保證圖片的大小(維度)要和資料集中的大小一樣。class ScreenShot { private int startX; private int startY; private int width; private int height; private String saveTo; public ScreenShot(int startX, int startY, int width, int height, String filename) { this.startX = startX;//擷取的起始x座標 this.startY = startY;//擷取的起始y座標 this.width = width; //擷取的寬度 this.height = height;//擷取的高度 this.saveTo = ".\\" + filename + ".png";//圖片的儲存位置 } public void capture() { File file = new File(saveTo); try { BufferedImage bufferedImage = (new Robot()) .createScreenCapture(new Rectangle(startX, startY, width, height)); ImageIO.write(bufferedImage, "png", file); System.out.println("capture image has finish..."); } catch (AWTException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } }
class ZoomImage {
private String filename;
private float scaling;
public ZoomImage(String filename, float scaling) {
this.filename = filename;//scaling為縮放比例,在這裡是縮小的比例
this.scaling = scaling;
}
public void zoom() {
File file = new File(this.filename);
try {
BufferedImage bufferedImage1 = ImageIO.read(new File(filename));
BufferedImage bufferedImage2 = new BufferedImage((int) (this.scaling * bufferedImage1.getWidth()),
(int) (this.scaling * bufferedImage1.getHeight()), BufferedImage.TYPE_INT_BGR);
Graphics graphics = bufferedImage2.createGraphics();
graphics.drawImage(bufferedImage1, 0, 0, (int) (this.scaling * bufferedImage1.getWidth()),
(int) (this.scaling * bufferedImage1.getHeight()), null);
ImageIO.write(bufferedImage2, "png", new File(".\\zoominMaggie.png"));
System.out.println("image has been zoomed...");
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
功能4:可以將彩色圖片轉化為二值圖片
class RGB2binary {
private String filename;
private short[] userInputDigit = new short[32 * 32];
public short[] getUserInputDigit() {
return this.userInputDigit;
}
public RGB2binary(String filename) {
this.filename = filename;
}
public void rgb2binary() {
System.out.println(this.filename);
File file = new File(this.filename);
try {
BufferedImage bufferedImage = ImageIO.read(file);
int startX = bufferedImage.getMinX();
int startY = bufferedImage.getMinY();
int width = bufferedImage.getWidth();
int height = bufferedImage.getHeight();
System.out.println("x = " + startX + " y = " + startY + " width = " + width + " height = " + height);
for (int i = startX; i < width; i++) {
for (int j = startY; j < height; j++) {
int pixel = bufferedImage.getRGB(j, i);
int r = (pixel & 0xff0000) >> 16;//得到該畫素點的R值
int g = (pixel & 0xff00) >> 8;
int b = (pixel & 0xff);
float gray = r * 0.3f + g * 0.59f + b * 0.11f;//灰度變為二值的計算公式
if (gray > 128) {
System.out.print(0 + "");
userInputDigit[i * width + j] = 0;
} else {
System.out.print(1 + "");
userInputDigit[i * width + j] = 1;
}
}
System.out.println();
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
功能5:對圖片中的手寫數字使用KNN演算法進行識別,也可以在測試集上計算演算法的準確性。
class Knn {
private int featureSize = 32 * 32;
String trainingSetDir = "./trainingDigits";
String testSetDir = "./testDigits1";
private int trainingSetSize;
private int testSetSize;
private short[][] trainingData = null;
private short[] trainintSetLabel = null;
private short[][] testData = null;
private short[] testSetLabel = null;
public Knn() {
}
//讀取訓練集
public void readTrainingSet() {
File path = new File(trainingSetDir);
File files[] = path.listFiles();
System.out.println("total file number: " + files.length);
this.trainingSetSize = files.length;
trainingData = new short[trainingSetSize][32 * 32];
trainintSetLabel = new short[trainingSetSize];
int fileCount = 0;
for (File file : files) {
String[] filename = file.getName().split("_");
trainintSetLabel[fileCount] = Short.parseShort(String.valueOf(filename[0]));
int lines = 0;
char buff[] = new char[32 + 2]; //為什麼要+2:因為要讀取檔案末尾的換行和回車
int count = 0;
try {
FileReader fileReader = new FileReader(file);
while( -1 != (count = fileReader.read(buff)) ){
for( int i = 0; i < 32; i++ )
trainingData[fileCount][lines * 32 + i] = Short.parseShort(String.valueOf(buff[i]));
lines++;
}
fileReader.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
fileCount++;
}
}
//讀取測試集
public void readTestSet()
{
File path = new File(testSetDir);
File[] files = path.listFiles();
System.out.println("total number of test file" + files.length);
this.testSetSize = files.length;
testData = new short[this.testSetSize][32 * 32];
testSetLabel = new short[this.testSetSize];
int fileCount = 0;
for( File file : files )
{
String[] filename = file.getName().split("_");
testSetLabel[fileCount] = Short.parseShort(String.valueOf(filename[0]));
try {
FileReader fileReader = new FileReader(file);
int count = 0;
int lines = 0;
char buff[] = new char[32 + 2];
while( -1 != (count = fileReader.read(buff)) )
{
for( int i = 0; i < 32; i++ )
testData[fileCount][lines * 32 + i] = Short.parseShort(String.valueOf(buff[i]));
lines++;
}
fileReader.close();
fileCount++;
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
//@feature,待判斷的例項的特徵向量,
//@k,即為knn演算法中的k
//返回分類的結果
public int knn(short[] feature, int k)
{
double[] distances = new double[this.trainingSetSize];
for( int i = 0; i < trainingSetSize; i++ )
distances[i] = calculateDistance(feature, trainingData[i]);
int[] argDistance = this.arg_sort(distances);
HashMap<Short, Integer> vote = new HashMap<>();
for( int i = 0; i < k; i++ )
{
if ( null == vote.get(trainintSetLabel[argDistance[i]]) )
vote.put(trainintSetLabel[argDistance[i]], 1);
else
{
int score = vote.get(trainintSetLabel[argDistance[i]]) + 1;
vote.put(trainintSetLabel[argDistance[i]], score);
}
}
int result = 0;
int maxVote = 0;
for( short key : vote.keySet() )
{
if( maxVote < vote.get(key) )
{
result = key;
maxVote = vote.get(key);
}
}
return result;
}
//在測試集上計算該演算法的準確性
public double knnPrecise()
{
System.out.println("reading trainingSet...");
this.readTrainingSet();
System.out.println("reading trainingSet over");
System.out.println("reading testSet...");
this.readTestSet();
System.out.println("reading testSet end");
int success = 0;
for( int i = 0; i < testSetSize; i++ )
if( testSetLabel[i] == knn(testData[i], 3) )
success++;
return (double)success/testSetSize;
}
public double calculateDistance(short[] sequcence1, short[] sequence2)
{
int distance = 0;
for( int i = 0; i < sequcence1.length; i++ )
distance += (sequcence1[i] - sequence2[i]) * (sequcence1[i] - sequence2[i]);
return Math.sqrt(distance);
}
//返回的是sequence升序的下標序列
public int[] arg_sort(double[] sequence)
{
double[] sequence1 = sequence.clone();
int[] indexOfSequence = new int[sequence.length];
for( int i = 0; i < sequence1.length; i++ )
indexOfSequence[i] = i;
double minValue, tempD;
int minIndex,tempI;
for( int i = 0; i < sequence1.length - 1; i++ )
{
minValue = sequence1[i];
minIndex = i;
for( int j = i + 1; j < sequence1.length; j++ )
{
if( sequence1[j] < minValue )
{
minValue = sequence1[j];
minIndex = j;
}
}
if( i != minIndex )
{
tempD = sequence1[minIndex];
tempI = indexOfSequence[minIndex];
sequence1[minIndex] = sequence1[i];
indexOfSequence[minIndex] = indexOfSequence[i];
sequence1[i] = tempD;
indexOfSequence[i] = tempI;
}
}
return indexOfSequence;
}
public int getTrainingSetSize() {
return trainingSetSize;
}
public int getTestSetSize() {
return testSetSize;
}
}
3.結果:
在測試集上的準確性很高,但是實際應用中卻遠沒有那麼高。
完整程式碼
import java.awt.AWTException;
import java.awt.Color;
import java.awt.Font;
import java.awt.Graphics;
import java.awt.Rectangle;
import java.awt.Robot;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.MouseEvent;
import java.awt.event.MouseMotionListener;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;
import javax.imageio.ImageIO;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JTextArea;
public class Recognition extends JFrame implements ActionListener {
final private int windowWidth = 493;
final private int windowHeight = 380;
final private int windowX = 100;
final private int windowY = 100;
Board board = null;
JButton reWriteButton = null;
JButton recognitionButton = null;
JButton testButton = null;
JTextArea showResult = null;
private int contentPaneX;
private int contentPaneY;
public Recognition() {
board = new Board();
this.setLayout(null);
this.add(board);
board.setBounds(8, 8, 332, 332);
board.addMouseMotionListener(board);
reWriteButton = new JButton("Rewrite");
this.add(reWriteButton);
reWriteButton.setBounds(340, 10, 130, 30);
reWriteButton.addActionListener(this);
recognitionButton = new JButton("Recognition");
this.add(recognitionButton);
recognitionButton.setBounds(340, 40, 130, 30);
recognitionButton.addActionListener(this);
testButton = new JButton("testPrecise");
this.add(testButton);
testButton.setBounds(340, 80, 130, 30);
testButton.addActionListener(this);
showResult = new JTextArea();
showResult.setOpaque(true);
showResult.setBackground(Color.CYAN);
showResult.setForeground(Color.BLACK);
showResult.setFont(new Font("微軟雅黑", Font.BOLD, 12));
showResult.setLineWrap(true);
this.add(showResult);
showResult.setBounds(340, 180, 130, 150);
showResult.setVisible(false);
this.setTitle("HandWriting Recognition");
this.setSize(windowWidth, windowHeight);
this.setLocation(windowX, windowY);
this.setVisible(true);
this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
}
public static void main(String[] args) {
// TODO Auto-generated method stub
Recognition recognition = new Recognition();
}
@Override
public void actionPerformed(ActionEvent e) {
// TODO Auto-generated method stub
if (e.getSource() == reWriteButton) {
repaint();
} else if (e.getSource() == recognitionButton) {
this.contentPaneX = (int) this.getContentPane().getLocationOnScreen().getX();
this.contentPaneY = (int) this.getContentPane().getLocationOnScreen().getY();
ScreenShot screenShot = new ScreenShot(contentPaneX + 9, contentPaneY + 9, 320, 320, "maggie");
screenShot.capture();
ZoomImage zoomImage = new ZoomImage("./maggie.png", 0.1f);
zoomImage.zoom();
RGB2binary rgb2binary = new RGB2binary("./zoominMaggie.png");
rgb2binary.rgb2binary();
short[] userInput = rgb2binary.getUserInputDigit();
Knn knn = new Knn();
System.out.println("reading trainingSet...");
knn.readTrainingSet();
System.out.println("reading trainingSet over");
int recognitionResult = knn.knn(userInput, 3);
System.out.println("recognitionResult:"+ recognitionResult);
showResult.setText("Your input is \r\n" + String.valueOf(recognitionResult));
showResult.setVisible(true);
} else if ( e.getSource() == testButton ){
Knn knn = new Knn();
double precise = knn.knnPrecise();
String string = "Training Set Size is :\r\n" + knn.getTrainingSetSize() + "\r\nTest Set Size is :\r\n" + knn.getTestSetSize() + "\r\nAccury is \r\n" + String.valueOf(precise);
showResult.setText(string);
showResult.setVisible(true);
}
}
}
class Board extends JPanel implements MouseMotionListener {
final private int boardWidth = 320;
final private int boardHeight = 320;
final private int boardX = 1;
final private int boardY = 1;
private int pencilWidth = 40;
public void paint(Graphics graphics) {
super.paint(graphics);
graphics.setColor(Color.BLACK);
graphics.draw3DRect(this.boardX - 1, this.boardY - 1, this.boardWidth + 1, this.boardHeight + 1, true);
graphics.setColor(Color.WHITE);
graphics.fill3DRect(this.boardX, this.boardY, this.boardWidth, this.boardHeight, true);
}
@Override
public void mouseDragged(MouseEvent e) {
// TODO Auto-generated method stub
Graphics graphics = this.getGraphics();
if (e.getX() > 1 && e.getX() < boardWidth - this.pencilWidth && e.getY() > 1
&& e.getY() < boardHeight - pencilWidth)
graphics.fillOval(e.getX(), e.getY(), pencilWidth, pencilWidth);
}
@Override
public void mouseMoved(MouseEvent e) {
// TODO Auto-generated method stub
}
}
class ScreenShot {
private int startX;
private int startY;
private int width;
private int height;
private String saveTo;
public ScreenShot(int startX, int startY, int width, int height, String filename) {
this.startX = startX;
this.startY = startY;
this.width = width;
this.height = height;
this.saveTo = ".\\" + filename + ".png";
}
public void capture() {
File file = new File(saveTo);
try {
BufferedImage bufferedImage = (new Robot())
.createScreenCapture(new Rectangle(startX, startY, width, height));
ImageIO.write(bufferedImage, "png", file);
System.out.println("capture image has finish...");
} catch (AWTException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
class ZoomImage {
private String filename;
private float scaling;
public ZoomImage(String filename, float scaling) {
this.filename = filename;
this.scaling = scaling;
}
public void zoom() {
File file = new File(this.filename);
try {
BufferedImage bufferedImage1 = ImageIO.read(new File(filename));
BufferedImage bufferedImage2 = new BufferedImage((int) (this.scaling * bufferedImage1.getWidth()),
(int) (this.scaling * bufferedImage1.getHeight()), BufferedImage.TYPE_INT_BGR);
Graphics graphics = bufferedImage2.createGraphics();
graphics.drawImage(bufferedImage1, 0, 0, (int) (this.scaling * bufferedImage1.getWidth()),
(int) (this.scaling * bufferedImage1.getHeight()), null);
ImageIO.write(bufferedImage2, "png", new File(".\\zoominMaggie.png"));
System.out.println("image has been zoomed...");
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
class RGB2binary {
private String filename;
private short[] userInputDigit = new short[32 * 32];
public short[] getUserInputDigit() {
return this.userInputDigit;
}
public RGB2binary(String filename) {
this.filename = filename;
}
public void rgb2binary() {
System.out.println(this.filename);
File file = new File(this.filename);
try {
BufferedImage bufferedImage = ImageIO.read(file);
int startX = bufferedImage.getMinX();
int startY = bufferedImage.getMinY();
int width = bufferedImage.getWidth();
int height = bufferedImage.getHeight();
System.out.println("x = " + startX + " y = " + startY + " width = " + width + " height = " + height);
for (int i = startX; i < width; i++) {
for (int j = startY; j < height; j++) {
int pixel = bufferedImage.getRGB(j, i);
int r = (pixel & 0xff0000) >> 16;
int g = (pixel & 0xff00) >> 8;
int b = (pixel & 0xff);
float gray = r * 0.3f + g * 0.59f + b * 0.11f;
if (gray > 128) {
System.out.print(0 + "");
userInputDigit[i * width + j] = 0;
} else {
System.out.print(1 + "");
userInputDigit[i * width + j] = 1;
}
}
System.out.println();
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
class Knn {
private int featureSize = 32 * 32;
String trainingSetDir = "./trainingDigits";
String testSetDir = "./testDigits1";
private int trainingSetSize;
private int testSetSize;
private short[][] trainingData = null;
private short[] trainintSetLabel = null;
private short[][] testData = null;
private short[] testSetLabel = null;
public Knn() {
}
//讀取訓練集
public void readTrainingSet() {
File path = new File(trainingSetDir);
File files[] = path.listFiles();
System.out.println("total file number: " + files.length);
this.trainingSetSize = files.length;
trainingData = new short[trainingSetSize][32 * 32];
trainintSetLabel = new short[trainingSetSize];
int fileCount = 0;
for (File file : files) {
String[] filename = file.getName().split("_");
trainintSetLabel[fileCount] = Short.parseShort(String.valueOf(filename[0]));
int lines = 0;
char buff[] = new char[32 + 2]; //為什麼要+2:因為要讀取檔案末尾的換行和回車
int count = 0;
try {
FileReader fileReader = new FileReader(file);
while( -1 != (count = fileReader.read(buff)) ){
for( int i = 0; i < 32; i++ )
trainingData[fileCount][lines * 32 + i] = Short.parseShort(String.valueOf(buff[i]));
lines++;
}
fileReader.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
fileCount++;
}
}
//讀取測試集
public void readTestSet()
{
File path = new File(testSetDir);
File[] files = path.listFiles();
System.out.println("total number of test file" + files.length);
this.testSetSize = files.length;
testData = new short[this.testSetSize][32 * 32];
testSetLabel = new short[this.testSetSize];
int fileCount = 0;
for( File file : files )
{
String[] filename = file.getName().split("_");
testSetLabel[fileCount] = Short.parseShort(String.valueOf(filename[0]));
try {
FileReader fileReader = new FileReader(file);
int count = 0;
int lines = 0;
char buff[] = new char[32 + 2];
while( -1 != (count = fileReader.read(buff)) )
{
for( int i = 0; i < 32; i++ )
testData[fileCount][lines * 32 + i] = Short.parseShort(String.valueOf(buff[i]));
lines++;
}
fileReader.close();
fileCount++;
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
//@feature,待判斷的例項的特徵向量,
//@k,即為knn演算法中的k
//返回分類的結果
public int knn(short[] feature, int k)
{
double[] distances = new double[this.trainingSetSize];
for( int i = 0; i < trainingSetSize; i++ )
distances[i] = calculateDistance(feature, trainingData[i]);
int[] argDistance = this.arg_sort(distances);
HashMap<Short, Integer> vote = new HashMap<>();
for( int i = 0; i < k; i++ )
{
if ( null == vote.get(trainintSetLabel[argDistance[i]]) )
vote.put(trainintSetLabel[argDistance[i]], 1);
else
{
int score = vote.get(trainintSetLabel[argDistance[i]]) + 1;
vote.put(trainintSetLabel[argDistance[i]], score);
}
}
int result = 0;
int maxVote = 0;
for( short key : vote.keySet() )
{
if( maxVote < vote.get(key) )
{
result = key;
maxVote = vote.get(key);
}
}
return result;
}
//在測試集上計算該演算法的準確性
public double knnPrecise()
{
System.out.println("reading trainingSet...");
this.readTrainingSet();
System.out.println("reading trainingSet over");
System.out.println("reading testSet...");
this.readTestSet();
System.out.println("reading testSet end");
int success = 0;
for( int i = 0; i < testSetSize; i++ )
if( testSetLabel[i] == knn(testData[i], 3) )
success++;
return (double)success/testSetSize;
}
public double calculateDistance(short[] sequcence1, short[] sequence2)
{
int distance = 0;
for( int i = 0; i < sequcence1.length; i++ )
distance += (sequcence1[i] - sequence2[i]) * (sequcence1[i] - sequence2[i]);
return Math.sqrt(distance);
}
//返回的是sequence升序的下標序列
public int[] arg_sort(double[] sequence)
{
double[] sequence1 = sequence.clone();
int[] indexOfSequence = new int[sequence.length];
for( int i = 0; i < sequence1.length; i++ )
indexOfSequence[i] = i;
double minValue, tempD;
int minIndex,tempI;
for( int i = 0; i < sequence1.length - 1; i++ )
{
minValue = sequence1[i];
minIndex = i;
for( int j = i + 1; j < sequence1.length; j++ )
{
if( sequence1[j] < minValue )
{
minValue = sequence1[j];
minIndex = j;
}
}
if( i != minIndex )
{
tempD = sequence1[minIndex];
tempI = indexOfSequence[minIndex];
sequence1[minIndex] = sequence1[i];
indexOfSequence[minIndex] = indexOfSequence[i];
sequence1[i] = tempD;
indexOfSequence[i] = tempI;
}
}
return indexOfSequence;
}
public int getTrainingSetSize() {
return trainingSetSize;
}
public int getTestSetSize() {
return testSetSize;
}
}