deit-small-distilled-patch16-224
Transformador de imágenes eficiente en datos destilados (modelo de tamaño pequeño) preentrenado y afinado en ImageNet-1k (1 millón de imágenes, 1,000 clases) a una resolución de 224x224. Fue introducido por primera vez en el artículo 'Entrenamiento de transformadores de imágenes eficientes en datos & destilación mediante atención' por Touvron et al. y primero lanzado en este repositorio. Sin embargo, los pesos fueron convertidos del repositorio timm por Ross Wightman. Este modelo utiliza un token de destilación, además del token de clase, para aprender efectivamente de un maestro (CNN) tanto durante el preentrenamiento como la afinación. El token de destilación se aprende a través de retropropagación, interactuando con el token de clase ([CLS]) y los tokens de parche a través de las capas de autoatención. Las imágenes se presentan al modelo como una secuencia de parches de tamaño fijo (resolución 16x16), que se incrustan linealmente.
Como usar
Dado que este modelo es un ViT destilado, puedes conectarlo a DeiTModel, DeiTForImageClassification o DeiTForImageClassificationWithTeacher. Nota que el modelo espera que los datos se preparen usando DeiTFeatureExtractor. Aquí usamos AutoFeatureExtractor, que automáticamente usará el extractor de características apropiado dado el nombre del modelo. Aquí está cómo usar este modelo para clasificar una imagen del dataset COCO 2017 en una de las 1,000 clases de ImageNet:
from transformers import AutoFeatureExtractor, DeiTForImageClassificationWithTeacher
from PIL import Image
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/deit-small-distilled-patch16-224')
model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-small-distilled-patch16-224')
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# el modelo predice una de las 1000 clases de ImageNet
predicted_class_idx = logits.argmax(-1).item()
print("Clase predicha:", model.config.id2label[predicted_class_idx])
Actualmente, tanto el extractor de características como el modelo soportan PyTorch. Tensorflow y JAX/FLAX llegarán pronto.
Funcionalidades
- Modelo basado en Transformers de visión
- Token de destilación para aprendizaje efectivo
- Entrenado y afinado con ImageNet-1k
- Resuelve tareas de clasificación de imágenes
Casos de uso
- Clasificación de imágenes