BERTでTOEICの問題を解いてみる①

こんにちわ
AIチームの戸田です

今回の記事でもBERT^1を扱わせていただきます

BERTの事前学習タスクであるMasked LM、こちらは、入力文のトークンをランダムに[MASK]シンボルに置き換え出力でその単語を予測する、という学習ですが、このタスクどこか既視感があると思いませんか?

Input: the man went to the [MASK1] . he bought a [MASK2] of milk.
Labels: [MASK1] = store; [MASK2] = gallon

そう、センター試験やTOEICで出てくる単語穴埋めの問題です

ということで、今回は事前学習済みのBERTでfine-tuningせずにTOEICのPart 5の単語穴埋め問題を解けるか試してみたいと思います

問題はIIBC公式のサンプル問題^2を使用させていただきました

BERT学習済みモデルの読み込み

huggingfaceのtransformers^3 を利用します

import torch
from transformers import BertTokenizer, BertForPreTraining

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForPreTraining.from_pretrained('bert-base-uncased')

問題定義

textに問題の文章、candidateに穴埋め候補単語をリストで定義します
問題文章の空欄部分は*とします

text = "Customer reviews indicate that many modern mobile devices are often unnecessarily * ."
candidate = ["complication", "complicates", "complicate", "complicated"]

こちら、品詞を問う問題で、正解は4番目の”complicated”なのですが、BERTは正解することができるのでしょうか?

トークナイズ

BertTokenizerをつかってトークン分割します

tokens = tokenizer.tokenize(text)
# -> ['customer', 'reviews', 'indicate', 'that', 'many', 'modern', 'mobile', 'devices', 'are', 'often', 'un', '##ne', '##ces', '##sari', '##ly', '*', '.']

トークン分割ができたら、次は元の問題で空欄だった部分を[MASK]トークンに置き換えます
また、事前学習時と同様に文頭と文末にSpecial Tokenの[CLS][SEP]を入れます

masked_index = tokens.index("*")  # 空欄部分のトークンのインデックスを取得
tokens[masked_index] = "[MASK]"
tokens = ["[CLS]"] + tokens + ["[SEP]"]
# -> ['[CLS]', 'customer', 'reviews', 'indicate', 'that', 'many', 'modern', 'mobile', 'devices', 'are', 'often', 'un', '##ne', '##ces', '##sari', '##ly', '[MASK]', '.', '[SEP]']

BERTで予測

トークナイズされた文章をIDに変換して事前学習済みのBERTに通します
今回は事前学習と解きたい問題が同じなので、fine-tuningは行わずにそのまま予測します

ids = tokenizer.convert_tokens_to_ids(tokens)
ids = torch.tensor(ids).reshape(1,-1)  # バッチサイズ1の形に整形

with torch.no_grad():
    outputs1, outputs2 = model(ids)
predictions = outputs1[0]

outputs1Masked LMの予測結果が入っています
outputs2はBERTのもう一つの事前学習のNext Sentence Predictionの予測結果が入っていますが今回は使わないので無視します

予測上位のトークンを取得

[MASK]に入ると予測される単語の上位1000件を取得します

_, predicted_indexes = torch.topk(predictions[masked_index+1], k=1000)
predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_indexes.tolist())
# -> ['expensive', 'small', 'priced', 'used', ...

1位はexpensiveで、日本語にすると「カスタマーレビューによると、最近のモバイルデバイスは無駄に高価です」といったところでしょうか
意味に違和感はないと思います
予測単語を順々に見ていき、候補の単語が出てきたところで止めます

for i, v in enumerate(predicted_tokens):
        if v in candidate:
            print(i, v)
            break
# -> 74 complicated

75番目(index=74)でcomplicatedがヒットしました
日本語訳は「カスタマーレビューによると、最近のモバイルデバイスは無駄に複雑です」ですかね

見事正解できました

関数化

ここまでの処理を1つの関数にまとめます

def part5_slover(text, candidate):
    tokens = tokenizer.tokenize(text)
    masked_index = tokens.index("*")
    tokens[masked_index] = "[MASK]"
    tokens = ["[CLS]"] + tokens + ["[SEP]"]

    ids = tokenizer.convert_tokens_to_ids(tokens)
    ids = torch.tensor(ids).reshape(1,-1)
    with torch.no_grad():
        outputs, _ = model(ids)
    predictions = outputs[0]

    _, predicted_indexes = torch.topk(predictions[masked_index+1], k=1000)
    predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_indexes.tolist())


    for i, v in enumerate(predicted_tokens):
        if v in candidate:
            return (i, v)
    return "don't know"

こちらの関数を使って残りの問題も解いてみます

スクリーンショット 2019-12-20 17.34.49.png (260.8 kB)

なんと全問正解です!

おわりに

本記事では事前学習済みのBERTのモデルを使って、TOEICのPart 5の問題を解いてみました
解きたい問題がほぼ同じなので、fine-tuningすることなく良い結果が得られました

せっかくなのでこれで終わりにせず、リーディング問題をすべてBERTで解いてみようと思います(流石にPart 7はfine-tuningが必要でしょうか。。。)
次回はPart 6に挑戦します

最後まで読んでいただきありがとうございました