機械学習実戦——練習(素朴ベイズ——迷惑メールフィルタリング)

10674 ワード

シンプルベイズ——スパムフィルタリング


説明:50通のメールがあり、そのうちスパム25通、有用メール25通、スパムカテゴリ1、有用メールカテゴリ0、メールをフィルタする必要があります.実験の考え方:
  • まず、メールの形式を処理し、テキストはベクトル
  • に処理する.
  • その後、メールをランダムにトレーニングセットとテストセットに分け、割合4:1
  • 訓練セットを取り出して訓練し,2つのカテゴリの確率分布
  • を得た.
  • 試験セットは試験を取り出し、確率は計算カテゴリに持ち込み、誤り率
  • を計算する.
    #  , 
    def textParse(bigString):  # input is big string, output is word list
        import re
        listOfTokens = re.split(r'\W*', bigString)
        return [tok.lower() for tok in listOfTokens if len(tok) > 2]
    
    
    #  40 ,10 
    def spamTest():
        docList = []
        classList = []
        fullText = []
        for i in range(1, 26):
            wordList = textParse(open('email/spam/%d.txt' % i).read())
            docList.append(wordList)
            fullText.extend(wordList)
            classList.append(1)
            wordList = textParse(open('email/ham/%d.txt' % i).read())
            docList.append(wordList)
            fullText.extend(wordList)
            classList.append(0)
        vocabList = createVocabList(docList)  # create vocabulary
        trainingSet = list(range(50))
        # print(trainingSet)
        testSet = []  # create test set
        #  10 ,40 
        for i in range(10):
            randIndex = int(random.uniform(0, len(trainingSet)))
            testSet.append(trainingSet[randIndex])
            del (trainingSet[randIndex])
        # print(testSet)
        # print(trainingSet)
        #  40 , 
        trainMat = []
        trainClasses = []
        for docIndex in trainingSet:       # train the classifier (get probs) trainNB0
            trainMat.append(bagOfWords2VecMN(vocabList, docList[docIndex]))
            trainClasses.append(classList[docIndex])
        # print(trainMat)
        # print(trainClasses)
        p0V, p1V, pSpam = trainNB0(array(trainMat), array(trainClasses))
        errorCount = 0
        #  , 
        for docIndex in testSet:            # classify the remaining items
            wordVector = bagOfWords2VecMN(vocabList, docList[docIndex])
            if classifyNB(array(wordVector), p0V, p1V, pSpam) != classList[docIndex]:
                errorCount += 1
                print("classification error", docList[docIndex])
        print('the error rate is: ', float(errorCount) / len(testSet))
        # return vocabList,fullText