timm/efficientnet_b5.sw_in12k_ft_in1k

timm
Clasificación de imagen

Un modelo de clasificación de imágenes EfficientNet. Preentrenado en ImageNet-12k y afinado en ImageNet-1k por Ross Wightman en timm utilizando la plantilla de receta descrita a continuación. Detalles de la receta: Basado en la receta de entrenamiento/preentrenamiento del Swin Transformer con modificaciones (relacionadas con las recetas de DeiT y ConvNeXt). Optimización AdamW, recorte de gradientes, EMA promedio de pesos. Programa de LR coseno con calentamiento.

Como usar

Clasificación de imágenes

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))

model = timm.create_model('efficientnet_b5.sw_in12k_ft_in1k', pretrained=True)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # unsqueeze single image into batch of 1

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)

Extracción de mapas de características

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))

model = timm.create_model('efficientnet_b5.sw_in12k_ft_in1k', pretrained=True, features_only=True)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # unsqueeze single image into batch of 1

for o in output:
   print(o.shape) # print shape of each feature map in output

Embeddings de imágenes

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))

model = timm.create_model('efficientnet_b5.sw_in12k_ft_in1k', pretrained=True, num_classes=0)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # output is (batch_size, num_features) shaped tensor

# or equivalently without needing to set num_classes=0
output = model.forward_features(transforms(img).unsqueeze(0)) # output is unpooled, a (1, 2048, 14, 14) shaped tensor

output = model.forward_head(output, pre_logits=True) # output is a (1, num_features) shaped tensor

Funcionalidades

Clasificación de imágenes
Extracción de mapas de características
Embedding de imágenes
Comparación de modelos

Casos de uso

Clasificación de imágenes
Extracción de características de imágenes
Generación de embeddings de imágenes