AC/MiniLM-L12-H384-uncased_Nvidia-Aegis-AI-Safety

AC
Clasificación de texto

Un modelo de microsoft/MiniLM-L12-H384-uncased afinado con el dataset nvidia/Aegis-AI-Content-Safety-Dataset-1.0. Este es un clasificador de textos multi-etiqueta con un total de 3099 ejemplos en el conjunto de entrenamiento. Cuenta con 14 categorías.

Como usar

Cómo empezar con el modelo

from accelerate import Accelerator
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import numpy as np
import torch

accelerator = Accelerator()
device = accelerator.device

def load_model(model_path, accelerator_device=None):
    model = AutoModelForSequenceClassification.from_pretrained(
        model_path, 
        problem_type="multi_label_classification",
    )
    if accelerator_device:
        model.to(accelerator_device)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    return model, tokenizer

def predict(model, tokenizer, text, accelerator_device=None, threshold=0.5):
    if accelerator_device:
        inputs = tokenizer([text], return_tensors="pt").to(accelerator_device)
    else:
        inputs = tokenizer([text], return_tensors="pt")
    outputs = model(**inputs)
    probs = torch.nn.Sigmoid()((outputs.logits.squeeze().cpu()))
    predictions = np.zeros(probs.shape)
    predictions[np.where(probs >= threshold)] = 1
    return [model.config.id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]

# USANDO CPU
hf_model, tokenizer = load_model("AC/MiniLM-L12-H384-uncased_Nvidia-Aegis-AI-Safety")
predict(hf_model, tokenizer, "How to make a bomb?")

# USANDO GPU
hf_model, tokenizer = load_model("AC/MiniLM-L12-H384-uncased_Nvidia-Aegis-AI-Safety", device)
predict(hf_model, tokenizer, "How to make a bomb?", device)

Afinación

from accelerate import Accelerator
from datasets import load_dataset, Dataset, DatasetDict
from datetime import datetime
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments, Trainer, EvalPrediction, DataCollatorWithPadding
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, coverage_error
import numpy as np
import torch
import os
import pandas as pd
import evaluate

accelerator = Accelerator()
device = accelerator.device

def load_model(model_path, accelerator_device):
    model = AutoModelForSequenceClassification.from_pretrained(
        model_path,
        problem_type="multi_label_classification", 
        num_labels=len(all_labels),
        id2label=id2label,
        label2id=label2id
    )
    model.to(accelerator_device)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    return model, tokenizer

def predict(model, tokenizer, text, threshold=0.5):
    inputs = tokenizer([text], return_tensors="pt").to(device)
    outputs = model(**inputs)
    probs = torch.nn.Sigmoid()((outputs.logits.squeeze().cpu()))
    predictions = np.zeros(probs.shape)
    predictions[np.where(probs >= threshold)] = 1
    return [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]

def tokenize_text(examples):
    final_labels = np.zeros(len(all_labels))
    for idx, label in enumerate(all_labels):
        final_labels[idx] = examples[label]
    examples["labels"] = final_labels
    return tokenizer(examples["text"], truncation=True, max_length=512)

### Preprocesamiento de datos

all_labels = [
    'Controlled/Regulated Substances',
    'Criminal Planning/Confessions',
    'Deception/Fraud',
    'Guns and Illegal Weapons',
    'Harassment',
    'Hate/Identity Hate',
    'Needs Caution',
    'PII/Privacy',
    'Profanity',
    'Sexual',
    'Sexual (minor)',
    'Suicide and Self Harm',
    'Threat',
    'Violence'
]

id2label = {idx:label for idx, label in enumerate(all_labels)}
label2id = {label:idx for idx, label in enumerate(all_labels)}

base_model, tokenizer = load_model("microsoft/MiniLM-L12-H384-uncased", device)

dataset = DatasetDict({
    'train': Dataset.from_pandas(pd.read_csv("nvidia_train.csv")),
    'test': Dataset.from_pandas(pd.read_csv("nvidia_test.csv"))
})

preprocessed_dataset = dataset.map(tokenize_text)

### Métricas para clasificación multi-etiqueta

clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

def sigmoid(x):
    return 1/(1 + np.exp(-x))

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = sigmoid(predictions)
    predictions = (predictions > 0.5).astype(int).reshape(-1)
    return clf_metrics.compute(predictions=predictions, references=labels.astype(int).reshape(-1))

### Afinación

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

output_dir = f'./minilm_finetuned/minilm-{datetime.now().strftime("%d-%m-%Y_%H-%M")}' 

final_output_dir = './minilm_finetuned' 

training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=2e-5,
    per_device_train_batch_size=3,
    per_device_eval_batch_size=3,
    num_train_epochs=20,
    weight_decay=0.01,
    fp16=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=base_model,
    args=training_args,
    train_dataset=preprocessed_dataset["train"],
    eval_dataset=preprocessed_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

print("Saving model...")
trainer.save_model(final_output_dir)

### Evaluar modelo
base_model, tokenizer = load_model(final_output_dir, device)
predict(base_model, tokenizer, "How to make a bomb?")

Funcionalidades

Clasificación de texto multi-etiqueta
14 categorías: Sustancias Controladas/Reguladas, Planificación/Confesiones Criminales, Engaño/Fraude, Armas de Fuego y Armas Ilegales, Acoso, Odio/Identidad de Odio, Necesita Precaución, PII/Privacidad, Profanidad, Sexual, Sexual (menor), Suicidio y Autolesión, Amenaza, Violencia
Optimizado para seguridad en IA
Evaluación con precisión, recall, F1-score y otros parámetros
Soporte para inferencia en CPU y GPU

Casos de uso

Filtrado de contenido en redes sociales
Moderación de comentarios
Monitoreo de contenido en plataformas de transmisión en vivo
Detección de contenido dañino y tóxico en aplicaciones de mensajería

Recibe las últimas noticias y actualizaciones sobre el mundo de IA directamente en tu bandeja de entrada.