HMM viterbiアルゴリズム

5548 ワード

久しぶりに更新しました.viterbiアルゴリズムを書きました.
# coding=utf-8
"""
           HMM     :
    π
       A
     B
    Viterbi         
"""

TRAIN_CORPUS = 'trainCorpus.txt_utf8'
PROB_INIT = 'prob_init.txt'
PROB_EMIT = 'prob_emit.txt'
PROB_TRANS = 'prob_trans.txt'


def train_hmm(input_data):
    init_dict = {'S': 0, 'B': 0, 'M': 0, 'E': 0}
    emit_dict = {'S': {}, 'B': {}, 'M': {}, 'E': {}}
    trans_dict = {'S': {'S': 0, 'B': 0, 'M': 0, 'E': 0},
                  'B': {'S': 0, 'B': 0, 'M': 0, 'E': 0},
                  'M': {'S': 0, 'B': 0, 'M': 0, 'E': 0},
                  'E': {'S': 0, 'B': 0, 'M': 0, 'E': 0}
                  }

    def get_sign(sentence):
        temp = []
        words = sentence.strip().split()
        for w in words:
            w = list(w.decode('utf-8'))
            if len(w) == 1:
                temp.append('S')
            else:
                temp.append('B')
                for i in range(len(w) - 2):
                    temp.append('M')
                temp.append('E')
        return temp

    with open(input_data, 'r') as fi:
        for line in fi:
            sign = get_sign(line)
            # init
            if len(sign) > 1:
                if sign[0] == 'S':
                    init_dict['S'] += 1
                elif sign[0] == 'B':
                    init_dict['B'] += 1

            # emit_dict
            line = list(line.strip().replace(' ', '').decode('utf-8'))
            for i in range(len(line)):
                if line[i] not in emit_dict[sign[i]]:
                    emit_dict[sign[i]][line[i]] = 0
                emit_dict[sign[i]][line[i]] += 1

            # trans
            sign_len = len(sign)
            for i in range(1, sign_len):
                trans_dict[sign[i - 1]][sign[i]] += 1

    init_wr = open(PROB_INIT, 'w+')
    init_total = float(sum([init_dict[i] for i in init_dict]))
    for i in init_dict:
        init_wr.write(i + '\t' + str(init_dict[i] / init_total) + '
') emit_wr = open(PROB_EMIT, 'w+') for i in emit_dict: emit_total = float(sum([emit_dict[i][s] for s in emit_dict[i]])) state_str = str(i) for w in emit_dict[i]: temp = emit_dict[i][w] / emit_total state_str += '\t' + w.encode('utf-8') + ':' + str(temp) emit_wr.write(state_str + '
') trans_wr = open(PROB_TRANS, 'w+') for i in trans_dict: trans_total = float(sum([trans_dict[i][s] for s in trans_dict[i]])) trans_wr.write( i + '\tE:' + str(trans_dict[i]['E'] / trans_total) + '\tS:' + str(trans_dict[i]['S'] / trans_total) + '\tB:' + str(trans_dict[i]['B'] / trans_total) + '\tM:' + str(trans_dict[i]['M'] / trans_total) + '
') def viterbi_seg(sentence): # π、A、B init_dict = {} emit_dict = {} trans_dict = {} with open(PROB_INIT, 'r') as r_init: for line in r_init: line = line.strip().split('\t') init_dict[line[0]] = float(line[1]) with open(PROB_TRANS, 'r') as r_trans: for line in r_trans: line = line.strip().split('\t') trans_dict[line[0]] = {} for i in line[1:]: i = i.split(':') trans_dict[line[0]][i[0]] = float(i[1]) with open(PROB_EMIT, 'r') as r_emit: for line in r_emit: line = line.strip().split('\t') emit_dict[line[0]] = {} for i in line[1:]: i = i.split(':') emit_dict[line[0]][i[0]] = float(i[1]) def viterbi(obs, states, start_p, trans_p, emit_p): # ( ) obs = obs.decode('utf-8') V = [{}] path = {} for y in states: # V[0][y] = start_p[y] * emit_p[y].get(obs[0].encode('utf-8'), 0) # 0, y path[y] = [y] for t in range(1, len(obs)): V.append({}) newpath = {} for y in states: # y0 -> y (prob, state) = max( [(V[t - 1][y0] * trans_p[y0].get(y, 0) * emit_p[y].get(obs[t].encode('utf-8'), 0), y0) for y0 in states if V[t - 1][y0] > 0]) V[t][y] = prob newpath[y] = path[state] + [y] path = newpath # (prob, state) = max([(V[len(obs) - 1][y], y) for y in states]) # , y return (prob, path[state]) # res = viterbi(sentence, ('B', 'M', 'E', 'S'), init_dict, trans_dict, emit_dict) sen_utf = list(sentence.decode('utf-8')) temp = [] for i in range(len(sen_utf)): temp.append(sen_utf[i]) if res[1][i] == 'S' or res[1][i] == 'E': temp.append(' ') return ''.join(temp) while (True): a = raw_input("input:") print viterbi_seg(a)
原料の住所