うわっ…私の言語モデル、古すぎ…?

こんにちは、AIチームの戸田です
今回は去年Google検索に導入されたことでも話題になったBERTを使った比較実験の記事を書かせていただきます

というのも昨年発表報告を書かせていただいた第10回対話シンポジウム、参加して特に印象に残ったことの一つとして、文章をベクトルに変換するモデルとして BERT^1 を使用するのが当たり前になっていたことがあります
私が遅れているだけなのかもしれませんが、とりあえず文章をベクトル化するときはBERTという雰囲気で、Word2Vecで得られた単語ベクトルをコネコネ…とやっているのは(おそらく)今回の会議では私達だけだったと思います

BERTはファインチューニングにより自然言語処理の多くのタスクでState of the artを達成しましたが、単純な文書ベクトル抽出器としての能力はどうなんでしょうか?
私は手軽に文章の分散表現を得る方法としてWord2Vecから得られた単語ベクトルの平均やmax poolingをとる SWEM^2をよく使うのですが、語順が入れ替わった文章やノイズのある文章などでは、なかなか思うようなベクトルが得られないことが多々あります

本記事ではSWEMで得られたベクトルとBERTで得られたベクトルを比較し、SWEMでの課題をBERTが解決してくれるかを検証したいと思います

Word2VecやSWEM、BERTについての説明は本記事では扱いませんのでご容赦下さい

SWEMとBERTのベクトル比較

比較する文章はAI Shiftが提供しているチャットボットプロダクトAI Messengerを導入しているとあるサイトのユーザー質問から、現在SWEMを使ったときに苦戦している語順が入れ替わった文章ノイズが含まれる文章を抽出して使用します(記事に載せる都合上、一部を変更しています)

SWEMのベクトル化手法はconcathier(window size=2) を利用し、各文章ベクトルをPCAで二次元平面上に射影して比較します

語順が入れ替わったもの

SWEMでは基本的に語順が入れ替わった文章を区別できません
下記文章で比較してみいましょう

ユーザー質問
0羽田から那覇に行きたい
1那覇から羽田に行きたい
2羽田から福岡に行きたい
3福岡から羽田に行きたい
4福岡から那覇に行きたい
5那覇から福岡に行きたい
語順: SWEM-concat.png (15.3 kB)

ただ平均やmax poolingをとるだけでは、出発地と目的地が入れ替わっていたとしても同じベクトルになってしまいます

語順: SWEM-hider_2.png (19.2 kB)

hierはn-gramのように窓をとり、平均をとった結果に対してmax poolingする方法なので、ある程度は語順を考慮できるのですが、目的地や出発地でまとまらず、あまりうまく行きません

語順: BERT.png (15.1 kB)

「〇〇に行きたい」という目的地ごとに近いベクトルが抽出されています
SWEMのベクトルを作るword2vecのベクトルの成分は-1.0〜+1.0に正規化されているのに対して、BERTはされていないので縮尺が違いますが、互いの距離関係を見ると、かなりうまく分かれていると思います(自分でも試していて驚きました)

ノイズのある文章

「〇〇を☓☓したい」のような最小限の文章ならば問題ないのですが、チャットボットには「〇〇を☓☓したいんだけどどうすればいいの?」といったノイズのある文章が度々入力されます
こういったノイズは文意を捉えづらくしてしまいます

下記文章で比較します

ユーザー質問
6チケットを郵送してもらいたい
7チケットは郵送できませんか?
8行かなくなったから返金してもらいたい
9返金はできないんですか?
10キャンセルしたいんだけどどうすればいいの?
11キャンセルさせてください
ノイズ: SWEM-concat.png (19.9 kB)
ノイズ: SWEM-hider_2.png (21.6 kB)

比較文章は、チケットの郵送返金キャンセル、といった3つのジャンルがあるのですが、SWEMはconcatもhierどちらのベクトルもまとまっていません

ノイズ: BERT.png (19.5 kB)

BERTもぱっと見すべての文章が離れてしまっていたので、うまくいかなかったのかと思ったのですが、よく見ると

図1.png (78.9 kB)

各ジャンルが近い高さにあり「〜〜ですか?」という疑問形の文章と「〜〜してほしい」という願望系の文章の関係が同じ方向を向いていることがわかります

スクリーンショット 2019-12-18 16.33.47.png (145.7 kB)

Word2Vecの生みの親、Tomas Mikolov氏の論文^3 に出てくる国と首都の単語ベクトルの関係に似ていて面白いです

まとめ

本記事では語順が入れ替わった文章ノイズが含まれる文章をSWEMとBERTでベクトル変換し、PCAで二次元平面上に射影して、それぞれのベクトルを比較しました
結果、BERTは現在私が使っているSWEMより良いベクトルを抽出できるように見え、私の最初の疑問である、文章ベクトル抽出器としてのBERTは、非常に優れたものだと考えられます
とはいえ、明らかに比較している文章が少ないですし、そもそも文書分類などのタスクに、(fine-tuningするのではなく)文章ベクトル抽出器として応用したときにどうなるのか、といったことを今後検証したいと思います

自然言語処理エンジニアとしてまだまだ知識不足なので、なにか間違いがございましたらtwitter等で指摘していただけると嬉しいです

最後までご覧いただきありがとうございました

実験で使ったコード

import numpy as np
import pandas as pd

import torch
from transformers import (
    BertModel, 
    BertConfig, 
    BertTokenizer, 
    BertForPreTraining, 
    BertConfig
)
from chainer import functions as F

from pyknp import Jumanpp
from gensim import models
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt

jm = Jumanpp()

WORD2VEC_ROOT = # Word2Vecのモデルパス
BERT_ROOT = # BERTのモデルパス
TEXT_PATH = # 評価データ

def wakati_jm(text):
    result = jm.analysis(text)
    tokenized_text =[mrph.midasi for mrph in result.mrph_list()]
    return tokenized_text

class SWEM():

    def __init__(self, vec_dict):
        self.vec_dict = vec_dict

    def __call__(self, text, mode):
        vecs = self.w2v(text)
        if mode == "aver":
            return vecs.mean(axis=0)
        elif mode == "max":
            return vecs.max(axis=0)
        elif mode == "concat":
            return np.hstack([vecs.mean(axis=0), vecs.max(axis=0)])
        elif mode == "hier_2":
            return self.hier(vecs, 2)
        elif mode == "hier_3":
            return self.hier(vecs, 3)

    def w2v(self, text):
        sep_text = wakati_jm(text)
        v = []
        for w in sep_text:
            try:
                v.append(self.vec_dict[w])
            except KeyError:
                v.append(np.zeros(250))
        return np.array(v)

    def hier(self, vecs, window):
        h, w = vecs.shape
        if h < window:
            return vecs.max(axis=0)
        v = F.average_pooling_1d(vecs.reshape(1, w, h), ksize=window).data
        return v.max(axis=2)[0]

class BertVectorizer():

    def __init__(self, model, tokenizer):
        self.model = model.eval()
        self.tokenizer = tokenizer

    def __call__(self, text):
        tokenized_text = wakati_jm(text)

        ids = tokenizer.convert_tokens_to_ids(tokenized_text)
        ids = torch.tensor(ids).reshape(1,-1)

        with torch.no_grad():
            vec = model.bert(ids)[0][0].max(0)[0]

        return vec.numpy()

def plot_pca(vecs, label, title):
    X_reduced = PCA(n_components=2, random_state=0).fit_transform(vecs)
    plt.scatter(X_reduced[:, 0], X_reduced[:, 1])
    plt.grid()
    plt.title(title)

    for label, x, y in zip(label, X_reduced[:, 0], X_reduced[:, 1]):
        plt.annotate(label, xy=(x, y), xytext=(0, 0), textcoords='offset points')
    plt.savefig(f'{title}.png')

if __name__ == "__main__":

    w2v = models.Word2Vec.load(WORD2VEC_ROOT)
    swem = SWEM(w2v.wv)

    config = BertConfig.from_json_file(BERT_ROOT + '/bert_config.json')
    model = BertForPreTraining(config=config)
    model.load_state_dict(torch.load(BERT_ROOT+"/pytorch_model.bin"))
    tokenizer = BertTokenizer(BERT_ROOT+"/vocab.txt")
    bert_vec = BertVectorizer(model, tokenizer)

    text_df = pd.read_csv(TEXT_PATH)
    label = text_df["ユーザー質問"].values

    bert_vecs = np.array(text_df["ユーザー質問"].map(bert_vec).tolist())
    swem_concat_vecs = np.array(text_df["ユーザー質問"].map(lambda x: swem(x, "concat")).tolist())
    swem_hier_2_vecs = np.array(text_df["ユーザー質問"].map(lambda x: swem(x, "hier_2")).tolist())

    plot_pca(bert_vecs[:6], label[:6], "語順: BERT")
    plot_pca(bert_vecs[6:], label[6:], "ノイズ: BERT")

    plot_pca(swem_concat_vecs[:6], label[:6], "語順: SWEM-concat")
    plot_pca(swem_concat_vecs[6:], label[6:], "ノイズ: SWEM-concat")

    plot_pca(swem_hier_2_vecs[:6], label[:6], "語順: SWEM-hider_2")
    plot_pca(swem_hier_2_vecs[6:], label[6:], "ノイズ: SWEM-hider_2")
  • Word2Vecのモデルは日本語Wikipediaを学習したものを使いました
  • BERTは京都大学の黒橋・河原研究室が公開しているモデル^5を使わせていただきまし
  • BERTの文章ベクトル抽出方法ですが、huggingfaceのdocs^6をみたところ、[CLS]トークンの出力より、文章全体の出力の平均かpoolingのほうが文章の特徴を表しているようだったので、文章全体の出力をpoolingをしています