Keras で文書分類器を作ってみました
Tensorflow 2.x、Keras 勉強の一環で文書分類器を作ってみました。
今回は文書の読み込み・分割、ネットワークの設計、学習、モデルの保存・復元、予測を一通り実装することを目標にしています。文書分類のロジックは次回以降に探求していく予定です。
なお勉強を兼ねているので、コードで最適になっていない箇所、説明・コードに冗長な箇所がある点はご了承ください。
参照リンク
本記事は下記のTutorial, データセットを参照しています。
- Keras: Text classification from scratch
- Large Movie Review Dataset
動作環境
- Tensorflow 2.2
サンプルコード
本記事のサンプルコードはこちらになります。
本記事の構成
- 文書の読み込み・分割
- ネットワークの設計
- Multiclass classification
- Multi-label classification
- 学習
- モデルの保存・復元
- 予測
文書の読み込み・分割
1文書ずつリストに読み込み、tf.data.Dataset
へと変換します。
ラベルも同様に読み込みます。Tutorial と異なる点として、データセットのラベルはneg(negative), pos(positive)の2値ですが、ここでは3クラス以上の文書分類器への展開を想定してラベルを1次元の0, 1 ベクトルで表現するのでは無く、ラベルごとに次元を割り当てています。(今回のデータセットの場合、2次元のone-hot ベクトルになります)
text_ds = tf.data.Dataset.from_tensor_slices(texts)
label_ds = tf.data.Dataset.from_tensor_slices(labels)
if is_one_hot:
label_ds = label_ds.map(lambda x: tf.one_hot(x, LABEL_NUM))
データセットのtrain フォルダ以下のデータをさらにtraining とvalidation へ分割します。
(Tensorflow API のみで何とかできないかを模索したため冗長になっています)
def split(self, all_dataset, ratio=0.8):
DATASET_SIZE = 0
for _ in all_dataset:
DATASET_SIZE += 1
ds1_size = int(ratio * DATASET_SIZE)
ds2_size = DATASET_SIZE - ds1_size
print(("ds1_size:%d, ds2_size:%d") % (ds1_size, ds2_size));
ds1 = all_dataset.take(ds1_size)
ds2 = all_dataset.skip(ds1_size)
return ds1, ds2
ネットワーク設計
Multiclass classification とMulti-label classification の2つを取り上げます。
前者は1文書1ラベルという制約で尤もらしいラベルを割り当てる手法(ラベルのスコアの総和が1になる)で、後者はそうした制約無しに文書に対してラベルごとの尤もらしさを計算する手法(ラベルのスコアの総和が必ずしも1にならない)になります。文書のジャンル分けを例にすると、1文書1ジャンルという制約を課す場合は前者を用い、文書が複数のジャンルに属することを許容する場合は後者の手法を用います。
【参照リンク】
- Multiclass classification
- Multi-label classification
Multiclass classification
以前は、単語から単語ID への変換は自分で管理する必要がありましたが、tf.keras.layers.experimental.preprocessing.TextVectorization
を使うことで、ネットワークの中に単語から単語ID への変換を埋め込めるようになったようです。
ここではTutorial のネットワーク構成を簡略化したものを用いることにします。
最初にテキストを空白文字で分割し、分割した各文字列を単語と見なして単語を小文字にした後、単語から単語IDへ変換します。次に単語IDから単語のembedding (単語のベクトル表現)へ変換し、dropout を挟みます。最後に単語のembedding のうち各次元の最大値を取り、ラベル数の次元へとsoftmax を取ります。
# from text to word sequence by white space at default
# from word to word id
vectorize_layer = TextVectorization(
standardize=custom_standardization, # to lowercase
max_tokens=max_features,
output_mode="int",
output_sequence_length=sequence_length,
)
vectorize_layer.adapt(text_ds)
inputs = tf.keras.Input(shape=(1,), dtype="string")
indices = vectorize_layer(inputs)
# from word id to word embedding.
# max_features + 1 for OOV
x = layers.Embedding(max_features + 1, embedding_dim)(indices)
x = layers.Dropout(0.5)(x)
# global max pooling
x = layers.GlobalMaxPooling1D()(x)
predictions = layers.Dense(label_num, activation="softmax", name="predictions")(x)
model = tf.keras.Model(inputs, predictions)
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
Multi-label classification
tf.keras.layers.experimental.preprocessing.TextVectorization
を経てテキストから単語embedding へ変換、単語embedding のうち各次元の最大値を取るところまではmulticlass classification と同じです。
Multi-label classification の場合、それぞれのクラスごとに0 or 1に近づけるように学習したいので、activation
にsoftmax
では無くsigmoid
を使用します。また、loss
にbinary_crossentropy
を使用します。
predictions = layers.Dense(label_num, activation="sigmoid", name="predictions")(x)
model = tf.keras.Model(inputs, predictions)
model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])
学習
以降はMulticlass classification の場合を例に進めていきます。
fit
で学習します。
学習後にget_weights
するとモデルの中身を一部見ることができます。
数値だけでなく、文字列が含まれている点が良い感じです。
model.fit(train_ds, validation_data=val_ds, epochs=epochs)
model.evaluate(test_ds)
print(model.get_weights())
【出力】
...
782/782 [==============================] - 2s 3ms/step - loss: 0.3570 - accuracy: 0.8544
[array([b'there).', b'this).', b'thought.<br', ..., b'actors',
b'similarity', b'"b"'], dtype=object), array([20000, 19998, 19997, ..., 184, 8740, 7005]), array([[-0.00491479, -0.04982085, -0.02650738, ..., -0.00917612,
-0.04212022, -0.01764694],
...
モデルの保存・復元
ここは少しはまりました。
具体的には復元したモデルでpredict
しても精度が出ないという現象に遭遇しました。最初はtf.keras.layers.experimental.preprocessing.TextVectorization
のpackage がexperimental なので保存できないのかな…と思い、いろいろ調べると復元に失敗しているようでした。custom_objects
を指定しないと、TextVectorization
層が復元されないようです(モデルを復元した後に、get_weights()
することでモデルの中身を簡易に確認できます)
【参照リンク】
- Incompatible shapes when loading model with TextVectorization and Embedding + Conv1D #38250
単語と単語ID の対応情報はモデルに保存されるので自分で管理する必要は無くなりましたが、
ラベルとラベルIDの対応情報は自分で管理する必要があるので、モデルと同様に保存・復元します(詳細略)。
def save(self, model, model_path, label_map):
model.save(model_path, save_format="tf")
# save label
...
def load(self, model_path):
model = tf.keras.models.load_model(model_path,
custom_objects={'TextVectorization':TextVectorization, 'custom_standardization':custom_standardization})
# load label
...
return model, label_map
予測
predict
にtf.data.Dataset
を渡すことで、予測結果のラベルID がone-hot ベクトルで返ってきます。
念の為、ここで返ってきた結果を正解データと比較して正答率を計算してみると、evaluate
で返される正答率と同じであることが確認できました。
test_ds_wo_label_ds = test_ds.map(lambda x, y: x)
result = model.predict(test_ds_wo_label_ds)
result = result.argmax(axis=1)
ret = [label_map_inv[elem] for elem in result]
exp_ds = test_ds.map(lambda x, y: y)
exp_np = exp_ds.as_numpy_iterator()
corr = 0
for act, exp in zip(ret, exp_np):
if act == label_map_inv[exp[0]]:
corr += 1
print("accuracy:%f" % (corr * 100 / len(result)))
【出力】
accuracy:85.436000
データセットのテストデータ以外に適当なサンプルデータを与えてみたところ、それっぽく分類されました。
【出力】
The story is a good comedy.
neg:0.340374
pos:0.659626
It is a silly story.
neg:0.610098
pos:0.389902
This isn't a very exciting film, but it's warm.
neg:0.181591
pos:0.818409
おわりに
今回は文書分類器の作成するために、文書の読み込み・分割、ネットワーク設計、学習、モデルの保存・復元、予測、を一通り実装しました。
ネットワーク設計や学習パラメータの調整などは改善の余地があります。今後、BERT やTransformer など最新の手法を試してみたいと思います。
もう一つ改善の余地として、文書分類の際の前処理(テキストを分割する単位、各単位を正規化する処理など)があります。今回は英語の文書ということもあり、Tutorial の手法をそのまま用いて、空白文字で区切る、小文字にする、といった比較的単純な処理を使用しました。一方、日本語文書になると空白文字区切りはできないのと、仮に英語文書の場合でも、空白文字で区切るだけでは十分では無い場合があります。今後、前処理をいろいろ変えて試してみようと思います。
Author And Source
この問題について(Keras で文書分類器を作ってみました), 我々は、より多くの情報をここで見つけました https://qiita.com/theken/items/46d913d55b93c25be453著者帰属:元の著者の情報は、元のURLに含まれています。著作権は原作者に属する。
Content is automatically searched and collected through network algorithms . If there is a violation . Please contact us . We will adjust (correct author information ,or delete content ) as soon as possible .