imdb_roberta_large
Siki-77
Clasificación de texto
Este modelo es una versión ajustada de roberta-large en un conjunto de datos desconocido. El modelo logra los siguientes resultados en el conjunto de evaluación: Pérdida: 0.1728, Precisión: 0.9627.
Como usar
from datasets import load_dataset
imdb = load_dataset('imdb')
import numpy as np
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import EarlyStoppingCallback
import evaluate
model_name = 'roberta-large'
id2label = {0: 'NEGATIVE', 1: 'POSITIVE'}
label2id = {'NEGATIVE': 0, 'POSITIVE': 1}
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return accuracy.compute(predictions=predictions, references=labels)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def preprocess_function(examples):
return tokenizer(examples['text'], truncation=True)
tokenized_imdb = imdb.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
accuracy = evaluate.load('accuracy')
model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=2, id2label=id2label, label2id=label2id
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
bts = 8
accumulated_step = 2
training_args = TrainingArguments(
output_dir=f'5imdb_{model_name.replace('-', '_')}',
learning_rate=2e-5,
per_device_train_batch_size=bts,
per_device_eval_batch_size=bts,
num_train_epochs=2,
weight_decay=0.01,
evaluation_strategy='epoch',
save_strategy='epoch',
load_best_model_at_end=True,
push_to_hub=True,
gradient_accumulation_steps=accumulated_step,
)
early_stopping = EarlyStoppingCallback(early_stopping_patience=3)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_imdb['train'],
eval_dataset=tokenized_imdb['test'],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[early_stopping],
)
trainer.train()
Funcionalidades
- Clasificación de texto
- Optimizado con roberta-large
- Compatible con AutoTrain
- Compatible con Safetensors
- Compatible con Inference Endpoints
Casos de uso
- Clasificación de opiniones en reseñas de películas
- Clasificación de sentimiento en textos