Transformers(12) - 質疑応答③学習したモデルを使って質疑応答

前回は、わかち書きとファインチューニングを行いました。

今回は、学習したモデルを使って質疑応答を行います。

学習済みモデルを使って質疑応答

学習したモデルを使って質疑応答を行ってみます。

8行目のfrom_pretrained関数で、学習済みモデルをロードしています。

11行目でコンテキストを設定し、12行目で質問を指定しています。

[Google Colaboratory]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from transformers import BertJapaneseTokenizer, AutoModelForQuestionAnswering
import MeCab
wakati = MeCab.Tagger("-Owakati")

# トークナイザーとモデルの準備
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
model = AutoModelForQuestionAnswering.from_pretrained('output/')

# コンテキストと質問
context = wakati.parse('土曜日に友達と表参道に遊びに行きました。').strip()
question = wakati.parse('どこに遊びに行ったの?').strip()
print(wakati.parse('土曜日に友達と表参道に遊びに行きました。'), wakati.parse('どこに遊びに行ったの?'))

# テキストをテンソルに変換
inputs = tokenizer.encode_plus(question, context, return_tensors='pt')

# 入力のトークンIDの配列の取得
input_ids = inputs['input_ids'].tolist()[0]

# 推論
model.eval()
with torch.no_grad():
output = model(**inputs)
answer_start = torch.argmax(output.start_logits)
answer_end = torch.argmax(output.end_logits) + 1
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
print(answer)

回答結果は以下の通りです。

[実行結果]

1
2
3
4
土曜 日 に 友達 と 表 参道 に 遊び に 行き まし た 。 
どこ に 遊び に 行っ た の ?

表 参道

応答は「表 参道」と的確な回答になっています。


次に、コンテキストはそのままで質問の内容を変えてみます。

[Google Colaboratory]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from transformers import BertJapaneseTokenizer, AutoModelForQuestionAnswering
import MeCab
wakati = MeCab.Tagger("-Owakati")

# トークナイザーとモデルの準備
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
model = AutoModelForQuestionAnswering.from_pretrained('output/')

# コンテキストと質問
context = wakati.parse('土曜日に友達と表参道に遊びに行きました。').strip()
question = wakati.parse('いつ遊びに行ったの?').strip()
print(wakati.parse('土曜日に友達と表参道に遊びに行きました。'), wakati.parse('いつ遊びに行ったの?'))

# テキストをテンソルに変換
inputs = tokenizer.encode_plus(question, context, return_tensors='pt')

# 入力のトークンIDの配列の取得
input_ids = inputs['input_ids'].tolist()[0]

# 推論
model.eval()
with torch.no_grad():
output = model(**inputs)
answer_start = torch.argmax(output.start_logits)
answer_end = torch.argmax(output.end_logits) + 1
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
print(answer)

回答結果は以下の通りです。

[実行結果]

1
2
3
4
土曜 日 に 友達 と 表 参道 に 遊び に 行き まし た 。 
いつ 遊び に 行っ た の ?

土曜 日

応答は「土曜 日」と、こちらも的確な回答になっています。

前回のファインチューニングは2時間以上とかなり時間がかかりましたが、一度学習を完了してしまえば質問に対して的確な応答を得られるようになるのでかなり実用的だと感じました。