coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k

timm
Clasificación de imagen

Un modelo de clasificación de imágenes CoAtNet específico de timm (con un sesgo relativo de posición de coordenada logarítmica continua motivado por Swin-V2). Preentrenado en timm en ImageNet-12k (un subconjunto de 11821 clases del ImageNet-22k completo) y afinado en ImageNet-1k por Ross Wightman. El entrenamiento en ImageNet-12k se realizó en TPUs gracias al apoyo del programa TRC. El afinamiento se realizó en instancias en la nube de 8x GPU de Lambda Labs.

Como usar

Clasificación de Imágenes

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))

model = timm.create_model('coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k', pretrained=True)
model = model.eval()

# obtener transformaciones específicas del modelo (normalización, cambio de tamaño)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # unsqueeze para convertir imagen única en batch de 1

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)

Extracción de Mapa de Características

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))

model = timm.create_model('coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k', pretrained=True, features_only=True)
model = model.eval()

# obtener transformaciones específicas del modelo (normalización, cambio de tamaño)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # unsqueeze para convertir imagen única en batch de 1

for o in output:
    print(o.shape)

Embeddings de Imagen

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))

model = timm.create_model('coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k', pretrained=True, num_classes=0) # eliminar clasificador nn.Linear
model = model.eval()

# obtener transformaciones específicas del modelo (normalización, cambio de tamaño)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # la salida es un tensor de tamaño (batch_size, num_features)

# o de manera equivalente (sin necesidad de establecer num_classes=0)

output = model.forward_features(transforms(img).unsqueeze(0)) # salida sin agrupar, tensor de tamaño (1, 1024, 7, 7)
output = model.forward_head(output, pre_logits=True) # salida es un tensor de tamaño (1, num_features)

Funcionalidades

Clase del modelo: Clasificación de imágenes / columna vertebral de características
Parámetros (M): 73.9
GMACs: 15.2
Activaciones (M): 54.8
Tamaño de imagen: 224 x 224

Casos de uso

Clasificación de imágenes
Extracción de mapas de características
Embeddings de imágenes