Modelo DETR Condicional con ResNet-101 (etapa dilatada C5)
El recientemente desarrollado enfoque DETR aplica la arquitectura del codificador y decodificador transformer a la detección de objetos y logra un rendimiento prometedor. En este trabajo, abordamos el problema crítico de la lenta convergencia del entrenamiento y presentamos un mecanismo de atención cruzada condicional para un entrenamiento rápido de DETR. Nuestro enfoque, denominado DETR condicional, aprende una consulta espacial condicional a partir de la incrustación del decodificador para la atención cruzada de múltiples cabezales del decodificador. El beneficio es que a través de la consulta espacial condicional, cada cabeza de atención cruzada puede atender a una banda que contiene una región distinta, por ejemplo, una extremidad de un objeto o una región dentro del cuadro del objeto. Esto reduce el rango espacial para localizar las regiones distintas para la clasificación de objetos y la regresión de cuadros, relajando así la dependencia de las incrustaciones de contenido y facilitando el entrenamiento. Los resultados empíricos muestran que DETR condicional converge 6.7 veces más rápido para las espinas R50 y R101 y 10 veces más rápido para espinas más fuertes DC5-R50 y DC5-R101.
Como usar
Aquí se explica cómo usar este modelo:
from transformers import AutoImageProcessor, ConditionalDetrForObjectDetection
import torch
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
processor = AutoImageProcessor.from_pretrained("Omnifact/conditional-detr-resnet-101-dc5")
model = ConditionalDetrForObjectDetection.from_pretrained("Omnifact/conditional-detr-resnet-101-dc5")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# convertir salidas (cuadros delimitadores y logits de clase) a API de COCO
# solo mantendremos las detecciones con puntuación > 0.7
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detectado {model.config.id2label[label.item()]} con confianza "
f"{round(score.item(), 3)} en la ubicación {box}")
Esto debería generar:
Detectado gato con confianza 0.865 en la ubicación [13.95, 64.98, 327.14, 478.82]
Detectado control remoto con confianza 0.849 en la ubicación [39.37, 83.18, 187.67, 125.02]
Detectado gato con confianza 0.743 en la ubicación [327.22, 35.17, 637.54, 377.04]
Detectado control remoto con confianza 0.737 en la ubicación [329.36, 89.47, 376.42, 197.53]
Funcionalidades
- Detección de objetos con Transformers
- Entrenamiento rápido gracias a la atención cruzada condicional
- Pesos disponibles en PyTorch y Safetensors
Casos de uso
- Detección de objetos en imágenes
- Clasificación de objetos y regresión de cuadros
- Aplicaciones de visión por computadora que requieren rápida convergencia de entrenamiento