timm/vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k

timm
Clasificación de imagen

Un modelo de clasificación de imágenes basado en Vision Transformer (ViT). Esta es una variación específica de timm de la arquitectura con registros y pooling global promedio. Preentrenado en ImageNet-12k y afinado en ImageNet-1k por Ross Wightman en timm utilizando una plantilla de receta.

Como usar

Clasificación de imágenes

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))

model = timm.create_model('vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k', pretrained=True)
model = model.eval()

# obtener las transformaciones específicas del modelo (normalización, redimensionamiento)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # añadir una imagen en lote de 1

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)

Extracción de mapa de características

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))

model = timm.create_model('vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k', pretrained=True, features_only=True)
model = model.eval()

# obtener las transformaciones específicas del modelo (normalización, redimensionamiento)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # añadir una imagen en lote de 1

for o in output:
    print(o.shape)

Embeddings de imágenes

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))

model = timm.create_model('vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k', pretrained=True, num_classes=0)
model = model.eval()

# obtener las transformaciones específicas del modelo (normalización, redimensionamiento)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))

output = model.forward_features(transforms(img).unsqueeze(0))
output = model.forward_head(output, pre_logits=True)

Funcionalidades

Transformers de visión con registros
Clasificación de imágenes
Extracción de mapa de características
Embeddings de imágenes
Comparación de modelos

Casos de uso

Clasificación de imágenes
Extracción de características de imágenes
Generación de embeddings de imágenes