Big Transfer (BiT)
El modelo BiT fue propuesto en el trabajo 'Big Transfer (BiT): General Visual Representation Learning' por Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Joan Puigcerver, Jessica Yung, Sylvain Gelly y Neil Houlsby. BiT es una receta simple para escalar el pre-entrenamiento de arquitecturas del tipo ResNet (específicamente, ResNetv2). El método resulta en mejoras significativas para el aprendizaje por transferencia. La transferencia de representaciones pre-entrenadas mejora la eficiencia de muestra y simplifica la afinación de hiperparámetros al entrenar redes neuronales profundas para visión. Revisamos el paradigma de pre-entrenar en grandes conjuntos de datos supervisados y afinar el modelo en una tarea objetivo. Escalamos el pre-entrenamiento y proponemos una receta simple que llamamos Big Transfer (BiT). Combinando algunos componentes cuidadosamente seleccionados y transfiriendo usando una heurística simple, logramos un rendimiento fuerte en más de 20 conjuntos de datos. BiT se desempeña bien en una sorprendentemente amplia gama de regímenes de datos, desde 1 ejemplo por clase hasta 1 millón de ejemplos en total. BiT alcanza un 87.5% de precisión top-1 en ILSVRC-2012, 99.4% en CIFAR-10 y 76.3% en el Benchmark de Adaptación de Tareas Visuales (VTAB) en 19 tareas. En conjuntos de datos pequeños, BiT alcanza un 76.8% en ILSVRC-2012 con 10 ejemplos por clase y 97.0% en CIFAR-10 con 10 ejemplos por clase. Realizamos un análisis detallado de los componentes principales que conducen a un alto rendimiento de transferencia.
Como usar
Aquí está cómo usar este modelo para clasificar una imagen del conjunto de datos COCO 2017 en una de las 1,000 clases de ImageNet:
from transformers import BitImageProcessor, BitForImageClassification
import torch
from datasets import load_dataset
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
feature_extractor = BitImageProcessor.from_pretrained("google/bit-50")
model = BitForImageClassification.from_pretrained("google/bit-50")
inputs = feature_extractor(image, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
# el modelo predice una de las 1000 clases de ImageNet
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])
Funcionalidades
- Transferencia de representaciones pre-entrenadas
- Mejora la eficiencia de muestra
- Simplifica la afinación de hiperparámetros
- Rendimiento fuerte en más de 20 conjuntos de datos
- Buen desempeño en una amplia gama de regímenes de datos
- Alto rendimiento en conjuntos de datos pequeños
Casos de uso
- Clasificación de imágenes
- Transfiera el aprendizaje en conjuntos de datos de visión
- Aplicaciones que requieren alta eficiencia de muestra
- Tareas de visión por computadora en diferentes regímenes de datos