Transformador eficiente en datos de imágenes (modelo pequeño)
El Transformador eficiente en datos de imágenes (DeiT) es un modelo Vision Transformer (ViT) pre-entrenado y afinado en ImageNet-1k (1 millón de imágenes, 1,000 clases) a una resolución de 224x224. Introducido en el artículo 'Entrenar transformadores de imagen eficientes en datos y destilación a través de la atención' por Touvron et al. y liberado en este repositorio. Los pesos fueron convertidos del repositorio timm por Ross Wightman. Este modelo es un Transformer Vision (ViT) más eficiente en el entrenamiento, pre-entrenado y afinado en una gran colección de imágenes de manera supervisada, concretamente ImageNet-1k, a una resolución de 224x224 píxeles. Las imágenes se presentan al modelo como una secuencia de parches de tamaño fijo (resolución 16x16), que se incrustan linealmente. También se agrega un token [CLS] al principio de una secuencia para usarlo en tareas de clasificación. Antes de alimentar la secuencia a las capas del codificador Transformer, se agregan incrustaciones de posición absoluta.
Como usar
Dado que este modelo está entrenado de manera más eficiente, puedes conectarlo a ViTModel o ViTForImageClassification. Toma nota de que el modelo espera que los datos se preparen utilizando DeiTFeatureExtractor. Aquí usamos AutoFeatureExtractor, que seleccionará automáticamente el extractor de características apropiado dado el nombre del modelo. Así es como se usa este modelo para clasificar una imagen del conjunto de datos COCO 2017 en una de las 1,000 clases de ImageNet:
from transformers import AutoFeatureExtractor, ViTForImageClassification
from PIL import Image
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/deit-small-patch16-224')
model = ViTForImageClassification.from_pretrained('facebook/deit-small-patch16-224')
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# El modelo predice una de las 1000 clases de ImageNet
predicted_class_idx = logits.argmax(-1).item()
print("Clase predicha:", model.config.id2label[predicted_class_idx])
Funcionalidades
- Transformador encoder (similar a BERT)
- Presentación de imágenes como secuencia de parches de tamaño fijo (16x16)
- [CLS] token para tareas de clasificación
- Incrustaciones de posición absoluta
- Capacidad de extraer características útiles para tareas posteriores
Casos de uso
- Clasificación de imágenes
- Entrenamiento de un clasificador estándar usando una capa lineal sobre el codificador pre-entrenado