python实现决策树ID3算法的示例代码
在周志华的西瓜书和李航的统计机器学习中对决策树ID3算法都有很详细的解释,如何实现呢?核心点有如下几个步骤
step1:计算香农熵
frommathimportlog
importoperator
#计算香农熵
defcalculate_entropy(data):
label_counts={}
forfeature_dataindata:
laber=feature_data[-1]#最后一行是laber
iflabernotinlabel_counts.keys():
label_counts[laber]=0
label_counts[laber]+=1
count=len(data)
entropy=0.0
forkeyinlabel_counts:
prob=float(label_counts[key])/count
entropy-=prob*log(prob,2)
returnentropy
step2.计算某个feature的信息增益的方法
#计算某个feature的信息增益 #index:要计算信息增益的feature对应的在data的第几列 #data的香农熵 defcalculate_relative_entropy(data,index,entropy): feat_list=[number[index]fornumberindata]#得到某个特征下所有值(某列) uniqual_vals=set(feat_list) new_entropy=0 forvalueinuniqual_vals: sub_data=split_data(data,index,value) prob=len(sub_data)/float(len(data)) new_entropy+=prob*calculate_entropy(sub_data)#对各子集香农熵求和 relative_entropy=entropy-new_entropy#计算信息增益 returnrelative_entropy
step3.选择最大信息增益的feature
#选择最大信息增益的feature defchoose_max_relative_entropy(data): num_feature=len(data[0])-1 base_entropy=calculate_entropy(data)#香农熵 best_infor_gain=0 best_feature=-1 foriinrange(num_feature): info_gain=calculate_relative_entropy(data,i,base_entropy) #最大信息增益 if(info_gain>best_infor_gain): best_infor_gain=info_gain best_feature=i returnbest_feature
step4.构建决策树
defcreate_decision_tree(data,labels):
class_list=[example[-1]forexampleindata]
#类别相同,停止划分
ifclass_list.count(class_list[-1])==len(class_list):
returnclass_list[-1]
#判断是否遍历完所有的特征时返回个数最多的类别
iflen(data[0])==1:
returnmost_class(class_list)
#按照信息增益最高选取分类特征属性
best_feat=choose_max_relative_entropy(data)
best_feat_lable=labels[best_feat]#该特征的label
decision_tree={best_feat_lable:{}}#构建树的字典
del(labels[best_feat])#从labels的list中删除该label
feat_values=[example[best_feat]forexampleindata]
unique_values=set(feat_values)
forvalueinunique_values:
sub_lables=labels[:]
#构建数据的子集合,并进行递归
decision_tree[best_feat_lable][value]=create_decision_tree(split_data(data,best_feat,value),sub_lables)
returndecision_tree
在构建决策树的过程中会用到两个工具方法:
#当遍历完所有的特征时返回个数最多的类别
defmost_class(classList):
class_count={}
forvoteinclassList:
ifvotenotinclass_count.keys():class_count[vote]=0
class_count[vote]+=1
sorted_class_count=sorted(class_count.items,key=operator.itemgetter(1),reversed=True)
returnsorted_class_count[0][0]
#工具函数输入三个变量(待划分的数据集,特征,分类值)返回不含划分特征的子集
defsplit_data(data,axis,value):
ret_data=[]
forfeat_vecindata:
iffeat_vec[axis]==value:
reduce_feat_vec=feat_vec[:axis]
reduce_feat_vec.extend(feat_vec[axis+1:])
ret_data.append(reduce_feat_vec)
returnret_data
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。