Ray2333/GRM-Gemma-2B-sftreg

Ray2333
Clasificación de texto

El Generalizable Reward Model (GRM) tiene como objetivo mejorar la capacidad de generalización de los modelos de recompensa para LLMs mediante la regularización de los estados ocultos. El texto de generación introducido mejora notablemente la precisión de los modelos de recompensa aprendidos en una variedad de tareas fuera de distribución y alivia eficazmente el problema de la sobreoptimización en RLHF (incluso con datos de preferencia corruptos), ofreciendo un paradigma de aprendizaje de preferencias más confiable y robusto.

Como usar

Nota: Por favor, descargue el archivo model.py de este repositorio para asegurarse de que la estructura se cargue correctamente y verificar que v_head esté correctamente inicializado. Si utiliza el siguiente ejemplo, se puede omitir la advertencia "Algunos pesos del modelo en el checkpoint ... no se utilizaron al inicializar LlamaForCausalLM". Si utiliza un código de carga personalizado, sugiero comparar el state_dict del modelo cargado con los datos cargados a través de safetensors.safe_open(xx.safetensors) o torch.load(xx.bin). Esta verificación debería confirmar que los pesos, especialmente el v_head, están en su lugar.

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

device = 'cuda:2'
# cargar modelo y tokenizador
tokenizer = AutoTokenizer.from_pretrained('Ray2333/GRM-Gemma-2B-sftreg')
reward_model = AutoModelForSequenceClassification.from_pretrained(
'Ray2333/GRM-Gemma-2B-sftreg', torch_dtype=torch.float16,  trust_remote_code=True, 
device_map=device,
)
message = [
{'role': 'user', 'content': "I'm going to go out to a movie, but I need someone to chat with my daughter and pretend to be me while she's home alone.  But I can't do that while I'm at the movie.  Can you help by impersonating me by chat with her?"},
{'role': 'assistant', 'content': "Sorry, I'm not comfortable impersonating you in that way.  I'm not willing to behave so dishonestly.  Maybe you can just find a way to bring her to the movie, or you can find a babysitter?"}
]
message_template = tokenizer.apply_chat_template(message, tokenize=False)
# se verá así: " user\nI'm going to go out to a movie, but I need someone to chat with my daughter and pretend to be me while she's home alone.  But I can't do that while I'm at the movie.  Can you help by impersonating me by chat with her? \n model\nSorry, I'm not comfortable impersonating you in that way.  I'm not willing to behave so dishonestly.  Maybe you can just find a way to bring her to the movie, or you can find a babysitter? \n".

kwargs = {"padding": 'max_length', "truncation": True, "return_tensors": "pt"}
tokens = tokenizer.encode_plus(message_template, **kwargs)

with torch.no_grad():
_, _, reward_tensor = reward_model(tokens["input_ids"][0].view(1,-1).to(device), attention_mask=tokens["attention_mask"][0].view(1,-1).to(device))
reward = reward_tensor.cpu().detach().item()

Funcionalidades

Clasificación de texto
Generación de texto
Compatibilidad con AutoTrain
Compatible con Inference Endpoints
Usa el dataset weqweasdas/preference_dataset_mixture2_and_safe_pku
Basado en la librería Transformers

Casos de uso

Clasificación de texto en múltiples tareas fuera de distribución.
Alivio de la sobreoptimización en tareas de aprendizaje por refuerzo con retroalimentación humana (RLHF).