1. 程式人生 > >python實現 隨機森林(RF)的引數尋優

python實現 隨機森林(RF)的引數尋優

# -*- coding: utf-8 -*-
#RandomForestClassifier
import math
import matplotlib as mpl
import warnings
import numpy as np
from sklearn import tree
from sklearn import ensemble
from sklearn import metrics
from sklearn.metrics import auc 
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report
from sklearn.cross_validation import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.externals import joblib
import matplotlib.pyplot as plt
plt.switch_backend('agg') 
warnings.filterwarnings('ignore')


# 資料讀入,劃分訓練與驗證 
def split_data(file_name,N):  #uniprot_10_1_1_lst.txt
	data = []
	labels = []
	with open(file_name,'r') as ifile:   #-*-# file's name #-*-#
		for line in ifile:
			tokens = line.strip().split(',')
			#print('tokens:',tokens)
			data.append([int(tk) for tk in tokens[:-1]])
			labels.append(tokens[-1])

	x = np.array(data)
	labels = np.array(labels)
	y = np.zeros(labels.shape)
	y[labels=='disorder']=1  #label is 'non-disorder' and 'disorder'
	#拆分Train-Valid , # test_size = N
	x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = N,random_state=0) 
	return x_train,x_test,y_train,y_test


# 引數(n_estimators and  max_features) 尋優
with open('10-1000-2-20-scores.csv','w') as f_scores:
	scores = []
	for n in range(10,1000,10):
		for m in range(2,20): 
			list_empty = []
			clf2 = RandomForestClassifier(n_estimators = n, max_features = m,
					max_depth=None,min_samples_split=2, random_state=0)
			scores2 = cross_val_score(clf2,x_train,y_train,cv=5,scoring='accuracy')
			str_w = str(n)+','+str(m)+','+str(scores2.mean())
			print(str_w)
			f_scores.write(str_w)
			f_scores.write('\n')

1.首先對原始資料進行資料集劃分,分別得到訓練與驗證資料集(及其標籤)

2.建立RF,通過爬網格,,以ACC為標準,進行引數尋優

 

# for example #

資料格式:

執行結果如下: