FactCC
manueldeprada
Clasificación de texto
Modelo de predicción de factualidad FactCC. Este es una implementación moderna del modelo y código del repositorio original de GitHub. Este modelo está entrenado para predecir si un resumen es factual en relación con el texto original.
Como usar
Uso básico:
from transformers import BertForSequenceClassification, BertTokenizer
model_path = 'manueldeprada/FactCC'
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path)
text='''The US has "passed the peak" on new coronavirus cases, the White House reported. They predict that some states would reopen this month.
The US has over 637,000 confirmed Covid-19 cases and over 30,826 deaths, the highest for any country in the world.'''
wrong_summary = '''The pandemic has almost not affected the US'''
input_dict = tokenizer(text, wrong_summary, max_length=512, padding='max_length', truncation='only_first', return_tensors='pt')
logits = model(**input_dict).logits
pred = logits.argmax(dim=1)
model.config.id2label[pred.item()] # prints: INCORRECT
También puede usarse con un pipeline. Ten en cuenta que como los pipelines no están diseñados para ser usados con pares de frases, tienes que usar este truco de doble lista:
from transformers import pipeline
>>> pipe=pipeline(model="manueldeprada/FactCC")
>>> pipe([[[text1,summary1]],[[text2,summary2]]],truncation='only_first',padding='max_length')
# output [{'label': 'INCORRECT', 'score': 0.9979124665260315}, {'label': 'CORRECT', 'score': 0.879124665260315}]
Ejemplo de cómo realizar inferencia en lotes para reproducir los resultados de los autores en el conjunto de prueba:
def batched_FactCC(text_l, summary_l, max_length=512):
input_dict = tokenizer(text_l, summary_l, max_length=max_length, padding='max_length', truncation='only_first', return_tensors='pt')
with torch.no_grad():
logits = model(**input_dict).logits
preds = logits.argmax(dim=1)
return logits, preds
texts = []
claims = []
labels = []
with open('factCC/annotated_data/test/data-dev.jsonl', 'r') as file:
for line in file:
obj = json.loads(line) # Cargar los datos JSON de cada línea
texts.append(obj['text'])
claims.append(obj['claim'])
labels.append(model.config.label2id[o['label']])
preds = []
batch_size = 8
for i in tqdm(range(0, len(texts), batch_size)):
batch_texts = texts[i:i+batch_size]
batch_claims = claims[i:i+batch_size]
_, preds = fact_cc(batch_texts, batch_claims)
preds.extend(preds.tolist())
print(f"F1 micro: {f1_score(labels, preds, average='micro')}")
print(f"Balanced accuracy: {balanced_accuracy_score(labels, preds)}")
Funcionalidades
- Clasificación de texto
- Transformadores
- Compatibilidad con PyTorch
- Compatibilidad con TensorFlow y JAX
- Predicción de factualidad en resúmenes textuales
Casos de uso
- Verificación de hechos en resúmenes textuales
- Evaluación de la consistencia factual en textos generados automáticamente
- Moderación de contenido basado en la factualidad