Make-A-Video SD JAX
Un modelo de difusión latente para la síntesis de texto a video. Pruébalo con una demostración interactiva en los espacios de HuggingFace. El código de entrenamiento, la implementación en PyTorch y FLAX están disponibles aquí: https://github.com/lopho/makeavid-sd-tpu. Este modelo amplía un modelo de generación de imágenes por difusión latente (Stable Diffusion v1.5 Inpaint) con convolución temporal y autoatención temporal portados de Make-A-Video PyTorch. Luego ha sido afinado durante ~150k pasos en un conjunto de datos de 10,000 videos temáticos sobre baile. Luego, durante ~50k pasos adicionales con datos extra de videos genéricos mezclados en el conjunto original. Este modelo utilizó pesos preentrenados por lxj616 en 286 clips de video en time-lapse para la inicialización.
Como usar
Procedimiento de entrenamiento:
1. De cada muestra de video se selecciona un rango aleatorio de 24 cuadros.
2. Cada latente de video se codifica en representaciones latentes con forma 4 x 24 x H/8 x W/8.
3. El latente del primer cuadro de cada video se repite a lo largo de la dimensión de los cuadros como guía adicional (referido como imagen de pista).
4. Los latentes de pista y de video se apilan para producir una forma de 8 x 24 x H/8 x W/8.
5. El último canal de entrada se conserva para propósitos de enmascaramiento (no se usa durante el entrenamiento, se establece en cero).
6. Los indicaciones de texto se codifican mediante el codificador de texto CLIP.
7. Los latentes de video con ruido añadido y los textos codificados por CLIP se introducen en el UNet para predecir el ruido añadido.
8. La pérdida es el objetivo de reconstrucción entre el ruido añadido y el ruido predicho mediante el error cuadrático medio (mse/l2).
Hiperparámetros:
- Tamaño del lote: 1 x 4
- Tamaño de la imagen: 512 x 512
- Conteo de cuadros: 24
- Optimizador: AdamW (beta_1 = 0.9, beta_2 = 0.999, decaimiento de peso = 0.02)
- Plan de LR:
- 2 x 10 épocas: LR warmup durante 1 época luego constante en 5e-5 (10,000 muestras por época)
- 2 x 20 épocas: LR warmup durante 1 época luego constante en 5e-5 (10,000 muestras por época)
- 1 x 9 épocas: LR warmup durante 1 época a 5e-5 luego enfriamiento cosenoidal a 1e-8
- Datos adicionales mezclados, ver Datos de Entrenamiento
- 1 x 5 épocas: LR warmup durante 0.5 épocas a 2.5e-5 luego constante (17,000 muestras por época)
- 1 x 5 épocas: LR warmup durante 0.5 épocas a 5e-6 luego enfriamiento cosenoidal a 2.5e-6 (17,000 muestras por época)
- Algunos reinicios fueron necesarios debido a la aparición de NaNs en el gradiente.
Hardware:
- TPUv4-8 proporcionado por Google Cloud para el evento HuggingFace JAX/Diffusers Sprint.
Funcionalidades
- Implementación en PyTorch y FLAX
- Entrenado con 10,000 videos temáticos sobre baile
- Afinado adicionalmente con 7,000 videos genéricos
- Codificación de texto mediante el codificador CLIP
- Predicción del ruido añadido mediante UNet
- Objetivo de reconstrucción mediante error cuadrático medio (mse/l2)
- IPU TPUv4-8 proporcionado por Google Cloud
Casos de uso
- Entendimiento de las limitaciones y sesgos de los modelos generativos de video
- Desarrollo de herramientas educativas o creativas
- Uso artístico
- Cualquier propósito que desees