diffusers-generation-text-box

gligen
Texto a imagen

Stable Diffusion es un modelo de difusión de texto a imagen latente capaz de generar imágenes fotorrealistas a partir de cualquier entrada de texto. Incluye una combinación de un texto codificador fijo pre-entrenado (CLIP ViT-L/14) y un modelo de difusión entrenado en el espacio latente del codificador. Este modelo se ha afinado en 225,000 pasos a una resolución de 512x512 en 'laion-aesthetics v2 5+' con un 10% de eliminación de la condicionante de texto para mejorar el muestreo de guía sin clasificador.

Como usar

Se recomienda usar la biblioteca Diffusers de 🤗 para ejecutar Stable Diffusion.

PyTorch

pip install --upgrade diffusers transformers scipy

import torch
from diffusers import StableDiffusionPipeline

model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"

pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to(device)

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]  
    
image.save("astronaut_rides_horse.png")

Si tienes limitada la memoria de la GPU y tienes menos de 4GB de RAM de GPU disponible, asegúrate de cargar el StableDiffusionPipeline en precisión de float16 en lugar de la precisión por defecto float32 como se hace arriba.

import torch

pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to(device)
pipe.enable_attention_slicing()

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]  
    
image.save("astronaut_rides_horse.png")

Para cambiar el planificador de ruido, pásalo a from_pretrained:

from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler

model_id = "CompVis/stable-diffusion-v1-4"

# Usa el planificador Euler aquí en su lugar
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]  
    
image.save("astronaut_rides_horse.png")

JAX/Flax

Para utilizar StableDiffusion en TPUs y GPUs para una inferencia más rápida puedes aprovechar JAX/Flax. Ejecutando el pipeline con el PNDMScheduler por defecto

import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard

from diffusers import FlaxStableDiffusionPipeline

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="flax", dtype=jax.numpy.bfloat16
)

prompt = "a photo of an astronaut riding a horse on mars"

prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

Si tienes limitada la memoria de la TPU, asegúrate de cargar el FlaxStableDiffusionPipeline en precisión bfloat16 en lugar de la precisión por defecto float32 como se hace arriba.

import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard

from diffusers import FlaxStableDiffusionPipeline

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jax.numpy.bfloat16
)

prompt = "a photo of an astronaut riding a horse on mars"

prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

Funcionalidades

Generación de imágenes fotorrealistas a partir de texto
Modelo basado en difusión latente
Usa el codificador de texto CLIP ViT-L/14
Implementado principalmente con la biblioteca Diffusers
Chequeador de seguridad para revisar conceptos NSFW
Múltiples checkpoints durante el entrenamiento

Casos de uso

Implementación segura de modelos que tienen el potencial de generar contenido dañino.
Prueba y comprensión de las limitaciones y sesgos de los modelos generativos.
Generación de obras de arte y uso en procesos de diseño y otras prácticas artísticas.
Aplicaciones en herramientas educativas o creativas.
Investigación sobre modelos generativos.