InternVideo2-Chat-8B
OpenGVLab
Clasificación de video
Para enriquecer aún más la semántica incrustada en InternVideo2 y mejorar su usabilidad en las comunicaciones humanas, afinamos InternVideo2 incorporándolo en un VideoLLM con un LLM y un video BLIP. Empleamos el esquema de aprendizaje progresivo en VideoChat utilizando InternVideo2 como el codificador de video y entrenamos un video BLIP para comunicarse con LLM de código abierto. Durante el entrenamiento, el codificador de video se actualizará. Las recetas de entrenamiento detalladas están en VideoChat. El BaseLLM de este modelo es Mistral-7B.
Como usar
Cómo usar el modelo
import os
token = os.environ['HF_TOKEN']
import torch
tokenizer = AutoTokenizer.from_pretrained('OpenGVLab/InternVideo2-Chat-8B', trust_remote_code=True, use_fast=False)
from transformers import AutoTokenizer, AutoModel
model = AutoModel.from_pretrained(
'OpenGVLab/InternVideo2-Chat-8B',
torch_dtype=torch.bfloat16,
trust_remote_code=True).cuda()
from decord import VideoReader, cpu
from PIL import Image
import numpy as np
import numpy as np
import decord
from decord import VideoReader, cpu
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms import PILToTensor
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
decord.bridge.set_bridge('torch')
def get_index(num_frames, num_segments):
seg_size = float(num_frames - 1) / num_segments
start = int(seg_size / 2)
offsets = np.array([
start + int(np.round(seg_size * idx)) for idx in range(num_segments)
])
return offsets
def load_video(video_path, num_segments=8, return_msg=False, resolution=224, hd_num=4, padding=False):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
num_frames = len(vr)
frame_indices = get_index(num_frames, num_segments)
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
transform = transforms.Compose([
transforms.Lambda(lambda x: x.float().div(255.0)),
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize(mean, std)
])
frames = vr.get_batch(frame_indices)
frames = frames.permute(0, 3, 1, 2)
frames = transform(frames)
T_, C, H, W = frames.shape
if return_msg:
fps = float(vr.get_avg_fps())
sec = ', '.join([str(round(f / fps, 1)) for f in frame_indices])
msg = f'El video contiene {len(frame_indices)} cuadros muestreados en {sec} segundos.'
return frames, msg
else:
return frames
video_path = 'yoga.mp4'
# muestra uniformemente 8 cuadros del video
video_tensor = load_video(video_path, num_segments=8, return_msg=False)
video_tensor = video_tensor.to(model.device)
chat_history= []
response, chat_history = model.chat(tokenizer, '', 'describe the action step by step.', media_type='video', media_tensor=video_tensor, chat_history= chat_history, return_history=True,generation_config={'do_sample':False})
print(response)
response, chat_history = model.chat(tokenizer, '', 'What is she wearing?', media_type='video', media_tensor=video_tensor, chat_history= chat_history, return_history=True,generation_config={'do_sample':False})
print(response)
Funcionalidades
- Codificador de video
- BLIP de video
- Esquema de aprendizaje progresivo
- Integración con LLM de código abierto
- Transformers
- Safetensors
Casos de uso
- Clasificación de acciones en videos
- Describir el contenido de video paso a paso
- Identificar objetos y personas en videos
- Proveer una descripción detallada de la vestimenta en videos