Stable Diffusion XS, with a ControlNet based on sketch, can generate 2-4 images per second on a consumer grade desktop in the browser, at a total compressed size (text encoder, controlnet, u-net, autoencoder, all the pieces together) of under 400MB when compressed.

As a demo in the browser, this is a clock that downloads all 900 MB of ONNX (compressed to ~400MB) and renders the time as a nice household clock: Diffusion Local Time, extra small

Four pieces: 1 text encoder, 1 controlnet, 1 u-net, 1 autoencoder, 0 schedulers

Stable Diffusion XS operates in one denoising step.

The stock text encoder from Stable Diffusion 1.5, CLIP ViT L/14, takes tokens in a prompt and embeds them into a 1x77x768 dimensional embedding.

The controlnet takens a control image at a certain strength, and attaches to the existing u-net to steer the denoising process.

The u-net operates in one step to denoise pure normally distributed noise into a latent-space image.

The autoencoder decodes from latent space to pixel space.

The scheduler can be slightly complex, especially to serialize to existing formats. Because the model operates in one step, the role of the scheduler, to provide a time embedding for the latents at a timestep, can be compiled away, because the latents are pure noise.

Hotpatching 🤗Diffusers

PyTorch continues to innovate in serialization and export formats. This uses the existing TorchScript exporter, and whenever the exporter trips up on the mostly-unmodified Diffusers pipeline, the fix has generally been a minimal hotpatch. These have been collected at lsb/diffusers@sdxs-hotpatches.

5.7-bit floating point

There are lots of ways to represent numbers! ONNX in the browser can accelerate 32-bit floats and 16-bit floats as of 2024 via WebGPU.

Different modules in the Stable Diffusion pipeline have different amounts of resilience to the errors of quantizing stock 32-bit floats to a lower bit-depth.

The text encoder is highly compressible, and the controlnet has a u-net of similar size to the stock Stable Diffusion XS u-net. Both can be compressed heavily without much image quality degradation. PyTorch has two different 8-bit floating point formats, one with a minimum of 2^-11 and another with a minimum of 2^-20, one with three bits of mantissa precision and one with two bits of mantissa precision. By round tripping our floating point numbers through both formats, we can keep our existing fast kernels that operate on 16-bit floats and get the compressed space savings at load time and for other storage. ONNX currently requires static weights for most of its layers, which makes it impractical to use something like a palettization to decrease the size of weights as is possible in CoreML.

The u-net is more sensitive to reductions in bit depth, but we can comfortably reduce the weights to 8-bit floats without much loss in quality.

The autoencoder is more sensitive and only 4 million parameters so we only reduce the weights to fp16.

When comparing the size of the compressed fp16 model (435MB) to full-size (1160MB), the entropy suggests that this is a 5.7-bit model.

We can compress the linear layers from the fp16 model without decreasing image quality, and use the linear quantization out of the box, to decrease the bit depth of linear weights from fp16 to uint8. Empirically this is slower that full fp16 on WebGPU on some consumer desktops, but WebGPU is still rolling out and WASM is highly effective as well, for which most of the ONNX spec is operated at high speed. The reported size at the beginning is the mixed 8-bit / 16-bit model.

Text encoder optional

Part of the interesting idea of the controlnet is to be able to control the results of image synthesis through visual instead of textual means. For narrowly-targeted use cases, this might be generating the same image prompt with different controls, as a form of style transfer.

By pre-generating the prompt embeddings for a well-defined set of prompts, it is possible to ship this to production without any text encoder, and decrease the footprint while increasing how narrowly-defined the image generation is, thus enhancing safety and debuggability.

Without the text encoder, just the controlnet + u-net + vae are 743MB uncompressed, and only 337MB compressed, at mixed 16-bit convolutional/8-bit linear layer compression. A PyTorch model with manually-applied post-training palettization achieved similar quality at under 250MB compressed.

Let me know what you think!

The source is available on Github, the models are available on Github and Huggingface, and it is a work in progress to add this functionality to Transformers.js! Give me a shout