timm/regnety_040.ra3_in1k

timm
Clasificación de imagen

Un modelo de clasificación de imágenes RegNetY-4GF. Entrenado en ImageNet-1k por Ross Wightman en timm. La implementación de timm RegNet incluye una serie de mejoras que no están presentes en otras implementaciones, incluyendo: profundidad estocástica, punto de control de gradiente, decaimiento LR por capa, salida configurable (dilation), capas de activación y normalización configurables, opción para un bloque de cuello de botella pre-activación utilizado en la variante RegNetV, y las únicas definiciones de modelos RegNetZ conocidas con pesos preentrenados.

Como usar

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model('regnety_040.ra3_in1k', pretrained=True)
model = model.eval()

# obtain the specific transforms for the model (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # unsqueeze single image into batch of 1

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)

Extracción de Mapas de Características

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model(
'regnety_040.ra3_in1k',
pretrained=True,
features_only=True,
)
model = model.eval()

# obtain the specific transforms for the model (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # unsqueeze single image into batch of 1

for o in output:
    print(o.shape)

Embeddings de Imágenes

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))
model = timm.create_model(
'regnety_040.ra3_in1k',
pretrained=True,
num_classes=0, # remove classifier nn.Linear
)
model = model.eval()

# obtain the specific transforms for the model (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0)) # output is (batch_size, num_features) shaped tensor

# or equivalently (without needing to set num_classes=0)

output = model.forward_features(transforms(img).unsqueeze(0))
# output is unpooled, a (1, 1088, 7, 7) shaped tensor

output = model.forward_head(output, pre_logits=True)
# output is a (1, num_features) shaped tensor

Funcionalidades

Clasificación de imágenes
Extracción de mapas de características
Extracción de embeddings de imágenes

Casos de uso

Clasificación de imágenes
Extracción de mapas de características para análisis posterior
Generación de embeddings de imágenes para tareas de similitud o búsqueda de imágenes