mdeberta-v3-base-triplet-critic-xnli
Este es el modelo Triplit Critic presentado en el artículo RED^{FM}: un Conjunto de Datos de Extracción de Relaciones Filtradas y Multilingües, presentado en la conferencia ACL 2023. El Triplit Critic está basado en mdeberta-v3-base y fue entrenado como un sistema de multitareas para filtrar tríos así como en el conjunto de datos XNLI. Los pesos del modelo contienen las dos cabezas de clasificación, sin embargo, al cargarlo usando la biblioteca huggingface solo se cargarán los de la clasificación de tríos (es decir, una cabeza de clasificación binaria). Para usarlo en XNLI se necesita un script personalizado. Aunque está definido y entrenado como un sistema de clasificación, usamos la puntuación positiva (es decir, Label_1) como la puntuación de confianza para un trío. Para SREDFM el umbral de la puntuación de confianza se estableció en 0.75.
Como usar
Para cargar el modelo multitarea:
from transformers import DebertaV2PreTrainedModel, DebertaV2Model
from torch import nn
from transformers.models.deberta_v2.modeling_deberta_v2 import *
from transformers.file_utils import ModelOutput
@dataclass
class TXNLIClassifierOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
logits_xnli: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class DebertaV2ForTripletClassification(DebertaV2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
num_labels = getattr(config, "num_labels", 2)
self.num_labels = num_labels
self.deberta = DebertaV2Model(config)
self.pooler = ContextPooler(config)
output_dim = self.pooler.output_dim
self.classifier = nn.Linear(output_dim, num_labels)
drop_out = getattr(config, "cls_dropout", None)
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
self.dropout = StableDropout(drop_out)
self.classifier_xnli = nn.Linear(output_dim, 3)
self.post_init()
def get_input_embeddings(self):
return self.deberta.get_input_embeddings()
def set_input_embeddings(self, new_embeddings):
self.deberta.set_input_embeddings(new_embeddings)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.deberta(
input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
encoder_layer = outputs[0]
pooled_output = self.pooler(encoder_layer)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
logits_xnli = self.classifier_xnli(pooled_output)
loss = None
if labels is not None:
if labels.dtype != torch.bool:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
else:
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits_xnli.view(-1, 3), labels.view(-1).long())
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TXNLIClassifierOutput(
loss=loss, logits=logits, logits_xnli=logits_xnli, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
Funcionalidades
- Basado en mdeberta-v3-base
- Sistema multitarea para filtrar tríos y XNLI
- Contiene dos cabezas de clasificación
- Clasificación binaria para filtrar tríos
- Puntuación de confianza establecida en 0.75 para SREDFM
Casos de uso
- Filtrado de tríos en tareas de procesamiento de lenguaje natural
- Clasificación de texto en múltiples idiomas
- Aplicaciones que requieren puntuaciones de confianza para la extracción de relaciones