sklearn: RandomForest

随机森林RandomForestClassifier

利用多棵树对样本进行训练并预测的一种分类器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

# -*- coding: utf-8 -*-
from sklearn.tree import DecisionTreeClassifier
from matplotlib.pyplot import *
from sklearn.cross_validation import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.externals.joblib import Parallel, delayed
from sklearn.tree import export_graphviz

final = open('c:/test/final.dat' , 'r')
data = [line.strip().split('\t') for line in final]
feature = [[float(x) for x in row[3:]] for row in data]

#特征对应的目标值,feature,target
target = [int(row[0]) for row in data]

#拆分训练集和测试集
feature_train, feature_test, target_train, target_test = train_test_split(feature, target, test_size=0.1, random_state=42)

#初始随机森林对象的参数设置
clf = RandomForestClassifier(n_estimators = 8)

#训练集开始训练
s = clf.fit(feature_train , target_train)
# return self object.

#评估模型准确率
r = clf.score(feature_test , target_test)
# return a float score: Mean accuracy of self.predict(feature_test) wrt. target_test.


print '判定结果:%s' % clf.predict(feature_test[0])
#print clf.predict_proba(feature_test[0])

print '所有的树:%s' % clf.estimators_

print clf.classes_
print clf.n_classes_

print '各feature的重要性:%s' % clf.feature_importances_

print clf.n_outputs_

def _parallel_helper(obj, methodname, *args, **kwargs):
return getattr(obj, methodname)(*args, **kwargs)

all_proba = Parallel(n_jobs=10, verbose=clf.verbose, backend="threading")(
delayed(_parallel_helper)(e, 'predict_proba', feature_test[0]) for e in clf.estimators_)
print '所有树的判定结果:%s' % all_proba

proba = all_proba[0]
for j in range(1, len(all_proba)):
proba += all_proba[j]
proba /= len(clf.estimators_)
print '数的棵树:%s , 判不作弊的树比例:%s' % (clf.n_estimators , proba[0,0])
print '数的棵树:%s , 判作弊的树比例:%s' % (clf.n_estimators , proba[0,1])

#当判作弊的树多余不判作弊的树时,最终结果是判作弊
print '判断结果:%s' % clf.classes_.take(np.argmax(proba, axis=1), axis=0)

#把所有的树都保存到word
for i in xrange(len(clf.estimators_)):
export_graphviz(clf.estimators_[i] , '%d.dot'%i)

来源: http://blog.itpub.net/12199764/viewspace-1572056/