HMM viterbiアルゴリズム
5548 ワード
久しぶりに更新しました.viterbiアルゴリズムを書きました.
# coding=utf-8
"""
HMM :
π
A
B
Viterbi
"""
TRAIN_CORPUS = 'trainCorpus.txt_utf8'
PROB_INIT = 'prob_init.txt'
PROB_EMIT = 'prob_emit.txt'
PROB_TRANS = 'prob_trans.txt'
def train_hmm(input_data):
init_dict = {'S': 0, 'B': 0, 'M': 0, 'E': 0}
emit_dict = {'S': {}, 'B': {}, 'M': {}, 'E': {}}
trans_dict = {'S': {'S': 0, 'B': 0, 'M': 0, 'E': 0},
'B': {'S': 0, 'B': 0, 'M': 0, 'E': 0},
'M': {'S': 0, 'B': 0, 'M': 0, 'E': 0},
'E': {'S': 0, 'B': 0, 'M': 0, 'E': 0}
}
def get_sign(sentence):
temp = []
words = sentence.strip().split()
for w in words:
w = list(w.decode('utf-8'))
if len(w) == 1:
temp.append('S')
else:
temp.append('B')
for i in range(len(w) - 2):
temp.append('M')
temp.append('E')
return temp
with open(input_data, 'r') as fi:
for line in fi:
sign = get_sign(line)
# init
if len(sign) > 1:
if sign[0] == 'S':
init_dict['S'] += 1
elif sign[0] == 'B':
init_dict['B'] += 1
# emit_dict
line = list(line.strip().replace(' ', '').decode('utf-8'))
for i in range(len(line)):
if line[i] not in emit_dict[sign[i]]:
emit_dict[sign[i]][line[i]] = 0
emit_dict[sign[i]][line[i]] += 1
# trans
sign_len = len(sign)
for i in range(1, sign_len):
trans_dict[sign[i - 1]][sign[i]] += 1
init_wr = open(PROB_INIT, 'w+')
init_total = float(sum([init_dict[i] for i in init_dict]))
for i in init_dict:
init_wr.write(i + '\t' + str(init_dict[i] / init_total) + '
')
emit_wr = open(PROB_EMIT, 'w+')
for i in emit_dict:
emit_total = float(sum([emit_dict[i][s] for s in emit_dict[i]]))
state_str = str(i)
for w in emit_dict[i]:
temp = emit_dict[i][w] / emit_total
state_str += '\t' + w.encode('utf-8') + ':' + str(temp)
emit_wr.write(state_str + '
')
trans_wr = open(PROB_TRANS, 'w+')
for i in trans_dict:
trans_total = float(sum([trans_dict[i][s] for s in trans_dict[i]]))
trans_wr.write(
i + '\tE:' + str(trans_dict[i]['E'] / trans_total) + '\tS:' + str(trans_dict[i]['S'] / trans_total)
+ '\tB:' + str(trans_dict[i]['B'] / trans_total) + '\tM:' + str(trans_dict[i]['M'] / trans_total) + '
')
def viterbi_seg(sentence):
# π、A、B
init_dict = {}
emit_dict = {}
trans_dict = {}
with open(PROB_INIT, 'r') as r_init:
for line in r_init:
line = line.strip().split('\t')
init_dict[line[0]] = float(line[1])
with open(PROB_TRANS, 'r') as r_trans:
for line in r_trans:
line = line.strip().split('\t')
trans_dict[line[0]] = {}
for i in line[1:]:
i = i.split(':')
trans_dict[line[0]][i[0]] = float(i[1])
with open(PROB_EMIT, 'r') as r_emit:
for line in r_emit:
line = line.strip().split('\t')
emit_dict[line[0]] = {}
for i in line[1:]:
i = i.split(':')
emit_dict[line[0]][i[0]] = float(i[1])
def viterbi(obs, states, start_p, trans_p, emit_p): # ( )
obs = obs.decode('utf-8')
V = [{}]
path = {}
for y in states: #
V[0][y] = start_p[y] * emit_p[y].get(obs[0].encode('utf-8'), 0) # 0, y
path[y] = [y]
for t in range(1, len(obs)):
V.append({})
newpath = {}
for y in states: # y0 -> y
(prob, state) = max(
[(V[t - 1][y0] * trans_p[y0].get(y, 0) * emit_p[y].get(obs[t].encode('utf-8'), 0), y0) for y0 in
states if
V[t - 1][y0] > 0])
V[t][y] = prob
newpath[y] = path[state] + [y]
path = newpath #
(prob, state) = max([(V[len(obs) - 1][y], y) for y in states]) # , y
return (prob, path[state]) #
res = viterbi(sentence, ('B', 'M', 'E', 'S'), init_dict, trans_dict, emit_dict)
sen_utf = list(sentence.decode('utf-8'))
temp = []
for i in range(len(sen_utf)):
temp.append(sen_utf[i])
if res[1][i] == 'S' or res[1][i] == 'E':
temp.append(' ')
return ''.join(temp)
while (True):
a = raw_input("input:")
print viterbi_seg(a)
原料の住所