pythonを解読して、どうやって決定ツリーアルゴリズムを実現しますか?


データの説明
データ項目ごとにリストに保存し、最後の列に結果を保存します。
複数のデータ項目がデータセットを形成します。

data=[[d1,d2,d3...dn,result],
   [d1,d2,d3...dn,result],
        .
        .
   [d1,d2,d3...dn,result]]
意思決定ツリーのデータ構造

class DecisionNode:
  '''     
  '''
   
  def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):
    '''        
     
    args:    
    col --      col      
    value --  value    col    
    result --        ,                  。{‘  ':      }
    rb,fb --       
    '''
    self.col=col
    self.value=value
    self.results=results
    self.tb=tb
    self.fb=fb
決定ツリーの分類の最終結果は、データ項目をいくつかのサブセットに分けたもので、その中の各サブセットの結果は同じです。だから、ここでは{結果}を採用しています。結果の出現回数}の形で、各サブセットを表現しています。

def pideset(rows,column,value):
  '''     rows column   ,       value           
           
  '''
  split_function=None
  #value     
  if isinstance(value,int) or isinstance(value,float):
    #  lambda   row[column]>=value   true
    split_function=lambda row:row[column]>=value
  #value     
  else:
    #  lambda   row[column]==value   true
    split_function=lambda row:row[column]==value
  #         
  set1=[row for row in rows if split_function(row)]
  set2=[row for row in rows if not split_function(row)]
  #       
  return (set1,set2)
 
def uniquecounts(rows):
  '''     rows        ,        ,      
  '''
  results={}
  for row in rows:
    r=row[len(row)-1]
    if r not in results: results[r]=0
    results[r]+=1
  return results
 
def giniimpurity(rows):
  '''  rows         
  '''
  total=len(rows)
  counts=uniquecounts(rows)
  imp=0
  for k1 in counts:
    p1=float(counts[k1])/total
    for k2 in counts:
      if k1==k2: continue
      p2=float(counts[k2])/total
      imp+=p1*p2
  return imp
 
def entropy(rows):
  '''  rows     
  '''
  from math import log
  log2=lambda x:log(x)/log(2) 
  results=uniquecounts(rows)
  ent=0.0
  for r in results.keys():
    p=float(results[r])/len(rows)
    ent=ent-p*log2(p)
  return ent
 
def build_tree(rows,scoref=entropy):
  '''     
  '''
  if len(rows)==0: return DecisionNode()
  current_score=scoref(rows)
 
  #       
  best_gain=0.0
  #
  best_criteria=None
  #    
  best_sets=None
 
  column_count=len(rows[0])-1
  #       ,      
  for col in range(0,column_count):
    column_values={}
    #     
    for row in rows:
      column_values[row[col]]=1
    for value in column_values.keys():
      (set1,set2)=pideset(rows,col,value)
      p=float(len(set1))/len(rows)
      #       
      gain=current_score-p*scoref(set1)-(1-p)*scoref(set2)
      if gain>best_gain and len(set1)>0 and len(set2)>0:
        best_gain=gain
        best_criteria=(col,value)
        best_sets=(set1,set2)
  #                  ,       
  if best_gain>0:
    trueBranch=build_tree(best_sets[0])
    falseBranch=build_tree(best_sets[1])
    return DecisionNode(col=best_criteria[0],value=best_criteria[1],
            tb=trueBranch,fb=falseBranch)
  #                   ,    
  else:
    return DecisionNode(results=uniquecounts(rows))
 
def print_tree(tree,indent=''):
  if tree.results!=None:
    print(str(tree.results))
  else:
    print(str(tree.col)+':'+str(tree.value)+'? ')
    print(indent+'T->',end='')
    print_tree(tree.tb,indent+' ')
    print(indent+'F->',end='')
    print_tree(tree.fb,indent+' ')
 
 
def getwidth(tree):
  if tree.tb==None and tree.fb==None: return 1
  return getwidth(tree.tb)+getwidth(tree.fb)
 
def getdepth(tree):
  if tree.tb==None and tree.fb==None: return 0
  return max(getdepth(tree.tb),getdepth(tree.fb))+1
 
 
def drawtree(tree,jpeg='tree.jpg'):
  w=getwidth(tree)*100
  h=getdepth(tree)*100+120
 
  img=Image.new('RGB',(w,h),(255,255,255))
  draw=ImageDraw.Draw(img)
 
  drawnode(draw,tree,w/2,20)
  img.save(jpeg,'JPEG')
 
def drawnode(draw,tree,x,y):
  if tree.results==None:
    # Get the width of each branch
    w1=getwidth(tree.fb)*100
    w2=getwidth(tree.tb)*100
 
    # Determine the total space required by this node
    left=x-(w1+w2)/2
    right=x+(w1+w2)/2
 
    # Draw the condition string
    draw.text((x-20,y-10),str(tree.col)+':'+str(tree.value),(0,0,0))
 
    # Draw links to the branches
    draw.line((x,y,left+w1/2,y+100),fill=(255,0,0))
    draw.line((x,y,right-w2/2,y+100),fill=(255,0,0))
   
    # Draw the branch nodes
    drawnode(draw,tree.fb,left+w1/2,y+100)
    drawnode(draw,tree.tb,right-w2/2,y+100)
  else:
    txt=' 
'.join(['%s:%d'%v for v in tree.results.items()]) draw.text((x-20,y),txt,(0,0,0))
テストデータを分類する(欠落データを添付する)

def mdclassify(observation,tree):
  '''         
   
  args:
  observation --           
  tree --         
   
              
  '''
 
  #            
  if tree.results!=None:
    #        ,    result
    return tree.results
  else:
    #      col     
    v=observation[tree.col]
 
    #  col     
    if v==None:
      # tree         mdclassify,tr           ,fr           
      tr,fr=mdclassify(observation,tree.tb),mdclassify(observation,tree.fb)
 
      #                      
      tcount=sum(tr.values())
      fcount=sum(fr.values())
      tw=float(tcount)/(tcount+fcount)
      fw=float(fcount)/(tcount+fcount)
      result={}
 
      #            
      for k,v in tr.items(): 
        result[k]=v*tw
      for k,v in fr.items(): 
        # fr   k      tr , result    k
        if k not in result: 
          result[k]=0 
        # fr      result  
        result[k]+=v*fw
      return result
 
    # col     ,        
    else:
      if isinstance(v,int) or isinstance(v,float):
        if v>=tree.value: branch=tree.tb
        else: branch=tree.fb
      else:
        if v==tree.value: branch=tree.tb
        else: branch=tree.fb
      return mdclassify(observation,branch)
 
tree=build_tree(my_data)
print(mdclassify(['google',None,'yes',None],tree))
print(mdclassify(['google','France',None,None],tree))
政策決定ツリーの剪定

def prune(tree,mingain):
  '''        
   
  args:
  tree --    
  mingain --       
   
    
  '''
  #       
  if tree.tb.results==None:
    prune(tree.tb,mingain)
  if tree.fb.results==None:
    prune(tree.fb,mingain)
  #        
  if tree.tb.results!=None and tree.fb.results!=None:
    tb,fb=[],[]
    for v,c in tree.tb.results.items():
      tb+=[[v]]*c
    for v,c in tree.fb.results.items():
      fb+=[[v]]*c
    #       
    delta=entropy(tb+fb)-(entropy(tb)+entropy(fb)/2)
    #       mingain,      
    if delta<mingain:
      tree.tb,tree.fb=None,None
      tree.results=uniquecounts(tb+fb)