Deformable DETR modelo entrenado en DocLayNet
Deformable DEtection TRansformer (DETR), entrenado en DocLayNet (incluyendo 80k páginas anotadas en 11 clases). El modelo DETR es un transformador de codificador-decodificador con una columna vertebral convolucional. Se agregan dos cabezas en la parte superior de las salidas del decodificador para realizar la detección de objetos: una capa lineal para las etiquetas de clase y un MLP (perceptrón multicapa) para las cajas delimitadoras. El modelo utiliza las llamadas consultas de objetos para detectar objetos en una imagen. Cada consulta de objeto busca un objeto particular en la imagen. Para COCO, el número de consultas de objetos se establece en 100. El modelo se entrena utilizando una 'pérdida de emparejamiento bipartito': se comparan las clases predichas y las cajas delimitadoras de cada una de las N = 100 consultas de objetos con las anotaciones de verdad del terreno, completadas hasta la misma longitud N (por lo que si una imagen solo contiene 4 objetos, 96 anotaciones solo tendrán una 'sin objeto' como clase y 'sin caja delimitadora' como caja). Se utiliza el algoritmo de emparejamiento húngaro para crear un mapeo óptimo uno a uno entre cada una de las N consultas y cada una de las N anotaciones. Luego, se utilizan la entropía cruzada estándar (para las clases) y una combinación lineal de la pérdida L1 y la pérdida IoU generalizada (para las cajas delimitadoras) para optimizar los parámetros del modelo.
Como usar
Aquí se explica cómo usar este modelo:
from transformers import AutoImageProcessor, DeformableDetrForObjectDetection
import torch
from PIL import Image
import requests
url = "https://huggingface.co/Aryn/deformable-detr-DocLayNet/resolve/main/examples/doclaynet_example_1.png"
image = Image.open(requests.get(url, stream=True).raw)
processor = AutoImageProcessor.from_pretrained("Aryn/deformable-detr-DocLayNet")
model = DeformableDetrForObjectDetection.from_pretrained("Aryn/deformable-detr-DocLayNet")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# convertir salidas (cajas delimitadoras y logits de clase) a API COCO
# vamos a mantener solo las detecciones con puntaje > 0.7
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detectado {model.config.id2label[label.item()]} con confianza "
f"{round(score.item(), 3)} en la ubicación {box}")
Funcionalidades
- Detección de objetos
- Algoritmo de emparejamiento húngaro
- Perceptrón multicapa (MLP)
- Modelo transformador de codificador-decodificador
- Consultas de objetos
Casos de uso
- Detección de objetos en imágenes
- Análisis de diseño de documentos