Distilado Transformer de Imágenes Eficiente en Datos (modelo tamaño base)

facebook
Clasificación de imagen

Modelo Distilado de Transformer de Imágenes Eficiente en Datos (DeiT) preentrenado a una resolución de 224x224 y ajustado a 384x384 en ImageNet-1k (1 millón de imágenes, 1,000 clases). Fue presentado por primera vez en el artículo 'Formación de transformadores de imágenes eficientes en datos y destilación a través de la atención' por Touvron et al. Los pesos fueron convertidos desde el repositorio timm por Ross Wightman. Este modelo es un Transformer de Visión distilado (ViT). Utiliza un token de destilación, además del token de clase, para aprender de manera efectiva de un profesor (CNN) tanto durante el preentrenamiento como durante el ajuste fino. El token de destilación se aprende a través de la retropropagación, interactuando con los tokens de clase ([CLS]) y de parches a través de las capas de autoatención.

Como usar

Dado que este modelo es un modelo ViT distilado, se puede conectar a DeiTModel, DeiTForImageClassification o DeiTForImageClassificationWithTeacher. El modelo espera que los datos estén preparados utilizando DeiTFeatureExtractor. Aquí usamos AutoFeatureExtractor, que utilizará automáticamente el extractor de características apropiado dado el nombre del modelo. A continuación, se explica cómo utilizar este modelo para clasificar una imagen del dataset COCO 2017 en una de las 1,000 clases de ImageNet:

from transformers import AutoFeatureExtractor, DeiTForImageClassificationWithTeacher
from PIL import Image
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/deit-base-distilled-patch16-384')
model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-base-distilled-patch16-384')
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# el modelo predice una de las 1000 clases de ImageNet
predicted_class_idx = logits.argmax(-1).item()
print("Clase predicha:", model.config.id2label[predicted_class_idx])

Actualmente, tanto el extractor de características como el modelo son compatibles con PyTorch. Próximamente se admitirán TensorFlow y JAX/FLAX.

Funcionalidades

Modelo Transformer de Visión distilado (ViT)
Entrenamiento con token de destilación y token de clase
Presentación de imágenes como una secuencia de parches de tamaño fijo (16x16)
Preentrenado a 224x224 y ajustado a 384x384 en ImageNet-1k
Aprendizaje efectivo a partir de un profesor (CNN)

Casos de uso

Clasificación de imágenes
Modelos ajustados para tareas específicas de clasificación de imágenes
Aprendizaje eficiente a partir de transformadores de visión en datos