Skip to content

fix(LTX AV): pass correct timesteps to cross-attention AdaLN modulation#14097

Merged
comfyanonymous merged 1 commit into
Comfy-Org:masterfrom
izorinLightricks:fix/adaln
May 26, 2026
Merged

fix(LTX AV): pass correct timesteps to cross-attention AdaLN modulation#14097
comfyanonymous merged 1 commit into
Comfy-Org:masterfrom
izorinLightricks:fix/adaln

Conversation

@izorinLightricks

Copy link
Copy Markdown
Contributor

Fix two bugs in the cross-attention AdaLN timestep inputs in LTXAVModel:

  1. Scale/shift modulation was driven by the wrong modality's sigma. The audio scale/shift AdaLN was being fed the video sigma, and the video scale/shift AdaLN was being fed the audio sigma. Each modality's scale/shift should be conditioned on its own timestep.
  2. Gate modulation used the unscaled timestep. The a2v and v2a gate AdaLNs used a_timestep / timestep instead of the _scaled variants that the rest of the model relies on.

@coderabbitai

coderabbitai Bot commented May 25, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

This PR adjusts cross-attention ADaLN timestep preparation in LTXAVModel._prepare_timestep. The video scale-shift path now uses timestep_flat directly instead of deriving it from audio timestep expansion. The A2V and V2A gate-noise paths now compute their expanded maxima from the scaled timestep tensors (a_timestep_scaled.max() and timestep_scaled.max()) rather than unscaled values before multiplying by av_ca_factor.

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title directly and accurately summarizes the main change: fixing incorrect timesteps passed to cross-attention AdaLN modulation in the LTX AV model.
Description check ✅ Passed The description is directly related to the changeset, clearly explaining the two bugs being fixed in the cross-attention AdaLN timestep inputs.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@comfy/ldm/lightricks/av_model.py`:
- Around line 782-788: The gate timestep construction uses tensor.max() without
a dim which reduces across the entire tensor and couples the batch; change the
logic in the calls to av_ca_a2v_gate_adaln_single and
av_ca_v2a_gate_adaln_single to compute a per-batch max (e.g. use
a_timestep_scaled.max(dim=1, keepdim=True) / timestep_scaled.max(dim=1,
keepdim=True) and then broadcast/expand to a_timestep_flat / timestep_flat) so
each sample gets its own gate timestep instead of a global scalar; update uses
of a_timestep_scaled.max() and timestep_scaled.max() accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 99a343b1-96c6-44e9-a89e-0e99c481e928

📥 Commits

Reviewing files that changed from the base of the PR and between 0077d78 and b496d40.

📒 Files selected for processing (1)
  • comfy/ldm/lightricks/av_model.py

Comment on lines +782 to +788
a_timestep_scaled.max().expand_as(timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
timestep_scaled.max().expand_as(a_timestep_flat) * av_ca_factor,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for timestep_scaled tensor creation and usage patterns
rg -nP -C3 'timestep_scaled\s*=' --type=py -g 'comfy/ldm/lightricks/*.py'

# Check how compute_prompt_timestep handles max operation for comparison
rg -nP -A5 'def compute_prompt_timestep' --type=py -g 'comfy/ldm/lightricks/*.py'

Repository: Comfy-Org/ComfyUI

Length of output: 2670


Flag global .max() when building gate timestep (possible batch coupling)

In comfy/ldm/lightricks/av_model.py the A2V/V2A gate timesteps are built with a_timestep_scaled.max() / timestep_scaled.max() (lines 782/788). Without a dim, PyTorch reduces to a single scalar over all elements, so every batch element would receive the same gate timestep if timesteps vary across the batch.

This is inconsistent with compute_prompt_timestep() in comfy/ldm/lightricks/model.py, which computes maxima per batch segment (max(dim=1, ...)) and then broadcasts. Consider switching the gate timestep construction to a per-batch max (as in the proposed diff) to avoid unintended batch coupling when token/timestep values differ per sample.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@comfy/ldm/lightricks/av_model.py` around lines 782 - 788, The gate timestep
construction uses tensor.max() without a dim which reduces across the entire
tensor and couples the batch; change the logic in the calls to
av_ca_a2v_gate_adaln_single and av_ca_v2a_gate_adaln_single to compute a
per-batch max (e.g. use a_timestep_scaled.max(dim=1, keepdim=True) /
timestep_scaled.max(dim=1, keepdim=True) and then broadcast/expand to
a_timestep_flat / timestep_flat) so each sample gets its own gate timestep
instead of a global scalar; update uses of a_timestep_scaled.max() and
timestep_scaled.max() accordingly.

@comfyanonymous comfyanonymous merged commit 57414da into Comfy-Org:master May 26, 2026
14 checks passed
@drozbay

drozbay commented May 26, 2026

Copy link
Copy Markdown
Contributor

@izorinLightricks Hey, here you say "Each modality's scale/shift should be conditioned on its own timestep." But when LTX-2.3 dropped, the Lightricks/LTX-2 reference code was explicitly changed to condition each modality's scale/shift on the other modality's timestep, with cross_timestep = cross_modality.sigma that then feeds into the scale/shift AdaLN: https://github.com/Lightricks/LTX-2/blob/822ce3c4b/packages/ltx-core/src/ltx_core/model/transformer/transformer_args.py#L260-L261

Is the reference code also incorrect here, or am I misunderstanding something?

@izorinLightricks

Copy link
Copy Markdown
Contributor Author

@izorinLightricks Hey, here you say "Each modality's scale/shift should be conditioned on its own timestep." But when LTX-2.3 dropped, the Lightricks/LTX-2 reference code was explicitly changed to condition each modality's scale/shift on the other modality's timestep, with cross_timestep = cross_modality.sigma that then feeds into the scale/shift AdaLN: https://github.com/Lightricks/LTX-2/blob/822ce3c4b/packages/ltx-core/src/ltx_core/model/transformer/transformer_args.py#L260-L261

Is the reference code also incorrect here, or am I misunderstanding something?

Hi! Thank you for being attentive to this matter. You are correct, reference code is also incorrect and was fixed internally (will be released later). This fix brings inference behavior in line with training.

Now reference code:

  cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
            modality_timesteps=modality.timesteps,
            cross_modality_sigma=cross_modality.sigma,
            timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
            batch_size=transformer_args.x.shape[0],
            hidden_dtype=modality.latent.dtype,
        )

    def _prepare_cross_attention_timestep(
        self,
        modality_timesteps: torch.Tensor,
        cross_modality_sigma: torch.Tensor,
        timestep_scale_multiplier: int,
        batch_size: int,
        hidden_dtype: torch.dtype,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Prepare A-V cross-attention AdaLN inputs."""
        av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier

        scale_shift_timestep, _ = self.cross_scale_shift_adaln(
            (modality_timesteps * timestep_scale_multiplier).flatten(),
            hidden_dtype=hidden_dtype,
        )
        scale_shift_timestep = scale_shift_timestep.view(batch_size, -1, scale_shift_timestep.shape[-1])

        gate_noise_timestep, _ = self.cross_gate_adaln(
            (cross_modality_sigma * timestep_scale_multiplier * av_ca_factor).flatten(),
            hidden_dtype=hidden_dtype,
        )
        gate_noise_timestep = gate_noise_timestep.view(batch_size, -1, gate_noise_timestep.shape[-1])

        return scale_shift_timestep, gate_noise_timestep

@drozbay

drozbay commented May 27, 2026

Copy link
Copy Markdown
Contributor

Hi! Thank you for being attentive to this matter. You are correct, reference code is also incorrect and was fixed internally (will be released later). This fix brings inference behavior in line with training.

@izorinLightricks Thanks for the reply and clarification!

AkaneTendo25 added a commit to AkaneTendo25/musubi-tuner that referenced this pull request May 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants