基於Kubernetes的機器學習微服務系統設計系列——(六)特徵選擇微服務
阿新 • • 發佈:2018-11-10
內容提要
特徵選擇微服務主要實現如下特徵選擇演算法:Document Frequency(DF)、Information Gain(IG)、(χ2)Chi-Square Test(CHI)、Mutual Information(MI)、Matrix Projection(MP)。
特徵選擇類圖
特徵選擇類圖如圖所示:
部分實現程式碼
特徵選擇Action類
package com.robin. feature.action;
import com.robin.feature.corpus.CorpusManager;
import com.robin.feature.AbstractFeature;
import com.robin.feature.FeatureFactory;
import com.robin.feature.FeatureFactory.FeatureMethod;
import com.robin.loader.MircoServiceAction;
import com.robin.log.RobinLogger;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.codehaus.jettison.json.JSONArray;
import org.codehaus.jettison.json.JSONException;
import org.codehaus.jettison.json.JSONObject;
/**
* <DT><B>描述:</B></DT>
* <DD>特徵選擇Action類</DD>
*
* 適配Jersey伺服器資源呼叫
*
* @version Version1.0
* @author Robin
* @version <I> Date:2018-04-01</I>
* @author <I> E-mail: [email protected]</I>
*/
public class FeatureSelectAction implements MircoServiceAction {
private static final Logger LOGGER = RobinLogger.getLogger();
/**
* Action狀態碼
*/
public enum StatusCode {
OK,
JSON_ERR,
KIND_ERR,
VERSION_ERR,
TRAIN_SCALE_ERR,
METHOD_ERR,
TEXTS_NULL,
}
/**
* Action狀態內部類
*/
private class ActionStatus {
StatusCode statusCode;
String msg;
}
/**
* 獲取返回錯誤狀態JSONObject
*
* @param actionStatus
* @return JSONObject
*/
private JSONObject getErrorJson(ActionStatus actionStatus) {
JSONObject errJson = new JSONObject();
try {
errJson.put("status", actionStatus.statusCode.toString());
errJson.put("msg", actionStatus.msg);
} catch (JSONException ex) {
LOGGER.log(Level.SEVERE, ex.getMessage());
}
return errJson;
}
/**
* 檢查JSON輸入物件具體項
*
* @param jsonObj
* @param key
* @param valueSet
* @param errStatusCode
* @return ActionStatus
*/
private ActionStatus checkJSONObjectTerm(JSONObject jsonObj,
String key,
HashSet<String> valueSet,
StatusCode errStatusCode) {
ActionStatus actionStatus = new ActionStatus();
try {
if (!jsonObj.isNull(key)) {
String value = jsonObj.getString(key);
if (!valueSet.contains(value)) {
actionStatus.msg = "The value [" + value + "] of " + key + " is error.";
actionStatus.statusCode = errStatusCode;
return actionStatus;
}
} else {
actionStatus.msg = "The input parameter is missing " + key + ".";
actionStatus.statusCode = errStatusCode;
return actionStatus;
}
} catch (JSONException ex) {
LOGGER.log(Level.SEVERE, ex.getMessage());
}
actionStatus.statusCode = StatusCode.OK;
return actionStatus;
}
/**
* 檢查JSON輸入物件
*
* @param jsonObj
* @return ActionStatus
*/
private ActionStatus checkInputJSONObject(JSONObject jsonObj) {
ActionStatus actionStatus = new ActionStatus();
ActionStatus retActionStatus;
HashSet<String> valueSet = new HashSet();
valueSet.add("feature");
retActionStatus = checkJSONObjectTerm(jsonObj, "kind", valueSet, StatusCode.KIND_ERR);
if (!retActionStatus.statusCode.equals(StatusCode.OK)) {
return retActionStatus;
}
valueSet.clear();
valueSet.add("v1");
retActionStatus = checkJSONObjectTerm(jsonObj, "version", valueSet, StatusCode.VERSION_ERR);
if (!retActionStatus.statusCode.equals(StatusCode.OK)) {
return retActionStatus;
}
try {
double trainScale = jsonObj.getJSONObject("metadata").getJSONObject("feature").getDouble("trainScale");
if ((trainScale >= 1.0) || (trainScale <= 0)) {
actionStatus.statusCode = StatusCode.TRAIN_SCALE_ERR;
actionStatus.msg = "The input train_scale [" + trainScale + "] is error.";
return actionStatus;
}
valueSet.clear();
valueSet.add("DF");
valueSet.add("CHI");
valueSet.add("MP");
valueSet.add("IG");
valueSet.add("MI");
JSONArray methods = jsonObj.getJSONObject("metadata").getJSONObject("feature").getJSONArray("method");
for (int i = 0; i < methods.length(); i++) {
String method = methods.getString(i);
if (!valueSet.contains(method)) {
actionStatus.statusCode = StatusCode.METHOD_ERR;
actionStatus.msg = "The input method [" + method + "] is error.";
return actionStatus;
}
}
} catch (JSONException ex) {
LOGGER.log(Level.SEVERE, ex.getMessage());
}
actionStatus.statusCode = StatusCode.OK;
return actionStatus;
}
/**
* 覆蓋抽象類中的具體action方法<BR>
* 實現特徵選擇具體處理事物
*
* @param obj
* @return Object
*/
@Override
public Object action(Object obj) {
ActionStatus actionStatus = new ActionStatus();
ActionStatus retActionStatus;
if (!(obj instanceof JSONObject)) {
actionStatus.msg = "The action arguments is not JSONObject.";
LOGGER.log(Level.SEVERE, actionStatus.msg);
actionStatus.statusCode = StatusCode.JSON_ERR;
return this.getErrorJson(actionStatus);
}
JSONObject corpusJson = (JSONObject) obj;
retActionStatus = this.checkInputJSONObject(corpusJson);
if (!retActionStatus.statusCode.equals(StatusCode.OK)) {
LOGGER.log(Level.SEVERE, retActionStatus.msg);
return this.getErrorJson(retActionStatus);
}
try {
long beginTime = System.currentTimeMillis();
JSONObject texts = corpusJson.getJSONObject("texts");
if (null == texts) {
actionStatus.statusCode = StatusCode.TEXTS_NULL;
actionStatus.msg = "The input texts is null.";
LOGGER.log(Level.SEVERE, actionStatus.msg);
return this.getErrorJson(actionStatus);
}
//生成訓練集和測試集
CorpusManager.divide(corpusJson);
JSONObject testSetJson = (JSONObject) corpusJson.remove("testSet");
JSONObject trainSetJson = (JSONObject) corpusJson.remove("trainSet");
JSONObject metadataFeatureJson = corpusJson.getJSONObject("metadata").getJSONObject("feature");
Boolean globalFeature = metadataFeatureJson.getBoolean("globalFeature");
int globalDimension = metadataFeatureJson.getInt("globalDimension");
Boolean localFeature = metadataFeatureJson.getBoolean("localFeature");
int localDimension = metadataFeatureJson.getInt("localDimension");
JSONObject featureSelectJson = new JSONObject();
JSONObject globalFeatureJson = new JSONObject();
JSONObject localFeatureJson = new JSONObject();
//特徵選擇
JSONArray methodArr = metadataFeatureJson.getJSONArray("method");
for (int i = 0; i < methodArr.length(); i++) {
String selectMethod = methodArr.getString(i);
AbstractFeature selecter = FeatureFactory.creatInstance(trainSetJson, FeatureMethod.valueOf(selectMethod));
if (true == globalFeature) {
List<Map.Entry<Integer, Double>> featureList = selecter.selectGlobalFeature(globalDimension);
JSONArray featureArr = new JSONArray();
featureList.forEach((entry) -> {
featureArr.put(entry.getKey());
});
globalFeatureJson.put(selectMethod, featureArr);
}
if (true == localFeature) {
Map<String, List<Map.Entry<Integer, Double>>> labelsMap = selecter.selectLocalFeature(localDimension);
JSONObject labelFeatureJson = new JSONObject();
Iterator<String> labelsIt = labelsMap.keySet().iterator();
while (labelsIt.hasNext()) {
String label = labelsIt.next();
JSONArray labelFeatureArr = new JSONArray();
List<Map.Entry<Integer, Double>> localFeatureList = labelsMap.get(label);
localFeatureList.forEach((entry) -> {
labelFeatureArr.put(entry.getKey());
});
labelFeatureJson.put(label, labelFeatureArr);
}
localFeatureJson.put(selectMethod, labelFeatureJson);
}
}
featureSelectJson.put("globalFeature", globalFeatureJson);
featureSelectJson.put("localFeature", localFeatureJson);
corpusJson.put("featureSelect", featureSelectJson);
corpusJson.put("trainSet", trainSetJson);
corpusJson.put("testSet", testSetJson);
JSONObject preMetadataJson = corpusJson.getJSONObject("metadata").getJSONObject("feature");
long endTime = System.currentTimeMillis();
int spendTime = (int) (endTime - beginTime);
preMetadataJson.put("spendTime", spendTime);
} catch (JSONException ex) {
LOGGER.log(Level.SEVERE, ex.getMessage());
}
JSONObject rsp = new JSONObject();
try {
rsp.put("status", "OK");
rsp.put("result", corpusJson);
} catch (JSONException ex) {
LOGGER.log(Level.SEVERE, ex.getMessage());
}
return rsp;
}
}
特徵選擇抽象類
package com.robin.feature;
import com.robin.container.MapSort;
import com.robin.feature.corpus.CorpusManager;
import com.robin.log.RobinLogger;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.codehaus.jettison.json.JSONObject;
/**
* <DT><B>描述:</B></DT>
* <DD>特徵選擇抽象類</DD>
*
* @version Version1.0
* @author Robin
* @version <I> Date:2018-04-05</I>
* @author <I> E-mail:[email protected]</I>
*/
public abstract class AbstractFeature {
/**
* 日誌記錄器
*/
protected static final Logger LOGGER = RobinLogger.getLogger();
/**
* 訓練集所有詞的集合
*/
protected Set<Integer> globalTermsSet;
/**
* 儲存每個訓練類別的詞-文件頻數 <類標籤,<詞編碼,頻數>>
*/
protected HashMap<String, HashMap<Integer, Integer>> everyClassDFMap;
//訓練集JSON物件
protected JSONObject trainSetJson;
//全域性特徵-特徵值集合
protected HashMap<Integer, Double> globalFeatureValueMap;
//區域性特徵-特徵值集合
protected HashMap<String, HashMap<Integer, Double>> allLocalFeatureValueMap;
/**
* 特徵選擇抽象類構造方法
*
* @param trainSetJson
*/
public AbstractFeature(JSONObject trainSetJson) {
this.trainSetJson = trainSetJson;
this.allLocalFeatureValueMap = new HashMap<>();
initEveryClassDFMap();
}
/**
* 獲取非重複總詞數
*
* @return 非重複總詞數
*/
public int getAllTermTotal() {
if (globalTermsSet != null) {
return globalTermsSet.size();
}
return 0;
}
/**
* 獲取全域性特徵總數
*
* @return 全域性特徵總數
*/
public int getGlobalFeatureSize() {
if (null == globalFeatureValueMap) {
return globalFeatureValueMap.size();
}
return 0;
}
/**
* 計算全域性特徵值
*
* @return HashMap
*/
protected abstract HashMap<Integer, Double> computeGlobalFeatureValue();
/**
* 計算區域性特徵值
*
* @param label 類標籤
* @return HashMap
*/
protected abstract HashMap<Integer, Double> computeLocalFeatureValue(String label);
/**
* 全域性選取 dimension 維特徵
*
* @param dimension
* @return List
*/
public List<Map.Entry<Integer, Double>> selectGlobalFeature(int dimension) {
if (null == globalFeatureValueMap) {
// 計算全域性特徵的量化值
globalFeatureValueMap = this.computeGlobalFeatureValue();
}
List<Map.Entry<Integer, Double>> featureList = new MapSort<Integer, Double>().descendSortByValue(globalFeatureValueMap);
for (int i = featureList.size() - 1; dimension