FLUX.2 Klein — How Inference Works

12 min read Original article ↗

From text prompt to pixels, one component at a time.

Geronimo

You give it a text prompt. You get an image.

                         ┌─────────────────────┐
"a beautiful │ │
mountain landscape" ─▶ │ FLUX.2 Klein │ ──▶ [image]
│ │
└─────────────────────┘

Press enter or click to view image in full size

Here’s what actually happens:

prompt
→ Text Encoder → text tokens (B, 512, 7680)
→ Sample noise → noisy latent (B, H/16 × W/16, 128)
→ Denoise loop → clean latent (B, H/16 × W/16, 128)
→ AE Decode → pixels (B, 3, H, W)

I get the big picture, but why does the text encoder always return a 512x7680 tensor? How does the AutoEncoder magically compress an image? If you already know the answer, skip this blog post.

The official code is spread across ten files. We build it from scratch, one component at a time, and explain every input and output.

Text Encoder

The text encoder is Qwen3–4B, an LLM. But it’s not generating anything. We run the prompt through it and steal hidden states from the middle of the network.

Why not just take the final layer output? An LLM is tuned for predicting the next token. Middle layers carry richer semantic content. The final layer has specialized away from meaning and towards next-token prediction. BFL chose layers 9, 18, 27.

Which layers

OUTPUT_LAYERS_QWEN3 = [9, 18, 27] # Qwen3–4B has 36 layers total

Hidden states at layers 9, 18, 27 — evenly spaced, every 9 layers. Each is 2560-dim. They get concatenated per token:

out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS_QWEN3], dim=1)

return rearrange(out, “b c l d -> b l (c d)”)

Output: (batch, 512, 7680) 512 tokens, each 7680-dim (= 2560 × 3). This maps directly to context_in_dim=7680 in Klein4BParams.

The chat template gotcha

The prompt isn’t tokenized raw. It’s wrapped in a chat template first:

messages = [{“role”: “user”, “content”: prompt}]
text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, enable_thinking=False)
# → “<|im_start|>user\na cat<|im_end|>\n<|im_start|>assistant\n”

So the first few tokens are always the same boilerplate: <|im_start|>user\n, regardless of what you typed. Running ”hello” and ”hello this is a text” through the encoder, tokens 0–3 are identical. Token 4 is where your actual prompt starts.

token 0: [-768.0, 6.09, -25.75, …] # <|im_start|>
token 1: [ 34.25, 2.39, -4.56, …] # “user”
token 2: [ 29.25, 2.98, -2.06, …] # “\n”
token 3: [ 21.0, 0.31, -1.05, …] # same in both runs
token 4: diverges # actual prompt content starts here

Fixed length: always 512

tokenizer(text, padding=”max_length”, truncation=True, max_length=512)

Short prompts get padded. Long ones get truncated. The transformer downstream always sees (B, 512, 7680).

That’s one input to the transformer, an encoded prompt. The other input is a noisy image. I say image, but what we really shovel into the transformer is a latent.

AutoEncoder

The transformer never sees your actual image. It denoises latents (not pixels), a compressed representation of an image.

You could run diffusion directly on pixels. A 1024×1024 image is 3 million numbers. That’s what the transformer would process at every denoising step. The AutoEncoder compresses it to 500k numbers first.

The encoder outputs [128, H/16, W/16] latents, a 16x spatial compression.

Three steps get you there:

moments = self.encoder(x)                # (B, 64, H/8, W/8) — CNN, 8× spatial
mean = torch.chunk(moments, 2, dim=1)[0] # (B, 32, H/8, W/8) — take first half, discard second
z = rearrange(mean, “… c (i pi) (j pj) -> … (c pi pj) i j”, pi=2, pj=2)
# (B, 128, H/16, W/16) — pixel-unshuffle, 2× more spatial

Compression only works if there is redundancy. There’s a lot of blue and green in a landscape snapshot for example. You could sketch that photo from memory by recalling the positions of the sky, greenfield, the river and if there was a cow or not. No need to memorize each pixel. That’s what the autoencoder does. A series of convolutions find the patterns and where they occur. Context is important too. Is the cow on the field or floating in the sky? That’s why at the middle part, a single self-attention block allows all parts of the image to talk to each other. Then pixel-unshuffle folds spatial dimensions into channels.

A 1024×1024px RGB photo becomes a 64×64 latent with 128 channels.

conv_in: [3, 1024, 1024] → [128, 1024, 1024] 
level 0: [128, 1024, 1024] → [128, 512, 512]
level 1: [128, 512, 512] → [256, 256, 256]
level 2: [256, 256, 256] → [512, 128, 128] ← Attention here
level 3: [512, 128, 128] → [512, 128, 128]
conv_out: [512, 128, 128] → [64, 128, 128] ← Last 32 channels (logvar) dropped after this one
unshuffle: [32, 128, 128] → [128, 64, 64]

We just turned 3 million numbers into 500k numbers. Lossless, they claim. Let’s check.

Load Weights

The autoencoder is identical for FLUX.2 Klein and FLUX.2-dev. Download from the dev repo:

weight_path = huggingface_hub.hf_hub_download(
repo_id="black-forest-labs/FLUX.2-dev",
filename="ae.safetensors",
)

ae = AutoEncoder(AutoEncoderParams())

# load weights from file and into autoencoder
sd = load_sft(weight_path, device="cuda")
ae.load_state_dict(sd, strict=True, assign=True)

Image prep

The encoder expects pixels in [-1, 1]. Standard ToTensor() gives [0, 1], so normalize first:

# img is a PIL.Image
img_tensor = transforms.ToTensor()(img).to(device)
img_tensor = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(img_tensor)

H and W must be divisible by 16, therefore crop:

crop_box = (0, 0) + tuple(x // 16 * 16 for x in (img.width, img.height))
img = img.crop(crop_box)

Encode

Image reconstruction breaks if you forget the normalize or clamp. Without a batch dimension, the net fails. Once the inputs and outputs are right, encoding and decoding are simple function calls ae.encode and ae.decode.

# add expected batch dimension
img_tensor.unsqueeze_(0)

# encode to latent space
with torch.no_grad():
img_latent = ae.encode(img_tensor)

Without torch.no_grad(), torch stores gradients and your VRAM usage will skyrocket.

Converting back

After decoding, the output is in [-1, 1] again (approximately. it can drift slightly past the boundaries).

To get a PIL image:

# clamp first, then renormalize — do it the other way and you get artifacts
img_decoded = decoded.clamp(-1, 1) * 0.5 + 0.5
img_pil = transforms.ToPILImage()(img_decoded)

The order matters. Clamping after re-normalizing shifts the out-of-range values into the interior and distorts the colors.

Result

I can’t tell the difference.

Press enter or click to view image in full size

The FLUX.2 autoencoder has considerably improved over the FLUX.1 autoencoder.

That’s what the authors say about their new Autoencoder.

Really? How bad was the FLUX1 VAE?

Press enter or click to view image in full size

Same.

The FLUX.2 AE clearly outperforms the old model on this image though:

Press enter or click to view image in full size

Denoising

The goal: run the denoising transformer until the latent turns into something that matches the prompt.

The model’s forward call looks like this:

Signature:
model.forward(
x: torch.Tensor,
x_ids: torch.Tensor,
timesteps: torch.Tensor,
ctx: torch.Tensor,
ctx_ids: torch.Tensor,
guidance: torch.Tensor | None,
)

Six inputs. x is the noisy image, ctx is the encoded text prompt. timesteps I will explain in a second. But x_ids and ctx_ids, what are those?

x_ids and ctx_ids

The name is misleading. These aren’t identifiers, they are positional coordinates in 4D space.

Tokens are the “what”, x_ids and ctx_ids are the “where”. The transformer receives a flat sequence of inputs. It doesn’t know if patch #6 is top-left or bottom-right in the image. Or whether “horse” comes before or after “field”. To the attention mechanism, a shuffled sequence looks identical to the original.

For attention to be position-aware, every token needs a coordinate in a shared space. The FLUX space has four axes: (t, h, w, l).

  • Text tokens use the l dimension (h and w always zero).
  • Image tokens have a 2D location (h and w) but no sequence (l always zero).

For example:

text token 42      → (t=0, h=0, w=0, l=42)
image patch (3, 5) → (t=0, h=3, w=5, l=0)

prc_txt and prc_img build these coordinate tensors:

def prc_txt(x):
# x: (512, 7680)
x_ids = torch.cartesian_prod(
torch.arange(1), # t = 0
torch.arange(1), # h = 0 ← dummy
torch.arange(1), # w = 0 ← dummy
torch.arange(512), # l = 0..511
) # → (512, 4)
return x, x_ids

def prc_img(x):
# x: (128, 64, 64) for a 1024×1024 image
x_ids = torch.cartesian_prod(
torch.arange(1), # t = 0
torch.arange(64), # h = 0..63
torch.arange(64), # w = 0..63
torch.arange(1), # l = 0 ← dummy
) # → (4096, 4)
x = rearrange(x, "c h w -> (h w) c") # flatten spatial grid → sequence
return x, x_ids

Both return a flat sequence of tokens with 4D coordinates.

The transformer uses these coordinates to modulate attention. Tokens that are spatially close attend to each other differently than tokens that are far apart. The mechanism is called RoPE but the details aren’t important for inference. And I don’t get it. Yet. Maybe another blog post in the future.

What is the t dimension for?

3 dimensions should be enough, you might think. There’s one more: t.

t = time? Video frames? But FLUX.2 Klein is an image model.

t is for reference images. The model supports an image-conditioned mode: give it a reference image (“generate something like this”) alongside the denoising target. The target sits at t=0. References get stamped with t=10, t=20, … , large spacing so RoPE encodes strong separation between reference and target frames.

For plain text-to-image, everything is t=0. The t dimension exists, but does nothing.

Timesteps scheduler

The obvious schedule: evenly spaced timesteps from 1.0 → 0.0. Step by step from high noise to low noise.

timesteps = torch.linspace(1, 0, num_steps) # [1.0, 0.9, 0.8, …]
                   Timestep schedule              
┌────────────────────────────────────────────┐
1.00┤▚▄ │
│ ▀▚▄▖ │
│ ▝▀▄▖ │
│ ▝▀▄ │
│ ▀▚▖ │
│ ▝▀▄▖ │
│ ▝▀▄▄ │
│ ▀▚▄ │
│ ▀▚▄ │
│ ▀▚▄▖ │
│ ▝▚▄ │
│ ▀▄▖ │
│ ▝▀▄▖ │
│ ▝▀▄▖ │
0.00┤ ▝▀▄▄│
└┬──────────┬──────────┬─────────┬──────────┬┘
1.0 3.2 5.5 7.8 10.0

y-axis: %Noise

That’s not what we want. First go slow, then go fast towards the end. The intuition: early denoising steps determine composition and structure. Those are the ones that matter. Late steps refine details.

                   Timestep schedule              
┌────────────────────────────────────────────┐
1.00┤▚▄▄▄▄▄▄▄▄▄ │
│ ▀▀▀▀▀▄▄▄▄▄ │
│ ▀▀▀▀▚▄▄ │
│ ▀▀▚▄▄ │
│ ▀▀▚▄▖ │
│ ▝▀▜ │
│ ▌ │
│ ▐ │
│ ▌ │
│ ▐ │
│ ▌ │
│ ▐ │
│ ▌ │
│ ▐ │
0.00┤ ▚│
└┬──────────┬──────────┬─────────┬──────────┬┘
1.0 3.2 5.5 7.8 10.0

A very simple power curve from Karras et al. 2022 gets us a schedule just like that:

def get_schedule(num_steps, rho=5):
return torch.linspace(1, 0, num_steps + 1) ** (1/rho)

The official FLUX.2 repo goes further. It computes an adaptive mu parameter from image size and step count that shifts the curve toward the noisy end for large images. That’s complicated. The simple power schedule above is enough to generate coherent images.

The loop

Text to image is the name. What’s actually happening is random noise to image, guided by a text prompt.

Create some noise.

# Returns a tensor filled with random numbers from a normal distribution
img = torch.randn([128, img_h//16, img_w//16], dtype=dtype, device=device)

Press enter or click to view image in full size

The denoising loop is five lines:

for t_curr, t_next in zip(timesteps, timesteps[1:]):
timesteps_vec = torch.full((1,), t_curr, device=device, dtype=dtype)
with torch.no_grad():
pred = model.forward(
x=img, x_ids=img_ids,
ctx=txt, ctx_ids=txt_ids,
timesteps=timesteps_vec,
guidance=guidance_vec
)
img = img + (t_next — t_curr) * pred

This is an Euler step. The model predicts velocity pointing towards the clean image. Each step in the loop removes a little noise.

Let’s say we want to reach the clean image in 10 evenly spaced steps. The clean image is 10 steps away. We feed the model with 100% noise and take 1 step towards the predicted direction. We are at 90% noise now (and 10% image), the model predicts again, we take another step. After 10 steps we reach the clean image.

img stays flat throughout the loop. There’s no height or width while denoising. Before passing to the decoder, we unflatten. A list of 4096 tokens turns into a 64x64 grid. That’s the input for the decoder.

# Tensor [1 4096 128] -> [1 128 64 64]
latent = rearrange(img, "1 (h w) c -> 1 c h w", h=img_h//16)

# Tensor -> PIL Image
decoded = ae.decode(latent)

Press enter or click to view image in full size

All together

# 1. encode text
txt = encode_prompt(prompt).to(dtype)
txt, txt_ids = prc_txt(txt)

# 2. sample noise latent
img = torch.randn([128, img_h//16, img_w//16], dtype=dtype, device=device)
img, img_ids = prc_img(img)

# 3. add batch dim
txt, txt_ids = txt[None], txt_ids[None]
img, img_ids = img[None], img_ids[None]

# 4. denoise
timesteps = get_schedule(num_steps)
guidance_vec = torch.full((1,), 1.0, device=device, dtype=dtype)

for t_curr, t_next in tqdm(zip(timesteps, timesteps[1:])):
timesteps_vec = torch.full((1,), t_curr, device=device, dtype=dtype)
with torch.no_grad():
pred = model.forward(
x=img, x_ids=img_ids,
ctx=txt, ctx_ids=txt_ids,
timesteps=timesteps_vec,
guidance=guidance_vec
)
img = img + (t_next - t_curr) * pred

# 5. decode
latent = rearrange(img, "1 (h w) c -> 1 c h w", h=img_h//16)
result = ae_decode(latent)

That’s the whole text-to-image pipeline.

Image-to-image

Image-to-image starts from the same pure noise, but passes reference images to the model. This is where the t dimension comes in. The denoising target sits at t=0. Reference images get t=10, t=20, …, spaced by 10 so the model knows not to denoise them, just attend to them.

Assign each image a different t coordinate.

img_refs, img_refs_ids = zip(*[
prc_img(img_ref.squeeze(), t_coord=torch.tensor([(idx+1)*10]))
for idx, img_ref in enumerate(img_refs)
])
# ref. img #1 → t=10, ref. img #2 → t=20, ...

In the loop, the model receives the same four inputs. The reference images are not passed via a separate parameter but concatenated with the noise tokens as new x and x_ids. The text tokens remain unchanged.

pred = flow_model.forward(
x = torch.cat([img] + img_refs, dim=1),
x_ids = torch.cat([img_ids] + img_refs_ids, dim=1),
ctx=txt,
ctx_ids=txt_ids,
...
)
# model returns predictions for the full [img + refs] sequence → keep only img
pred = pred[:, :img.size(1)]

The model sees noise and references as one flat sequence and attends across all of them. The t coordinate is the only thing distinguishing “denoise this” from “use this as reference”. After the forward pass, strip the reference predictions. Only img gets updated.

Everything else is identical to text2img.

Remove and replace the tanks in image 1. Replace them with the flowers of image 2.

Press enter or click to view image in full size