1. 程式人生 > >簡單的遺傳演算法(Genetic algorithms)-吃豆人

簡單的遺傳演算法(Genetic algorithms)-吃豆人

遺傳演算法簡介:

一直都在收聽卓老闆聊科技這個節目,最近播出了一起人工智慧的節目,主要講的是由霍蘭提出的遺傳演算法,在目中詳細闡述了一個有趣的小實驗:吃豆人。

首先簡單介紹下遺傳演算法:
1:為了解決某個具體的問題,先隨機生成若干個解決問題的實體,每個實體解決問題的方式都用“基因”來表示,也就是說,不同的實體擁有不同的基因,那麼也對應著不同的解決問題的方案。
2:有了若干實體之後,接下來就是讓這些實體來完成這個任務,根據任務的完成情況用相同標準打分。
3:接下來是進化環節,按照得分的高低,得出每個個體被選出的概率,得分越高越容易被選出,先選出兩個個體,對其基因進行交叉,再按照設定的概率對其基因進行突變

,來生成新個體,不停重複直到生成足夠數量的新個體,這便是一次進化過程。按照這個方法不停的進化,若干代之後就能得到理想的個體。

下面簡單介紹下吃豆人實驗:

吃豆人首先生存在一個10*10個格子組成的矩形空間中,將50個豆子隨機放在這100個格子中,每個格子要嘛為空,要嘛就有一顆豆子。吃豆人出生的時候隨機出現在一個任意方格中,接下來吃豆人需要通過自己的策略來吃豆子,一共只有200步,吃到一顆+10分,撞牆-5分,發出吃豆子的動作卻沒吃到豆子-1分。另外吃豆人只能看到自己所在格子和上下左右一共5個格子的情況。

整理一下
吃豆人的所有動作:上移、下移、左移、由移、吃豆、不動、隨機移動,一共7種
吃豆人所能觀察到的狀態:每個格子有,有豆子,無豆子,牆3種狀態,而一共有5個格子,那就是3^5=243種狀態。

為此,吃豆人個體的基因可以用243長度的基因表示,分別對應所處的243種狀態,每個基因有7種情況,分別表示所處狀態下產生的反應。

程式碼

Main.java

public class Main {
    public static void main(String[] args) {
        Population population = new Population(1000, false);
        System.out.println(population);
        long count = 1;
        while (true){                           
            Population newPopulation = Algorithm.evolve(population);
            if
(count % 5 == 0) { System.out.println("The " + count + "'s evolve"); System.out.println(newPopulation); } population = newPopulation; count++; } } }

Individual.java

public class Individual {

    //吃豆人一共會有3^5種狀態,它能觀察的位置一共有上下左右和當前格子,一個共5個,每個格子有牆,豆子,無豆子3種狀態。
    private static int length = 243;
    /*吃豆人一共有7總動作
     * 0 :上    4 : 隨機移動
     * 1 : 左   5 : 吃
     * 2 : 下   6 : 不動  
     * 3 : 右    
    */
    private static byte actionNum = 7;

    private byte genes[] = null;
    private int fitness = Integer.MIN_VALUE;

    public Individual() {
        genes = new byte[length];       
    }

    public void generateGenes(){        
        for (int i = 0; i < length; i++) {
            byte gene = (byte) Math.floor(Math.random() * actionNum);
            genes[i] = gene;
        }
    }

    public int getFitness() {
        if (fitness == Integer.MIN_VALUE) {
            fitness = FitnessCalc.getFitnessPall(this);
        }
        return fitness;
    }


    public int getLength() {
        return length;
    }

    public byte getGene(int index) {
        return genes[index];
    }

    public void setGene(int index, byte gene) {
        this.genes[index] = gene;
        fitness = Integer.MIN_VALUE;
    }

    //狀態碼的轉換:5個3進位制位,第一個代表中,第二個代表上,第三個代表右,第四個代表下,第五個代表左
    public byte getActionCode(State state) {        
        int stateCode = (int) (state.getMiddle() * Math.pow(3, 4) + state.getUp() * Math.pow(3, 3) + state.getRight() * Math.pow(3, 2) + state.getDown() * 3 + state.getLeft());
        return genes[stateCode];
    }

    @Override
    public String toString() {  
        StringBuffer bf = new StringBuffer();
        for (int i = 0; i < length; i++) {
            bf.append(genes[i]);
        }
        return bf.toString();
    }

    public static void main(String[] args) {
        Individual ind = new Individual();
        ind.generateGenes();
        System.out.println(ind);
        System.out.println(ind.getFitness());
        System.out.println(FitnessCalc.getFitnessPall(ind));
    }
}

State.java

public class State {
    //0為牆,1為有豆子,2為無豆子   
    private byte middle;
    private byte up;
    private byte right;
    private byte down;
    private byte left;

    public State(byte middle, byte up, byte right, byte down, byte left) {
        this.middle = middle;
        this.up = up;
        this.right = right;
        this.down = down;
        this.left = left;
    }

    public byte getMiddle() {
        return middle;
    }

    public void setMiddle(byte middle) {
        this.middle = middle;
    }

    public byte getUp() {
        return up;
    }

    public void setUp(byte up) {
        this.up = up;
    }

    public byte getRight() {
        return right;
    }

    public void setRight(byte right) {
        this.right = right;
    }

    public byte getDown() {
        return down;
    }

    public void setDown(byte down) {
        this.down = down;
    }

    public byte getLeft() {
        return left;
    }

    public void setLeft(byte left) {
        this.left = left;
    }


}

Algorithm.java

public class Algorithm {
    /* GA 演算法的引數 */
    private static final double uniformRate = 0.5; //交叉概率
    private static final double mutationRate = 0.0001; //突變概率
    private static final int tournamentSize = 3; //淘汰陣列的大小

    public static Population evolve(Population pop) {
        Population newPopulation = new Population(pop.size(), true);

        for (int i = 0; i < pop.size(); i++) {
        //隨機選擇兩個 優秀的個體
            Individual indiv1 = tournamentSelection(pop);
            Individual indiv2 = tournamentSelection(pop);           
            //進行交叉
            Individual newIndiv = crossover(indiv1, indiv2);
            newPopulation.saveIndividual(i, newIndiv);  
        }

        // Mutate population  突變
        for (int i = 0; i < newPopulation.size(); i++) {
            mutate(newPopulation.getIndividual(i));
        }   
        return newPopulation;       
    }       

    // 隨機選擇一個較優秀的個體,用了進行交叉
    private static Individual tournamentSelection(Population pop) {
        // Create a tournament population
        Population tournamentPop = new Population(tournamentSize, true);
        //隨機選擇 tournamentSize 個放入 tournamentPop 中
        for (int i = 0; i < tournamentSize; i++) {
            int randomId = (int) (Math.random() * pop.size());
            tournamentPop.saveIndividual(i, pop.getIndividual(randomId));
        }
        // 找到淘汰陣列中最優秀的
        Individual fittest = tournamentPop.getFittest();
        return fittest;
    }

    // 進行兩個個體的交叉 。 交叉的概率為uniformRate
    private static Individual crossover(Individual indiv1, Individual indiv2) {
        Individual newSol = new Individual();
        // 隨機的從 兩個個體中選擇 
        for (int i = 0; i < indiv1.getLength(); i++) {
            if (Math.random() <= uniformRate) {
                newSol.setGene(i, indiv1.getGene(i));
            } else {
                newSol.setGene(i, indiv2.getGene(i));
            }
        }
        return newSol;
    }

    // 突變個體。 突變的概率為 mutationRate
    private static void mutate(Individual indiv) {
        for (int i = 0; i < indiv.getLength(); i++) {
            if (Math.random() <= mutationRate) {
                // 生成隨機的 0-6
                byte gene = (byte) Math.floor(Math.random() * 7);
                indiv.setGene(i, gene);
            }
        }
    }
}

Population.java

public class Population {

    private Individual[] individuals;

    public Population(int size, boolean lazy) {
        individuals = new Individual[size];
        if (!lazy) {
            for (int i = 0; i < individuals.length; i++) {
                Individual ind = new Individual();
                ind.generateGenes();
                individuals[i] = ind;
            }
        }
    }

    public void saveIndividual(int index, Individual ind) {
        individuals[index] = ind;
    }

    public Individual getIndividual(int index) {
        return individuals[index];
    }

    public Individual getFittest() {
        Individual fittest = individuals[0];
        // Loop through individuals to find fittest
        for (int i = 1; i < size(); i++) {
            if (fittest.getFitness() <= getIndividual(i).getFitness()) {
                fittest = getIndividual(i);
            }
        }
        return fittest;
    }

    public Individual getLeastFittest() {
        Individual ind = individuals[0];
        for (int i = 1; i < size(); i++) {
            if (ind.getFitness() > getIndividual(i).getFitness()) {
                ind = getIndividual(i);
            }
        }
        return ind;
    }

    public double getAverageFitness() {
        double sum = 0;
        for (int i = 0; i < size(); i++) {
            sum += individuals[i].getFitness();
        }
        return sum / size();
    }

    public int size() {
        return individuals.length;
    }

    @Override
    public String toString(){
        StringBuffer bf = new StringBuffer();
        bf.append("Population size: " + size() + "\n");
        bf.append("Max Fitnewss: " + getFittest().getFitness() + "\n");
        bf.append("Least Fitness: " + getLeastFittest().getFitness() + "\n");
        bf.append("Average Fitness: " + getAverageFitness() + "\n");        
        return bf.toString();
    }

    public static void main(String[] args) {
        Population population = new Population(8000, false);
        System.out.println(population);    
    }
}

MapMgr.java

public class MapMgr {

    private static int x = 10;
    private static int y = 10;
    private static int beanNum = 50;
    private static int mapNum = 100;

    private static MapMgr manager = null;       
    private Map[] maps = null;

    private MapMgr() {
        maps = new Map[mapNum];
        for (int i = 0; i < mapNum; i++) {
            Map map = new Map(x, y);
            map.setBeans(beanNum);
            maps[i] = map;
        }
    }

    synchronized public static MapMgr getInstance() {
        if (manager == null) manager = new MapMgr();
        return manager;
    }

    public Map getMap(int index) {
        Map map = null;
        index = index % mapNum;
        try {
            map = maps[index].clone();
        } catch (CloneNotSupportedException e) {
            e.printStackTrace();
        }
        return map;     
    }

    public static void main(String[] args) {
        MapMgr mgr = MapMgr.getInstance();
        mgr.getMap(1).print();
        System.out.println("--------------");
        mgr.getMap(2).print();
    }
}

Map.java

import java.awt.Point;

public class Map implements Cloneable{

    private int x = -1;
    private int y = -1;
    private int total = -1;
    private byte[][] mapGrid = null;

    public Map(int x, int y) {
        this.x = x;
        this.y = y;
        mapGrid = new byte[x][y];
        total = x * y;
    }

    public void setBeans(int num) {
        //check num 
        if (num > total) {
            num = total;
        }
        for (int i = 0; i < num; i++) {
            int address, xp, yp;
            do{
                address = (int) Math.floor((Math.random() * total)); //生成0 - (total-1)的隨機數          
                xp = address / y;
                yp = address % y;   
                //System.out.println(xp+ ":" + yp + ":" + address + ":" + total);
            } while (mapGrid[xp][yp] != 0);
            mapGrid[xp][yp] = 1;            
        }

    }

    public boolean isInMap(int x, int y) {      
        if (x < 0 || x >= this.x) return false;
        if (y < 0 || y >= this.y) return false;     
        return true;
    }

    public boolean hasBean(int x, int y) {
        boolean ret = mapGrid[x][y] == 0 ? false : true;
        return ret;
    }

    public boolean eatBean(int x, int y) {
        if(hasBean(x, y)) {
            mapGrid[x][y] = 0;
            return true;
        }
        return false;
    }

    public Point getStartPoint() {              
        int x = (int) Math.floor(Math.random() * this.x);
        int y = (int) Math.floor(Math.random() * this.y);       
        return new Point(x, y);
    }

    public State getState(Point p) {        
        byte middle = stateOfPoint(p);
        byte up = stateOfPoint(new Point(p.x, p.y - 1));
        byte right = stateOfPoint(new Point(p.x + 1, p.y));
        byte down = stateOfPoint(new Point(p.x, p.y + 1));
        byte left = stateOfPoint(new Point(p.x - 1, p.y));
        return new State(middle, up, right, down, left);
    }

    //0為牆,1為有豆子,2為無豆子
    private byte stateOfPoint(Point p) {
        byte ret;

        if (!isInMap(p.x, p.y)) ret = 0;            
        else if (mapGrid[p.x][p.y] == 0) ret =  2;
        else ret = 1;

        return ret;
    }


    @Override
    public Map clone() throws CloneNotSupportedException {
        Map m = (Map) super.clone();
        byte[][] mapGrid = new byte[x][y];
        for (int i = 0; i < x; i++) {
            for (int j = 0; j < y; j++) {
                mapGrid[i][j] = this.mapGrid[i][j];
            }
        }
        m.mapGrid = mapGrid;
        return m;       
    }

    public void print() {
        for (int i = 0; i < y; i++) {
            for (int j = 0; j < x; j++) {
                System.out.print(mapGrid[j][i]);
            }
            System.out.println();
        }
    }

    public static void main(String[] args) {
        Map m = new Map(10, 5);
        Map m1 = null;
        try {
            m1 = m.clone();
        } catch (CloneNotSupportedException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        m.setBeans(40);
        m.print();
        m1.setBeans(15);
        m1.print();
    }

}

FitnessCalc

import java.awt.Point;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;

public class FitnessCalc {
    /*動作結果說明:
     * 撞牆:-5分
     * 吃到豆子:10分
     * 吃空了:-1分
     * 其他:0分
     */ 
    //模擬進行的場數
    private static int DefaultSimTimes = 1000;
    //模擬進行的步數
    private static int simSteps = 200;
    private static int cores = 4;

    public static int getFitness(Individual ind) {
        return getFitness(ind, DefaultSimTimes);
    }

    public static int getFitness(Individual ind, int simTimes) {
        int fitness = 0;        
        MapMgr mgr = MapMgr.getInstance();  
        for (int i = 0; i < simTimes; i++) {
            Map map = mgr.getMap(i);
            Point point = map.getStartPoint();  
            for (int j = 0; j < simSteps; j++) {
                State state = map.getState(point);
                byte actionCode = ind.getActionCode(state);
                fitness += action(point, map, actionCode);
                //map.print();
                //System.out.println("---");
            }                               
        }       
        return fitness / simTimes;
    }

    public static int getFitnessPall(Individual ind) {
        int fitness = 0;        
        if (DefaultSimTimes < 100) {
            fitness = getFitness(ind);
        } else {                            
            FutureTask<Integer>[] tasks = new FutureTask[cores];            
            for (int i = 0; i < cores; i++) {
                FitnessPall pall = null;
                if (i == 0) {
                    pall = new FitnessPall(ind, (DefaultSimTimes / cores) + DefaultSimTimes % cores);
                } else {
                    pall = new FitnessPall(ind, DefaultSimTimes / cores);   
                }               
                tasks[i] = new FutureTask<Integer>(pall);
                Thread thread = new Thread(tasks[i]);
                thread.start();
            }       
            for (int i = 0; i < cores; i++) {
                try {
                    fitness += tasks[i].get();
                } catch (InterruptedException | ExecutionException e) {
                    e.printStackTrace();
                }
            }
            fitness = fitness / cores;
        }
        return fitness;
    }


    private static int action(Point point, Map map, int actionCode) {
        int sorce = 0;
        switch (actionCode) {
        case 0:
            if (map.isInMap(point.x, point.y - 1)) {
                sorce = 0;
                point.y = point.y - 1;
            } else {
                sorce = -5;
            }           
            break;
        case 1:
            if (map.isInMap(point.x - 1, point.y)) {
                sorce = 0;
                point.x = point.x - 1;
            } else {
                sorce = -5;
            }
            break;
        case 2:
            if (map.isInMap(point.x, point.y + 1)) {
                sorce = 0;
                point.y = point.y + 1;
            } else {
                sorce = -5;
            }
            break;
        case 3: 
            if (map.isInMap(point.x + 1, point.y)) {
                sorce = 0;
                point.x = point.x + 1;
            } else {
                sorce = -5;
            }
            break;
        case 4:
            int randomCode = (int) Math.floor(Math.random() * 4);
            sorce = action(point, map, randomCode);         
            break;
        case 5:
            if (map.eatBean(point.x, point.y)) {
                sorce = 10;             
            } else {
                sorce = -1;
            }
            break;
        case 6: 
            sorce = 0;
            break;
        }
        return sorce;
    }


}

class FitnessPall implements Callable<Integer> {
    private int simTimes;
    private Individual ind;
    public FitnessPall(Individual ind, int simTimes) {
        this.ind = ind;
        this.simTimes = simTimes;       
    }

    @Override
    public Integer call() throws Exception {
        return FitnessCalc.getFitness(ind, simTimes);       
    }   
}