Pyramid Vision Transformer (modelo de tamaño pequeño)

Zetatech
Clasificación de imagen

El Pyramid Vision Transformer (PVT) es un modelo de codificador transformador (similar a BERT) preentrenado en ImageNet-1k, un conjunto de datos que comprende 1 millón de imágenes y 1,000 clases, también a resolución 224x224. Se introdujo en el artículo 'Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions' por Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, Kaitao Song, Ding Liang, Tong Lu, Ping Luo y Ling Shao, y se lanzó por primera vez en este repositorio. Las imágenes se presentan al modelo como una secuencia de parches de tamaño variable, que se incrustan linealmente. A diferencia de los modelos ViT, PVT utiliza una pirámide de reducción progresiva para reducir los cálculos de grandes mapas de características en cada etapa. Además, se agrega un token [CLS] al comienzo de una secuencia para usarlo en tareas de clasificación. También se agregan incrustaciones de posición absoluta antes de alimentar la secuencia a las capas del codificador del transformador. Al preentrenar el modelo, aprende una representación interna de las imágenes que luego se puede usar para extraer características útiles para tareas posteriores. Por ejemplo, si tienes un conjunto de datos de imágenes etiquetadas, puedes entrenar un clasificador estándar colocando una capa lineal sobre el codificador preentrenado.

Como usar

Aquí se muestra cómo usar este modelo para clasificar una imagen del conjunto de datos COCO 2017 en una de las 1,000 clases de ImageNet:

from transformers import PvtImageProcessor, PvtForImageClassification
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

processor = PvtImageProcessor.from_pretrained('Zetatech/pvt-tiny-224')
model = PvtForImageClassification.from_pretrained('Zetatech/pvt-tiny-224')

inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# el modelo predice una de las 1000 clases de ImageNet
predicted_class_idx = logits.argmax(-1).item()
print("Clase predicha:", model.config.id2label[predicted_class_idx])

Funcionalidades

Modelo preentrenado en ImageNet-1k
Reducción progresiva en pirámide
Token [CLS] para tareas de clasificación
Incrustaciones de posición absoluta
Capa lineal para clasificación

Casos de uso

Clasificación de imágenes
Extracción de características de imágenes