vit_wee_patch16_reg1_gap_256.sbb_in1k

timm
Clasificación de imagen

Un modelo de clasificación de imágenes Vision Transformer (ViT). Esta es una variación específica de timm de la arquitectura con registros y agrupación promedio global. Entrenado en ImageNet-1k en timm utilizando una receta que incluye influencias de Swin/DeiT/DeiT-III con mayor decaimiento de peso y una alta augmentación. Utiliza capa de decaimiento para ajuste fino y en algunos casos optimizadores BCE y/o NAdamW en lugar de AdamW.

Como usar

Clasificación de imágenes

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model('vit_wee_patch16_reg1_gap_256.sbb_in1k', pretrained=True)
model = model.eval()

# obtener las transformaciones específicas del modelo (normalización, cambio de tamaño)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # expander imagen única en batch de 1

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)

Extracción de Mapa de Características

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model(
'vit_wee_patch16_reg1_gap_256.sbb_in1k',
pretrained=True,
features_only=True,
)
model = model.eval()

# obtener las transformaciones específicas del modelo (normalización, cambio de tamaño)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # expander imagen única en batch de 1

for o in output:
  print(o.shape)

Embeddings de Imagen

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model(
'vit_wee_patch16_reg1_gap_256.sbb_in1k',
pretrained=True,
num_classes=0, # eliminar clasificador nn.Linear
)
model = model.eval()

# obtener las transformaciones específicas del modelo (normalización, cambio de tamaño)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # output es un tensor de forma (batch_size, num_features)

# o equivalente (sin necesidad de configurar num_classes=0)

output = model.forward_features(transforms(img).unsqueeze(0))
# output es no agrupado, un tensor de forma (1, 257, 256)

output = model.forward_head(output, pre_logits=True)
# output es un tensor de forma (1, num_features)

Funcionalidades

Clasificación de imágenes
Extracción de mapa de características
Embeddings de imagen
Comparación de modelos

Casos de uso

Clasificación de imágenes
Extracción de mapas de características
Generación de embeddings de imágenes
Comparación de diferentes arquitecturas de modelos