DeiT - 画像分類

今回は、Facebook社が提供しているDeiTを使って画像分類を行います。

Huggingface Transformersのインストール

DeiTは、Huggingface Transformersをインストールすることで使用できるようになります。

[Google Colaboratory]

1
!pip install transformers==4.6.0

Huggingface Transformersでは、以下のビジョンタスクのモデルアーキテクチャを使用できます。

  • Vision Transformer(Google AI)
    前回の記事で使用したモデル
  • DeiT(Facebook)
    今回の記事で使用するモデル
  • CLIP(OpenAI)

画像分類

分類する画像として下記のものを使います。

この画像をGoogle Colaboratoryアップロードしておきます。


画像分類するためのソースコードは以下のようになります。

2~6行目が前回記事と異なっており、それ以外の箇所はすべて同じです。

[Google Colaboratory]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from transformers import AutoFeatureExtractor, DeiTForImageClassificationWithTeacher

# モデルと特徴抽出器の準備
model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-base-distilled-patch16-224')
feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/deit-base-distilled-patch16-224')

from PIL import Image

# 画像の読み込み
image = Image.open('cat.png')

# 画像をテンソルに変換
inputs = feature_extractor(images=image, return_tensors='pt')

# 推論の実行
outputs = model(**inputs)
logits = outputs.logits # 1000クラスの精度の配列
predicted_class_idx = logits.argmax(-1).item() # 精度が最大のインデックス
print('class:', model.config.id2label[predicted_class_idx]) # インデックスをラベルに変換

実行結果は以下の通りです。

[実行結果]

1
class: tabby, tabby cat

判定結果はtabby cat(ぶち猫)となり、前回のVision Transformer(Google AI)のモデルを使った時と同じ結果となりました。

次回は、CLIPを使った画像分類を行います。