timm/coatnet_2_rw_224.sw_in12k_ft_in1k

timm
Clasificación de imagen

Un modelo específico de CoAtNet para la clasificación de imágenes desarrollado por timm. Está preentrenado en ImageNet-12k (un subconjunto de 11,821 clases del conjunto completo de ImageNet-22k) y ajustado en ImageNet-1k por Ross Wightman. El entrenamiento en ImageNet-12k se realizó en TPUs gracias al apoyo del programa TRC. El ajuste fino se llevó a cabo en instancias en la nube de Lambda Labs con 8 GPUs.

Como usar

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))

model = timm.create_model('coatnet_2_rw_224.sw_in12k_ft_in1k', pretrained=True)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # unsqueeze single image into batch of 1

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)

Extracción de Mapas de Características

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))

model = timm.create_model('coatnet_2_rw_224.sw_in12k_ft_in1k', pretrained=True, features_only=True)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # unsqueeze single image into batch of 1

for o in output:
    # print shape of each feature map in output
    print(o.shape)

Incrustaciones de Imágenes

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))

model = timm.create_model('coatnet_2_rw_224.sw_in12k_ft_in1k', pretrained=True, num_classes=0) # remove classifier nn.Linear
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # output is (batch_size, num_features) shaped tensor

# or equivalently (without needing to set num_classes=0)

output = model.forward_features(transforms(img).unsqueeze(0)) # output is unpooled, a (1, 1024, 7, 7) shaped tensor

output = model.forward_head(output, pre_logits=True) # output is a (1, num_features) shaped tensor

Funcionalidades

Modelo de clasificación de imágenes
Extracción de mapas de características
Incrustaciones de imágenes
Combinación de bloques de convolución MBConv y bloques de atención en las primeras etapas
Bloques uniformes en todas las etapas para MaxViT
Arquitectura específica de timm que utiliza bloques ConvNeXt en lugar de bloques MBConv
Variación de MaxxViT que elimina la atención de bloque de ventana, dejando solo bloques ConvNeXt y atención de rejilla

Casos de uso

Clasificación de imágenes
Extracción de mapas de características para revisar las activaciones de la red
Obtención de incrustaciones de imágenes para tareas downstream como la búsqueda de imágenes