Transformador de Imagen Eficiente en Datos (DeiT) con Distilación de Facebook - Base Distilled Patch16 224

facebook
Clasificación de imagen

Transformador de Imagen Eficiente en Datos Distilado (DeiT) modelo preentrenado y afinado en ImageNet-1k (1 millón de imágenes, 1,000 clases) a una resolución de 224x224. Fue introducido por primera vez en el documento de investigación 'Training data-efficient image transformers & distillation through attention' por Touvron et al. y lanzado inicialmente en este repositorio. Sin embargo, los pesos fueron convertidos del repositorio timm por Ross Wightman. Este modelo utiliza un token de distilación, además del token de clase, para aprender efectivamente de un profesor (CNN) durante el preentrenamiento y afinamiento. El token de distilación se aprende a través de retropropagación, interactuando con el token de clase ([CLS]) y los tokens de parche a través de las capas de auto-atención. Las imágenes se presentan al modelo como una secuencia de parches de tamaño fijo (resolución 16x16), los cuales son incrustados linealmente.

Como usar

Dado que este modelo es un ViT destilado, puedes integrarlo en DeiTModel, DeiTForImageClassification o DeiTForImageClassificationWithTeacher. Ten en cuenta que el modelo espera que los datos sean preparados usando DeiTFeatureExtractor. Aquí usamos AutoFeatureExtractor, que usará automáticamente el extractor de características apropiado dado el nombre del modelo.

from transformers import AutoFeatureExtractor, DeiTForImageClassificationWithTeacher
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/deit-base-distilled-patch16-224')
model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-base-distilled-patch16-224')

inputs = feature_extractor(images=image, return_tensors="pt")

# forward pass
outputs = model(**inputs)
logits = outputs.logits

# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

Actualmente, tanto el extractor de características como el modelo soportan PyTorch. TensorFlow y JAX/FLAX vendrán pronto.

Funcionalidades

Transformador de Visión Distilado
Preentrenado y afinado en ImageNet-1k
Resuelve imágenes de 224x224
Utiliza Token de Distilación y Token de Clase
Soporte actual para PyTorch, soporte para TensorFlow y JAX/FLAX próximamente

Casos de uso

Clasificación de imágenes