Vision Transformer - 画像分類

今回は、Google社が提供しているVision Transformerを使って画像分類を行います。

Huggingface Transformersのインストール

Vision Transformerは、Huggingface Transformersをインストールすることで使用できるようになります。

[Google Colaboratory]

1
!pip install transformers==4.6.0

Huggingface Transformersでは、以下のビジョンタスクのモデルアーキテクチャを使用できます。

  • Vision Transformer(Google AI)
  • DeiT(Facebook)
  • CLIP(OpenAI)

画像分類

分類する画像として下記のものを使います。

この画像をGoogle Colaboratoryアップロードしておきます。


画像分類するためのソースコードは以下のようになります。

11行目でアップロードした画像を読み込んでいます。

[Google Colaboratory]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor

# モデルと特徴抽出器の準備
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

from PIL import Image

# 画像の読み込み
image = Image.open('cat.png')

# 画像をテンソルに変換
inputs = feature_extractor(images=image, return_tensors='pt')

# 推論の実行
outputs = model(**inputs)
logits = outputs.logits # 1000クラスの精度の配列
predicted_class_idx = logits.argmax(-1).item() # 精度が最大のインデックス
print('class:', model.config.id2label[predicted_class_idx]) # インデックスをラベルに変換

実行結果は以下の通りです。

[実行結果]

1
class: tabby, tabby cat

tabby catは日本語でぶち猫という意味です。

猫の種類はよくわかりませんが、猫には違いないのでちゃんと画像分類できたということになると思います。

次回は、Facebook社が提供しているDeiTを使って画像分類を行います。