mdeberta-v3-base-triplet-critic-xnli

Babelscape
Clasificación de texto

Este es el modelo Triplit Critic presentado en el artículo RED^{FM}: un Conjunto de Datos de Extracción de Relaciones Filtradas y Multilingües, presentado en la conferencia ACL 2023. El Triplit Critic está basado en mdeberta-v3-base y fue entrenado como un sistema de multitareas para filtrar tríos así como en el conjunto de datos XNLI. Los pesos del modelo contienen las dos cabezas de clasificación, sin embargo, al cargarlo usando la biblioteca huggingface solo se cargarán los de la clasificación de tríos (es decir, una cabeza de clasificación binaria). Para usarlo en XNLI se necesita un script personalizado. Aunque está definido y entrenado como un sistema de clasificación, usamos la puntuación positiva (es decir, Label_1) como la puntuación de confianza para un trío. Para SREDFM el umbral de la puntuación de confianza se estableció en 0.75.

Como usar

Para cargar el modelo multitarea:

from transformers import DebertaV2PreTrainedModel, DebertaV2Model
from torch import nn
from transformers.models.deberta_v2.modeling_deberta_v2 import *
from transformers.file_utils import ModelOutput

@dataclass
class TXNLIClassifierOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    logits_xnli: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

class DebertaV2ForTripletClassification(DebertaV2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        num_labels = getattr(config, "num_labels", 2)
        self.num_labels = num_labels
        self.deberta = DebertaV2Model(config)
        self.pooler = ContextPooler(config)
        output_dim = self.pooler.output_dim
        self.classifier = nn.Linear(output_dim, num_labels)
        drop_out = getattr(config, "cls_dropout", None)
        drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
        self.dropout = StableDropout(drop_out)
        self.classifier_xnli = nn.Linear(output_dim, 3)

        self.post_init()

    def get_input_embeddings(self):
        return self.deberta.get_input_embeddings()

    def set_input_embeddings(self, new_embeddings):
        self.deberta.set_input_embeddings(new_embeddings)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = self.deberta(
            input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        encoder_layer = outputs[0]
        pooled_output = self.pooler(encoder_layer)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        logits_xnli = self.classifier_xnli(pooled_output)
        loss = None
        if labels is not None:
            if labels.dtype != torch.bool:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            else:
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits_xnli.view(-1, 3), labels.view(-1).long())
        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output
        return TXNLIClassifierOutput(
            loss=loss, logits=logits, logits_xnli=logits_xnli, hidden_states=outputs.hidden_states, attentions=outputs.attentions
        )

Funcionalidades

Basado en mdeberta-v3-base
Sistema multitarea para filtrar tríos y XNLI
Contiene dos cabezas de clasificación
Clasificación binaria para filtrar tríos
Puntuación de confianza establecida en 0.75 para SREDFM

Casos de uso

Filtrado de tríos en tareas de procesamiento de lenguaje natural
Clasificación de texto en múltiples idiomas
Aplicaciones que requieren puntuaciones de confianza para la extracción de relaciones