seresnext26ts.ch_in1k

timm
Clasificación de imagen

Un modelo de clasificación de imágenes SE-ResNeXt (ResNeXt con atención de canal 'Squeeze-and-Excitation'). Este modelo presenta un stem de 3 capas por niveles y activaciones SiLU. Entrenado en ImageNet-1k por Ross Wightman en timm. La arquitectura del modelo se implementa utilizando la flexible red BYOBNet (Bring-Your-Own-Blocks Network) de timm. BYOBNet permite la configuración de la disposición de bloques/etapas, disposición de stem, salida de stride (dilatación), capas de activación y normalización, capas de atención tanto de canal como espacial/auto-atención.

Como usar

from urllib.request import urlopen
from PIL import Image
import timm

# Clasificación de Imagen
img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))
model = timm.create_model('seresnext26ts.ch_in1k', pretrained=True)
model = model.eval()
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(img).unsqueeze(0)) # unsqueeze single image into batch of 1
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)

# Extracción de Mapa de Características
img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))
model = timm.create_model('seresnext26ts.ch_in1k', pretrained=True, features_only=True)
model = model.eval()
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(img).unsqueeze(0))
for o in output:
   print(o.shape)

# Embedding de Imagen
img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))
model = timm.create_model('seresnext26ts.ch_in1k', pretrained=True, num_classes=0)
model = model.eval()
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(img).unsqueeze(0))
output = model.forward_features(transforms(img).unsqueeze(0))
output = model.forward_head(output, pre_logits=True)

Funcionalidades

Profundidad estocástica
Punto de verificación de gradiente
Decaimiento de LR por capa
Extracción de características por etapa

Casos de uso

Clasificación de imágenes
Extracción de mapas de características
Obtención de embeddings de imágenes