Transformador de Imagen Eficiente en Datos (DeiT) con Distilación de Facebook - Base Distilled Patch16 224
Transformador de Imagen Eficiente en Datos Distilado (DeiT) modelo preentrenado y afinado en ImageNet-1k (1 millón de imágenes, 1,000 clases) a una resolución de 224x224. Fue introducido por primera vez en el documento de investigación 'Training data-efficient image transformers & distillation through attention' por Touvron et al. y lanzado inicialmente en este repositorio. Sin embargo, los pesos fueron convertidos del repositorio timm por Ross Wightman. Este modelo utiliza un token de distilación, además del token de clase, para aprender efectivamente de un profesor (CNN) durante el preentrenamiento y afinamiento. El token de distilación se aprende a través de retropropagación, interactuando con el token de clase ([CLS]) y los tokens de parche a través de las capas de auto-atención. Las imágenes se presentan al modelo como una secuencia de parches de tamaño fijo (resolución 16x16), los cuales son incrustados linealmente.
Como usar
Dado que este modelo es un ViT destilado, puedes integrarlo en DeiTModel, DeiTForImageClassification o DeiTForImageClassificationWithTeacher. Ten en cuenta que el modelo espera que los datos sean preparados usando DeiTFeatureExtractor. Aquí usamos AutoFeatureExtractor, que usará automáticamente el extractor de características apropiado dado el nombre del modelo.
from transformers import AutoFeatureExtractor, DeiTForImageClassificationWithTeacher
from PIL import Image
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/deit-base-distilled-patch16-224')
model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-base-distilled-patch16-224')
inputs = feature_extractor(images=image, return_tensors="pt")
# forward pass
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
Actualmente, tanto el extractor de características como el modelo soportan PyTorch. TensorFlow y JAX/FLAX vendrán pronto.
Funcionalidades
- Transformador de Visión Distilado
- Preentrenado y afinado en ImageNet-1k
- Resuelve imágenes de 224x224
- Utiliza Token de Distilación y Token de Clase
- Soporte actual para PyTorch, soporte para TensorFlow y JAX/FLAX próximamente
Casos de uso
- Clasificación de imágenes