1. 程式人生 > >ML.NET 示例:二元分類之信用卡欺詐檢測

ML.NET 示例:二元分類之信用卡欺詐檢測

寫在前面

準備近期將微軟的machinelearning-samples翻譯成中文,水平有限,如有錯漏,請大家多多指正。
如果有朋友對此感興趣,可以加入我:https://github.com/feiyun0112/machinelearning-samples.zh-cn

基於二元分類和PCA的信用卡欺詐檢測

ML.NET 版本 API 型別 狀態 應用程式型別 資料型別 場景 機器學習任務 演算法
v0.7 動態API 更新至0.7 兩個控制檯應用程式 .csv 檔案 欺詐檢測 二元分類 FastTree 二元分類

在這個介紹性示例中,您將看到如何使用ML.NET來預測信用卡欺詐。在機器學習領域中,這種型別的預測被稱為二元分類。

API版本:基於動態和評估器的API

請務必注意,此示例使用動態API和評估器。

問題

這個問題的核心是預測信用卡交易(及其相關資訊/變數)是否是欺詐。

交易的輸入資訊僅包含PCA轉換後的數值輸入變數。遺憾的是,基於隱私原因,原始特徵和附加的背景資訊無法得到,但您建立模型的方式不會改變。

特徵V1, V2, ... V28是用PCA獲得的主成分,未經PCA轉換的特徵是“Time”和“Amount”。

“Time”特徵包含每個交易和資料集中的第一個交易之間經過的秒數。“Amount”特徵是交易金額,該特徵可用於依賴於示例的代價敏感學習。特徵“Class”是響應變數,如果存在欺詐取值為1,否則為0。

資料集非常不平衡,正類(欺詐)資料佔所有交易的0.172%。

使用這些資料集,您可以建立一個模型,當預測該模型時,它將分析交易的輸入變數並預測欺詐值為false或true。

資料集

訓練和測試資料基於公共資料集dataset available at Kaggle,其最初來自於Worldline和ULB(Université Libre de Bruxelles)的機器學習小組(ttp://mlg.ulb.ac.be)在研究合作期間收集和分析的資料集。

這些資料集包含2013年9月由歐洲持卡人通過信用卡進行的交易。 這個資料集顯示了兩天內發生的交易,在284,807筆交易中有492個欺詐。

作者:Andrea Dal Pozzolo、Olivier Caelen、Reid A. Johnson和Gianluca Bontempi。基於欠取樣的不平衡分類概率。2015在計算智慧和資料探勘(CIDM)學術研討會上的發言

有關相關主題的當前和過去專案的更多詳細資訊,請訪問 http://mlg.ulb.ac.be/BruFencehttp://mlg.ulb.ac.be/ARTML

機器學習任務 - 二元分類

二元或二項式分類是根據分類規則將給定集合中的元素分成兩組(預測每個元素屬於哪個組)的任務。需要決定某項是否具有某種定性屬性、某些特定特徵的上下文

解決方案

要解決這個問題,首先需要建立一個機器學習模型。 然後,您可以在現有訓練資料上訓練模型,評估其準確性有多好,最後使用該模型(在另一個應用程式中部署建立的模型)來預測信用卡交易樣本是否存在欺詐。

Build -> Train -> Evaluate -> Consume

1. 建立模型

建立一個模型包括:

  • 定義對映到資料集的資料架構,以便使用DataReader讀取

  • 拆分訓練和測試資料

  • 建立一個評估器,並使用ConcatEstimator()轉換資料,並通過均值方差進行標準化。

  • 選擇一個訓練/學習演算法(FastTree)來訓練模型。

初始程式碼類似以下內容:


    // Create a common ML.NET context.
    // Seed set to any number so you have a deterministic environment for repeateable results
    MLContext mlContext = new MLContext(seed:1);

[...]
    TextLoader.Column[] columns = new[] {
           // A boolean column depicting the 'label'.
           new TextLoader.Column("Label", DataKind.BL, 30),
           // 29 Features V1..V28 + Amount
           new TextLoader.Column("V1", DataKind.R4, 1 ),
           new TextLoader.Column("V2", DataKind.R4, 2 ),
           new TextLoader.Column("V3", DataKind.R4, 3 ),
           new TextLoader.Column("V4", DataKind.R4, 4 ),
           new TextLoader.Column("V5", DataKind.R4, 5 ),
           new TextLoader.Column("V6", DataKind.R4, 6 ),
           new TextLoader.Column("V7", DataKind.R4, 7 ),
           new TextLoader.Column("V8", DataKind.R4, 8 ),
           new TextLoader.Column("V9", DataKind.R4, 9 ),
           new TextLoader.Column("V10", DataKind.R4, 10 ),
           new TextLoader.Column("V11", DataKind.R4, 11 ),
           new TextLoader.Column("V12", DataKind.R4, 12 ),
           new TextLoader.Column("V13", DataKind.R4, 13 ),
           new TextLoader.Column("V14", DataKind.R4, 14 ),
           new TextLoader.Column("V15", DataKind.R4, 15 ),
           new TextLoader.Column("V16", DataKind.R4, 16 ),
           new TextLoader.Column("V17", DataKind.R4, 17 ),
           new TextLoader.Column("V18", DataKind.R4, 18 ),
           new TextLoader.Column("V19", DataKind.R4, 19 ),
           new TextLoader.Column("V20", DataKind.R4, 20 ),
           new TextLoader.Column("V21", DataKind.R4, 21 ),
           new TextLoader.Column("V22", DataKind.R4, 22 ),
           new TextLoader.Column("V23", DataKind.R4, 23 ),
           new TextLoader.Column("V24", DataKind.R4, 24 ),
           new TextLoader.Column("V25", DataKind.R4, 25 ),
           new TextLoader.Column("V26", DataKind.R4, 26 ),
           new TextLoader.Column("V27", DataKind.R4, 27 ),
           new TextLoader.Column("V28", DataKind.R4, 28 ),
           new TextLoader.Column("Amount", DataKind.R4, 29 )
       };

   TextLoader.Arguments txtLoaderArgs = new TextLoader.Arguments
                                               {
                                                   Column = columns,
                                                   // First line of the file is a header, not a data row.
                                                   HasHeader = true,
                                                   Separator = ","
                                               };


[...]
    var classification = new BinaryClassificationContext(env);

    (trainData, testData) = classification.TrainTestSplit(data, testFraction: 0.2);

[...]

    //Get all the column names for the Features (All except the Label and the StratificationColumn)
    var featureColumnNames = _trainData.Schema.GetColumns()
        .Select(tuple => tuple.column.Name) // Get the column names
        .Where(name => name != "Label") // Do not include the Label column
        .Where(name => name != "StratificationColumn") //Do not include the StratificationColumn
        .ToArray();

    var pipeline = _mlContext.Transforms.Concatenate("Features", featureColumnNames)
                    .Append(_mlContext.Transforms.Normalize(inputName: "Features", outputName: "FeaturesNormalizedByMeanVar", mode: NormalizerMode.MeanVariance))                       
                    .Append(_mlContext.BinaryClassification.Trainers.FastTree(labelColumn: "Label", 
                                                                              featureColumn: "Features",
                                                                              numLeaves: 20,
                                                                              numTrees: 100,
                                                                              minDatapointsInLeaves: 10,
                                                                              learningRate: 0.2));

2. 訓練模型

訓練模型是在訓練資料(具有已知欺詐值)上執行所選演算法以調整模型引數的過程。它是在評估器物件的 Fit() 方法中實現。

為了執行訓練,您需要在DataView物件中提供了訓練資料集(trainData.csv)後呼叫 Fit() 方法。

    var model = pipeline.Fit(_trainData);

3. 評估模型

我們需要這一步驟來判定我們的模型對新資料的準確性。 為此,上一步中的模型再次針對另一個未在訓練中使用的資料集(testData.csv)執行。

Evaluate()比較測試資料集的預測值,並生成各種指標,例如準確性,您可以對其進行瀏覽。

    var metrics = _context.Evaluate(model.Transform(_testData), "Label");

4. 使用模型

訓練完模型後,您可以使用Predict()API來預測交易是否存在欺詐。

[...]

   ITransformer model;
   using (var file = File.OpenRead(_modelfile))
   {
       model = mlContext.Model.Load(file);
   }

   var predictionFunc = model.MakePredictionFunction<TransactionObservation, TransactionFraudPrediction>(mlContext);

[...]

    dataTest.AsEnumerable<TransactionObservation>(mlContext, reuseRowObject: false)
                        .Where(x => x.Label == true)
                        .Take(numberOfTransactions)
                        .Select(testData => testData)
                        .ToList()
                        .ForEach(testData => 
                                    {
                                        Console.WriteLine($"--- Transaction ---");
                                        testData.PrintToConsole();
                                        predictionFunc.Predict(testData).PrintToConsole();
                                        Console.WriteLine($"-------------------");
                                    });
[...]

    dataTest.AsEnumerable<TransactionObservation>(mlContext, reuseRowObject: false)
                        .Where(x => x.Label == false)
                        .Take(numberOfTransactions)
                        .ToList()
                        .ForEach(testData =>
                                    {
                                        Console.WriteLine($"--- Transaction ---");
                                        testData.PrintToConsole();
                                        predictionFunc.Predict(testData).PrintToConsole();
                                        Console.WriteLine($"-------------------");
                                    });