TextCNN の pytorch 実装 (IMDb 感情分析)

いきなりタイトルと話が違うが、DistilBERT で Sentiment Analysis を実装してみた。

これは accuracy=0.9344, f1 score=0.9347 くらいの性能を叩き出す(初期値などの具合で実行するたびに変わる。気になる人はシード指定して)。
きっとすごいんだけど、 これだけ見てるとどのくらいすごいかわからない。
そこでロジスティック回帰・ランダムフォレスト・ナイーブベイズSVM などの深層学習以前の分類モデルに加えて、 TextCNN [Kim2014] でも実装してみた。

TextCNN は畳み込みと max-pooling を組み合わせたテキスト分類器。発表時点では state-of-the-art だったんじゃあないかな?
モデルの詳細については解説ブログが結構あるので略。

枯れたモデルは sklearn で誰でもサクッと書けるが、TextCNN の今ちゃんと動く実装は意外とあんまり無い(tensorflow の古いバージョン用とかならある)。
今どきは BERT 系のモデルが圧倒的な性能を叩き出しており、その実装は、pretrained model の読み込みまで含めても Hugging Face のおかげでアホみたいに簡単なので、今更 TextCNN の出番は求められてないという話もある。
でもまあ、BERT との比較目的のように baseline としての役割ならまだ十分ある気がするので(ホント?)、ちゃんと動く TextCNN の実装を転がしておくのは有意義、ということにして以下に公開しておく。

TextCNN [Kim2014] の提案手法について、感情分析(2値分類)を以下のように実装している。*1

  • IMDb の映画レビューを元にした Large Movie Review Dataset を torchtext 経由で用いている。訓練/テストともに 25000件。
    • DistilBERT による実装では、IMDb のテキストから '<br/>' を除外しているが(accuracyが1ポイント向上)、torchtext だとその前処理をうまく入れられなくて、条件を同じにできていない……
  • term の符号化に static と non-static(符号化を初期値に使い勾配で更新する) を併用する Multi Channel。
    • ただし [Kim2014] は符号化に Word2Vec を提案しているが、この実装では GloVe を用いる(torchtext がサポートしてて楽なため)
  • 畳み込みのフィルタは各ウィンドウサイズ(3,4,5) ごとに 100個と論文では提案されているが、128個でも実験している
    • 初期値による揺れはあるが、たいてい 128個のほうが性能が高い。256個まで増やすと逆に性能は落ちる。
  • optimizer は AdamW を使用。5エポック学習しているが、最初の 1,2 エポック目でテストのスコアは最大となり、後は下がるだけ。
  • tokenizer に spacy の en_core_web_sm を利用。
    • torchtext のデフォルト tokeninzer である string.split (ただの空白区切り)と比べて、語彙数が 1/4(27万→7万) に減り、accuracy が 2ポイント向上する。

前処理で性能を上げるのはちょっとずるいかもしれないが、モデルのポテンシャルを見るという意味で汎用的かつ常識的な範囲内での前処理ということで許してもらおう(ダメならコードちょちょいと直してね)。

ここまでで記事の本題は終わっているが、せっかくなのでモデルの性能を比較しておこう。
比較するモデルは DistilBERT, TextCNN, Logistic Regression, SVM, Random Forests, Naive Bayes の6個。SVM はそのまま食わせたら終わらないので、 先に SVD で 1000次元に縮約している。その他パラメータは適当に良さげなものを選んでいる。
BERT-large などではなく DistilBERT を使ったのは、誰でも使えるリソース(Google Colab の無料ランタイム)で動かせる範囲に収めたかったから。

DistilBERT TextCNN LR SVM RF NB
accuracy 0.9289 0.8965 0.8846 0.8822 0.8501 0.8319
f1 score 0.9284 0.8940 0.8843 0.8793 0.8495 0.8197
precision 0.9351 0.9156 0.8867 0.9016 0.8530 0.8836
recall 0.9218 0.8734 0.8819 0.8581 0.8460 0.7645

当然というか、DistilBERT が頭一つ抜けている。
Browse State-of-the-Art の Sentiment Analysis on IMDb によれば、現在の最高性能(accuracy)は 0.974 とのことで、当然そのレベルには及ばないものの、わずかなコードを書くだけでこのくらいの性能を叩き出すことができるなんて、本当いい時代になったもんだ。

TextCNN は LR たちをちょっと上回るスコアであり、DistilBERT とは3ポイント以上の差がついている。上のページを見ると BERT や LSTM が成す一軍の最後尾にギリギリ滑り込めてはいるというレベル。
baseline としては、LR たちと BERT 系の間を埋めるものが欲しいところだが、そのためにはもうちょっと頑張って欲しいところ。TextCNN に attention を入れる提案をしているっぽい論文もあるので、そういうのを取り込むといいのかな。

*1:tqdm で訓練のプログレスバーを表示しているのだが、torchtext の前処理が tqdm のインスタンスを close せずに break かなにかしているようで、訓練時に tqdm のゴミが表示されるという問題の対処が一番大変だった。しかもまだ完全には抑えられていない……。