1. 程式人生 > >機器學習特徵工程——給任意屬性增加任意次方的全組合

機器學習特徵工程——給任意屬性增加任意次方的全組合

在機器學習中,我們時常會碰到需要給屬性增加欄位的情況。譬如有x、y兩個屬性,當結果傾向於線性時,我們可以很簡單的通過線性迴歸得到模型。但很多時候,線性(在數學上稱為多元一次方程),線性是擬合不了結果的。

往往,我們就需要在給定的幾個屬性上,通過增加屬性來嘗試能否擬合。那麼原本只有兩列,x、y,我們增加2次方的屬性後,就會變成x、y、x^2、x*y、y^2,變成了5個屬性,根據以往經驗,我們知道通過這5個屬性是能擬合出曲線。

2次方時,我們還能很簡單的寫出來所有的組合形式,但是當5次方時,原本有4列時,我們該增加多少列,增加的列該怎麼計算呢。這就有點麻煩了,譬如(x+y+z)^3展開後就是x^3+y^3+z^3+3xy^2+3xz^2+3x^2y

+3yz^2+3x^2z+3y^2z+6xyz. 去掉係數後,就是我們需要追加的所有列了。我們這篇就是做一個程式,來通過給定的m列,n次方,來給出所有的組合形式。

譬如m為2,n也為2,那麼我們給出結果組合:[{0,2}, {1,1}, {2,0}],代表追加3列,第一列是x^0 * y^2,第二列是x^1 * y^1,第三列是x^2 * y^0.

通過觀察我們發現,我們需要做的是求這樣的方程的所有解:X1+X2+X3+……+Xm = N。其中0<=X<=n。

那麼解法就是,我們可以定義一個int[m],該陣列共有m個元素,每個元素的取值範圍在0到n之間,並且該陣列的所有元素的和等於n即可。

直接看程式:

/**
 * @author wuweifeng wrote on 2018/6/4.
 */
public class LineAdder {
    private static int lines = 3;
    private static int power = 5;

    private static int[] resultArray;

    public static void main(String[] args) {
        resultArray = new int[lines];
        deal(0);
    }

    public static void deal(int m) {
        for (int i = 0; i <= power; i++) {
            resultArray[m] = i;
            if (m == lines - 1) {
                //如果找到一個解
                if (check()) {
                    print();
                    return;
                }
            } else {
                deal(m + 1);
            }
        }
    }

    /**
     * 判斷是否符合結果
     *
     * @return 是否符合
     */
    private static boolean check() {
        int total = 0;
        for (int one : resultArray) {
            total += one;
        }
        return power == total;
    }

    private static void print() {
        for (int one : resultArray) {
            System.out.print(one);
        }
        System.out.print("\n");
    }
}    

結果是:

005
014
023
032
041
050
104
113
122
131
140
203
212
221
230
302
311
320
401
410
500
這就是有3列,並且希望求出5次方時的所有組合的答案。

下面我們將它優化一下,讓他能處理文字,能處理一行一行的資料,直接把列追加在文字上。

直接上程式碼:

package ploy;

import java.util.ArrayList;
import java.util.List;

/**
 * @author wuweifeng wrote on 2018/6/4.
 */
public class LineAdder {
    private int lines = 3;
    private int power = 5;

    private List<int[]> resultList = new ArrayList<>();

    private int[] resultArray;

    public List<int[]> lineAdd(int lines, int power) {
        resultArray = new int[lines];
        this.lines = lines;
        this.power = power;
        deal(0);
        return resultList;
    }

    private void deal(int m) {
        for (int i = 0; i <= power; i++) {
            resultArray[m] = i;
            if (m == lines - 1) {
                //如果找到一個解
                if (check()) {
                    print();
                    return;
                }
            } else {
                deal(m + 1);
            }
        }
    }

    /**
     * 判斷是否符合結果
     *
     * @return 是否符合
     */
    private boolean check() {
        int total = 0;
        for (int one : resultArray) {
            total += one;
        }
        return power == total;
    }

    private void print() {
        for (int one : resultArray) {
            System.out.print(one);

        }
        System.out.print("\n");
        int[] temp = new int[resultArray.length];
        System.arraycopy(resultArray, 0, temp, 0, resultArray.length);
        resultList.add(temp);
    }
}
package ploy;

import java.io.*;
import java.util.List;

/**
 * @author wuweifeng wrote on 2018/6/5.
 */
public class TextDeal {
    public static void main(String[] args) throws IOException {
        new TextDeal().linePower("/Users/wuwf/Downloads/ml_data/1邏輯迴歸入門/train_test_deal.csv",
                "/Users/wuwf/Downloads/ml_data/1邏輯迴歸入門/train_test_deal-3.csv", 3, 1, 2, 3, 6);
    }

    /**
     * @param filePath
     *         檔案的路徑
     * @param outputPath
     *         輸出檔案的路徑
     * @param power
     *         要做幾次方
     * @param lineNums
     *         都有哪幾列,需要power,不填預設所有列。從第0列開始
     */
    public void linePower(String filePath, String outputPath, Integer power, Integer... lineNums) throws IOException {
        BufferedReader reader = buildReader(filePath);
        BufferedWriter writer = buildWriter(outputPath);

        addCSVHeader(reader, writer, power, lineNums);

    }

    private Integer[] getLineNums(String[] lines, Integer... lineNums) {
        //為null,則是所有列
        if (lineNums == null) {
            lineNums = new Integer[lines.length];
            for (int i = 0; i < lines.length; i++) {
                lineNums[i] = i;
            }
        }
        return lineNums;
    }

    private List<int[]> getAddList(int power, Integer... lineNums) {
        LineAdder lineAdder = new LineAdder();
        //計算共需增加多少列
        return lineAdder.lineAdd(lineNums.length, power);
    }

    /**
     * 給header裡增加相應的列名,都在第一行
     */
    private void addCSVHeader(BufferedReader reader, BufferedWriter writer, Integer power, Integer... lineNums)
            throws IOException {
        //讀取第一行
        String header = reader.readLine();
        //所有的列名
        String[] lines = header.split(",");
        lineNums = getLineNums(lines, lineNums);

        //計算共需增加多少列
        List<int[]> list = getAddList(power, lineNums);

        String[] addLines = new String[list.size()];

        String[] needLines = new String[lineNums.length];
        for (int i = 0; i < lineNums.length; i++) {
            needLines[i] = lines[lineNums[i]];
        }
        //設定每一列的名字
        for (int i = 0; i < list.size(); i++) {
            int[] array = list.get(i);
            String s = "";
            for (int j = 0; j < array.length; j++) {
                s += needLines[j] + array[j];
            }
            addLines[i] = s;
        }

        for (String addLine : addLines) {
            header += "," + addLine;
        }
        //將新增的列,寫入header檔案
        writer.write(header);
        writer.newLine();
        writer.flush();

        String oneLine;

        while ((oneLine = reader.readLine()) != null) {
            addLines = new String[list.size()];
            lines = oneLine.split(",");

            needLines = new String[lineNums.length];
            for (int i = 0; i < lineNums.length; i++) {
                needLines[i] = lines[lineNums[i]];
            }

            //設定每一列的值
            for (int i = 0; i < list.size(); i++) {
                int[] array = list.get(i);
                double s = 1;
                try {
                    for (int j = 0; j < array.length; j++) {
                        //譬如a,b,對應02時,該列就是a的0次方乘以b的2次方
                        s *= Math.pow(Double.valueOf(needLines[j]), array[j]);
                    }
                    addLines[i] = s + "";
                } catch (Exception e) {
                    addLines[i] = "?";
                }

            }
            for (String addLine : addLines) {
                oneLine += "," + addLine;
            }
            writer.write(oneLine);

            //寫入相關檔案
            writer.newLine();
        }

        //將新增的列,寫入header檔案
        writer.flush();
        //關閉流
        reader.close();
        writer.close();
    }

    private BufferedReader buildReader(String filePath) {
        try {
            return new BufferedReader(new FileReader(new File(filePath)));
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            return null;
        }
    }

    private BufferedWriter buildWriter(String outputPath) {
        //寫入相應的檔案
        try {
            return new BufferedWriter(new OutputStreamWriter(new FileOutputStream(outputPath), "utf-8"));
        } catch (UnsupportedEncodingException | FileNotFoundException e) {
            e.printStackTrace();
            return null;
        }
    }

}

假如csv檔案是這樣的

a,b
1,2
2,3

4,5

執行後,結果是

a,b,a0b2,a1b1,a2b0
1,2,4.0,2.0,1.0
2,3,9.0,6.0,4.0
4,5,25.0,20.0,16.0

可以看到已經完成了做2次方的展開。

這個類,可以完成任意次方的模擬及計算。