layoutlm-funsd

Narsil
Detección de objetos

Este modelo es una versión ajustada del microsoft/layoutlm-base-uncased en el conjunto de datos funsd. Logra los siguientes resultados en el conjunto de evaluación: Pérdida: 1.0045. Evaluaciones detalladas por categoría incluyen una precisión global del 0.7599, un recall global del 0.8083, un F1 global del 0.7866, y una precisión global del 0.8106.

Como usar

Desplegar el modelo con puntos finales de inferencia

from typing import Dict, List, Any
from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor
import torch
from subprocess import run

# instalar tesseract-ocr y pytesseract
run("apt install -y tesseract-ocr", shell=True, check=True)
run("pip install pytesseract", shell=True, check=True)

# función auxiliar para desnormalizar bboxes para dibujar en la imagen
def unnormalize_box(bbox, width, height):
    return [
        width * (bbox[0] / 1000),
        height * (bbox[1] / 1000),
        width * (bbox[2] / 1000),
        height * (bbox[3] / 1000),
    ]

# configurar dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class EndpointHandler:
    def __init__(self, path=""):
        # cargar modelo y procesador desde la ruta
        self.model = LayoutLMForTokenClassification.from_pretrained(path).to(device)
        self.processor = LayoutLMv2Processor.from_pretrained(path)

    def __call__(self, data: Dict[str, bytes]) -> Dict[str, List[Any]]:
        """
        Args:
            data (:obj:`Dic`):
                incluye el archivo de imagen deserializado como PIL.Image
        """
        # procesar entrada
        image = data.pop("inputs", data)

        # procesar imagen
        encoding = self.processor(image, return_tensors="pt")

        # ejecutar predicción
        with torch.inference_mode():
            outputs = self.model(
                input_ids=encoding.input_ids.to(device),
                bbox=encoding.bbox.to(device),
                attention_mask=encoding.attention_mask.to(device),
                token_type_ids=encoding.token_type_ids.to(device),
            )
            predictions = outputs.logits.softmax(-1)

        # procesar salida
        result = []
        for item, inp_ids, bbox in zip(
                predictions.squeeze(0).cpu(), encoding.input_ids.squeeze(0).cpu(), encoding.bbox.squeeze(0).cpu()
        ):
            label = self.model.config.id2label[int(item.argmax().cpu())]
            if label == "O":
                continue
            score = item.max().item()
            text = self.processor.tokenizer.decode(inp_ids)
            bbox = unnormalize_box(bbox.tolist(), image.width, image.height)
            result.append({"label": label, "score": score, "text": text, "bbox": bbox})
        return {"predictions": result}

Enviar petición HTTP utilizando Python

import json
import requests as r
import mimetypes

ENDPOINT_URL=""  # url del punto final
HF_TOKEN=""  # token de la organización donde desplegó su punto final

def predict(path_to_image: str = None):
    with open(path_to_image, "rb") as i:
        b = i.read()
    headers = {
        "Authorization": f"Bearer {HF_TOKEN}",
        "Content-Type": mimetypes.guess_type(path_to_image)[0]
    }
    response = r.post(ENDPOINT_URL, headers=headers, data=b)
    return response.json()

prediction = predict(path_to_image="path_to_your_image.png")

print(prediction)
# {'predictions': [{'label': 'I-ANSWER', 'score': 0.4823932945728302, 'text': '[CLS]', 'bbox': [0.0, 0.0, 0.0, 0.0]}, {'label': 'B-HEADER', 'score': 0.992474377155304, 'text': 'your', 'bbox': [1712.529, 181.203, 1859.949, 228.88799999999998]}

Dibujar el resultado en la imagen

from PIL import Image, ImageDraw, ImageFont

# dibujar resultados en la imagen
 def draw_result(path_to_image, result):
    image = Image.open(path_to_image)
    label2color = {
        "B-HEADER": "blue",
        "B-QUESTION": "red",
        "B-ANSWER": "green",
        "I-HEADER": "blue",
        "I-QUESTION": "red",
        "I-ANSWER": "green",
    }
    # dibujar predicciones sobre la imagen
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()
    for res in result:
        draw.rectangle(res["bbox"], outline="black")
        draw.rectangle(res["bbox"], outline=label2color[res["label"]])
        draw.text((res["bbox"][0] + 10, res["bbox"][1] - 10), text=res["label"], fill=label2color[res["label"]], font=font)
    return image

draw_result("path_to_your_image.png", prediction["predictions"])

Funcionalidades

Detección de objetos
Transformers
PyTorch
TensorBoard
Clasificación de tokens

Casos de uso

Clasificación de tokens
Detección de preguntas y respuestas en documentos
Reconocimiento de encabezados en formularios