diffusers-generation-text-box
Stable Diffusion es un modelo de difusión de texto a imagen latente capaz de generar imágenes fotorrealistas a partir de cualquier entrada de texto. Incluye una combinación de un texto codificador fijo pre-entrenado (CLIP ViT-L/14) y un modelo de difusión entrenado en el espacio latente del codificador. Este modelo se ha afinado en 225,000 pasos a una resolución de 512x512 en 'laion-aesthetics v2 5+' con un 10% de eliminación de la condicionante de texto para mejorar el muestreo de guía sin clasificador.
Como usar
Se recomienda usar la biblioteca Diffusers de 🤗 para ejecutar Stable Diffusion.
PyTorch
pip install --upgrade diffusers transformers scipy
import torch
from diffusers import StableDiffusionPipeline
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to(device)
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
Si tienes limitada la memoria de la GPU y tienes menos de 4GB de RAM de GPU disponible, asegúrate de cargar el StableDiffusionPipeline en precisión de float16 en lugar de la precisión por defecto float32 como se hace arriba.
import torch
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to(device)
pipe.enable_attention_slicing()
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
Para cambiar el planificador de ruido, pásalo a from_pretrained:
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
model_id = "CompVis/stable-diffusion-v1-4"
# Usa el planificador Euler aquí en su lugar
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
JAX/Flax
Para utilizar StableDiffusion en TPUs y GPUs para una inferencia más rápida puedes aprovechar JAX/Flax.
Ejecutando el pipeline con el PNDMScheduler por defecto
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="flax", dtype=jax.numpy.bfloat16
)
prompt = "a photo of an astronaut riding a horse on mars"
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
Si tienes limitada la memoria de la TPU, asegúrate de cargar el FlaxStableDiffusionPipeline en precisión bfloat16 en lugar de la precisión por defecto float32 como se hace arriba.
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jax.numpy.bfloat16
)
prompt = "a photo of an astronaut riding a horse on mars"
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
Funcionalidades
- Generación de imágenes fotorrealistas a partir de texto
- Modelo basado en difusión latente
- Usa el codificador de texto CLIP ViT-L/14
- Implementado principalmente con la biblioteca Diffusers
- Chequeador de seguridad para revisar conceptos NSFW
- Múltiples checkpoints durante el entrenamiento
Casos de uso
- Implementación segura de modelos que tienen el potencial de generar contenido dañino.
- Prueba y comprensión de las limitaciones y sesgos de los modelos generativos.
- Generación de obras de arte y uso en procesos de diseño y otras prácticas artísticas.
- Aplicaciones en herramientas educativas o creativas.
- Investigación sobre modelos generativos.