fix(LTX AV): pass correct timesteps to cross-attention AdaLN modulation#14097
Conversation
📝 WalkthroughWalkthroughThis PR adjusts cross-attention ADaLN timestep preparation in 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@comfy/ldm/lightricks/av_model.py`:
- Around line 782-788: The gate timestep construction uses tensor.max() without
a dim which reduces across the entire tensor and couples the batch; change the
logic in the calls to av_ca_a2v_gate_adaln_single and
av_ca_v2a_gate_adaln_single to compute a per-batch max (e.g. use
a_timestep_scaled.max(dim=1, keepdim=True) / timestep_scaled.max(dim=1,
keepdim=True) and then broadcast/expand to a_timestep_flat / timestep_flat) so
each sample gets its own gate timestep instead of a global scalar; update uses
of a_timestep_scaled.max() and timestep_scaled.max() accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 99a343b1-96c6-44e9-a89e-0e99c481e928
📒 Files selected for processing (1)
comfy/ldm/lightricks/av_model.py
| a_timestep_scaled.max().expand_as(timestep_flat) * av_ca_factor, | ||
| {"resolution": None, "aspect_ratio": None}, | ||
| batch_size=batch_size, | ||
| hidden_dtype=hidden_dtype, | ||
| ) | ||
| av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( | ||
| timestep.max().expand_as(a_timestep_flat) * av_ca_factor, | ||
| timestep_scaled.max().expand_as(a_timestep_flat) * av_ca_factor, |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Search for timestep_scaled tensor creation and usage patterns
rg -nP -C3 'timestep_scaled\s*=' --type=py -g 'comfy/ldm/lightricks/*.py'
# Check how compute_prompt_timestep handles max operation for comparison
rg -nP -A5 'def compute_prompt_timestep' --type=py -g 'comfy/ldm/lightricks/*.py'Repository: Comfy-Org/ComfyUI
Length of output: 2670
Flag global .max() when building gate timestep (possible batch coupling)
In comfy/ldm/lightricks/av_model.py the A2V/V2A gate timesteps are built with a_timestep_scaled.max() / timestep_scaled.max() (lines 782/788). Without a dim, PyTorch reduces to a single scalar over all elements, so every batch element would receive the same gate timestep if timesteps vary across the batch.
This is inconsistent with compute_prompt_timestep() in comfy/ldm/lightricks/model.py, which computes maxima per batch segment (max(dim=1, ...)) and then broadcasts. Consider switching the gate timestep construction to a per-batch max (as in the proposed diff) to avoid unintended batch coupling when token/timestep values differ per sample.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@comfy/ldm/lightricks/av_model.py` around lines 782 - 788, The gate timestep
construction uses tensor.max() without a dim which reduces across the entire
tensor and couples the batch; change the logic in the calls to
av_ca_a2v_gate_adaln_single and av_ca_v2a_gate_adaln_single to compute a
per-batch max (e.g. use a_timestep_scaled.max(dim=1, keepdim=True) /
timestep_scaled.max(dim=1, keepdim=True) and then broadcast/expand to
a_timestep_flat / timestep_flat) so each sample gets its own gate timestep
instead of a global scalar; update uses of a_timestep_scaled.max() and
timestep_scaled.max() accordingly.
|
@izorinLightricks Hey, here you say "Each modality's scale/shift should be conditioned on its own timestep." But when LTX-2.3 dropped, the Lightricks/LTX-2 reference code was explicitly changed to condition each modality's scale/shift on the other modality's timestep, with Is the reference code also incorrect here, or am I misunderstanding something? |
Hi! Thank you for being attentive to this matter. You are correct, reference code is also incorrect and was fixed internally (will be released later). This fix brings inference behavior in line with training. Now reference code: |
@izorinLightricks Thanks for the reply and clarification! |
Fix two bugs in the cross-attention AdaLN timestep inputs in
LTXAVModel:a_timestep/timestepinstead of the_scaledvariants that the rest of the model relies on.