InternVideo2-Chat-8B

OpenGVLab
Clasificación de video

Para enriquecer aún más la semántica incrustada en InternVideo2 y mejorar su usabilidad en las comunicaciones humanas, afinamos InternVideo2 incorporándolo en un VideoLLM con un LLM y un video BLIP. Empleamos el esquema de aprendizaje progresivo en VideoChat utilizando InternVideo2 como el codificador de video y entrenamos un video BLIP para comunicarse con LLM de código abierto. Durante el entrenamiento, el codificador de video se actualizará. Las recetas de entrenamiento detalladas están en VideoChat. El BaseLLM de este modelo es Mistral-7B.

Como usar

Cómo usar el modelo

import os
token = os.environ['HF_TOKEN']
import torch

tokenizer =  AutoTokenizer.from_pretrained('OpenGVLab/InternVideo2-Chat-8B', trust_remote_code=True, use_fast=False)

from transformers import AutoTokenizer, AutoModel
model = AutoModel.from_pretrained(
'OpenGVLab/InternVideo2-Chat-8B',
torch_dtype=torch.bfloat16,
trust_remote_code=True).cuda()

from decord import VideoReader, cpu
from PIL import Image
import numpy as np
import numpy as np
import decord
from decord import VideoReader, cpu
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms import PILToTensor
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
decord.bridge.set_bridge('torch')

def get_index(num_frames, num_segments):
    seg_size = float(num_frames - 1) / num_segments
    start = int(seg_size / 2)
    offsets = np.array([
        start + int(np.round(seg_size * idx)) for idx in range(num_segments)
    ])
    return offsets

def load_video(video_path, num_segments=8, return_msg=False, resolution=224, hd_num=4, padding=False):
    vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
    num_frames = len(vr)
    frame_indices = get_index(num_frames, num_segments)
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    transform = transforms.Compose([
        transforms.Lambda(lambda x: x.float().div(255.0)),
        transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.Normalize(mean, std)
    ])
    frames = vr.get_batch(frame_indices)
    frames = frames.permute(0, 3, 1, 2)
    frames = transform(frames)
    T_, C, H, W = frames.shape
    if return_msg:
        fps = float(vr.get_avg_fps())
        sec = ', '.join([str(round(f / fps, 1)) for f in frame_indices])
        msg = f'El video contiene {len(frame_indices)} cuadros muestreados en {sec} segundos.'
        return frames, msg
    else:
        return frames

video_path = 'yoga.mp4'
# muestra uniformemente 8 cuadros del video
video_tensor = load_video(video_path, num_segments=8, return_msg=False)
video_tensor = video_tensor.to(model.device)

chat_history= []
response, chat_history = model.chat(tokenizer, '', 'describe the action step by step.', media_type='video', media_tensor=video_tensor, chat_history= chat_history, return_history=True,generation_config={'do_sample':False})
print(response)

response, chat_history = model.chat(tokenizer, '', 'What is she wearing?', media_type='video', media_tensor=video_tensor, chat_history= chat_history, return_history=True,generation_config={'do_sample':False})
print(response)

Funcionalidades

Codificador de video
BLIP de video
Esquema de aprendizaje progresivo
Integración con LLM de código abierto
Transformers
Safetensors

Casos de uso

Clasificación de acciones en videos
Describir el contenido de video paso a paso
Identificar objetos y personas en videos
Proveer una descripción detallada de la vestimenta en videos