Big Transfer (BiT)

google
Clasificación de imagen

El modelo BiT fue propuesto en el trabajo 'Big Transfer (BiT): General Visual Representation Learning' por Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Joan Puigcerver, Jessica Yung, Sylvain Gelly y Neil Houlsby. BiT es una receta simple para escalar el pre-entrenamiento de arquitecturas del tipo ResNet (específicamente, ResNetv2). El método resulta en mejoras significativas para el aprendizaje por transferencia. La transferencia de representaciones pre-entrenadas mejora la eficiencia de muestra y simplifica la afinación de hiperparámetros al entrenar redes neuronales profundas para visión. Revisamos el paradigma de pre-entrenar en grandes conjuntos de datos supervisados y afinar el modelo en una tarea objetivo. Escalamos el pre-entrenamiento y proponemos una receta simple que llamamos Big Transfer (BiT). Combinando algunos componentes cuidadosamente seleccionados y transfiriendo usando una heurística simple, logramos un rendimiento fuerte en más de 20 conjuntos de datos. BiT se desempeña bien en una sorprendentemente amplia gama de regímenes de datos, desde 1 ejemplo por clase hasta 1 millón de ejemplos en total. BiT alcanza un 87.5% de precisión top-1 en ILSVRC-2012, 99.4% en CIFAR-10 y 76.3% en el Benchmark de Adaptación de Tareas Visuales (VTAB) en 19 tareas. En conjuntos de datos pequeños, BiT alcanza un 76.8% en ILSVRC-2012 con 10 ejemplos por clase y 97.0% en CIFAR-10 con 10 ejemplos por clase. Realizamos un análisis detallado de los componentes principales que conducen a un alto rendimiento de transferencia.

Como usar

Aquí está cómo usar este modelo para clasificar una imagen del conjunto de datos COCO 2017 en una de las 1,000 clases de ImageNet:

from transformers import BitImageProcessor, BitForImageClassification
import torch
from datasets import load_dataset

dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]

feature_extractor = BitImageProcessor.from_pretrained("google/bit-50")
model = BitForImageClassification.from_pretrained("google/bit-50")

inputs = feature_extractor(image, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

# el modelo predice una de las 1000 clases de ImageNet
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])

Funcionalidades

Transferencia de representaciones pre-entrenadas
Mejora la eficiencia de muestra
Simplifica la afinación de hiperparámetros
Rendimiento fuerte en más de 20 conjuntos de datos
Buen desempeño en una amplia gama de regímenes de datos
Alto rendimiento en conjuntos de datos pequeños

Casos de uso

Clasificación de imágenes
Transfiera el aprendizaje en conjuntos de datos de visión
Aplicaciones que requieren alta eficiencia de muestra
Tareas de visión por computadora en diferentes regímenes de datos