Segearth-R2-04

继续。下面进入 第 4 讲:train.py 主训练流程精读,并提前把 llava_phi.py 的 forward 主线接上。你要把这部分看成 SegEarth-R2 的“总控制台”。


第 4 讲:train.py 主训练流程

train.py 决定了 6 件事:

1. 训练参数怎么解析
2. SegEarthR2 主模型怎么加载
3. Mask2Former 分割模块怎么初始化
4. tokenizer 怎么扩展 [SEG]
5. LoRA 怎么注入
6. 数据集和 Trainer 怎么构建

这个文件不是模型创新本身,但它是你复现必须吃透的文件。官方 train.py 里定义了 ModelArgumentsDataArgumentsTrainingArguments,然后在 train() 函数中完成参数解析、模型加载、视觉模块初始化、LoRA 注入、数据模块构造和训练启动。(GitHub)


一、先看三个参数类

1. ModelArguments

它控制模型相关参数。

核心字段包括:

model_name_or_path
vision_tower
vision_tower_mask
mask_config
freeze_backbone
train_clip_backbone
train_swin_backbone
load_mask2former
projector_outdim
mm_projector_type

你可以这样理解:

model_name_or_path   → Mipha-3B / 语言模型主体
vision_tower         → SigLIP / 给 LLM 使用的图像编码器
vision_tower_mask    → Mask2Former / Swin 分割分支权重
mask_config          → Mask2Former 配置文件
freeze_backbone      → 是否冻结语言模型主体
train_clip_backbone  → 是否训练 SigLIP
train_swin_backbone  → 是否训练 Swin 分割骨干
load_mask2former     → 是否加载 Mask2Former 预训练权重

这里最关键的是:

vision_tower 和 vision_tower_mask 不是同一个模块

vision_tower 是给多模态大模型看的,vision_tower_mask 是给分割解码器用的。官方参数里同时出现 vision_towervision_tower_mask,训练流程里也分别初始化并移动到对应 dtype/device。(GitHub)


2. DataArguments

它控制数据相关参数。

核心字段:

base_data_path
data_ratio
switch_bs
segmentation
image_aspect_ratio
fix_dataset_len

含义:

base_data_path    → LaSeRS 数据集根目录
data_ratio        → 多数据集混合比例
switch_bs         → 多数据集训练时 batch 切换间隔
segmentation      → 是否启用分割任务
fix_dataset_len   → 是否固定数据集长度

当前官方训练主要使用 LaSeRS 数据,因此 data_ratio='1' 就够。make_unify_datamodule() 里会根据 data_ratio 构造 LaSeRSDataset,再封装成 UnifyDatasetSingleDatasetForBatch。(GitHub)


3. TrainingArguments

它继承 HuggingFace 的 transformers.TrainingArguments,控制训练超参数。

核心字段:

per_device_train_batch_size
gradient_accumulation_steps
gradient_checkpointing
deepspeed
output_dir
model_max_length
bf16 / fp16
lora_enable
lora_r
lora_alpha
lora_dropout
dataloader_num_workers

其中最重要的是:

lora_enable=True
lora_r
deepspeed
bf16
model_max_length=2048

SegEarth-R2 默认不是全量微调所有参数,而是 LoRA + 部分分割模块训练。train.py 中会根据 lora_enable 构造 LoraConfig,然后用 get_peft_model() 包装模型。(GitHub)


二、train() 函数主线

train() 是整个训练入口。你读源码时不要从第一行陷入细节,而是先按这条主线读:

parse_args
    ↓
get_mask_config
    ↓
SegEarthR2.from_pretrained
    ↓
initial_mask_module
    ↓
AutoTokenizer.from_pretrained
    ↓
initialize_vision_modules
    ↓
freeze / unfreeze modules
    ↓
tokenizer.add_tokens("[SEG]")
    ↓
LoRA 注入
    ↓
model.get_special_token
    ↓
SiglipImageProcessor
    ↓
make_unify_datamodule
    ↓
LLaVATrainer
    ↓
trainer.train()
    ↓
save_model

下面逐段拆。


三、参数解析

代码逻辑是:

parser = transformers.HfArgumentParser(
    (ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

含义:把 train.sh 里的命令行参数自动映射到三个 dataclass 中。

例如:

--model_name_or_path pretrained_model/mllm/Mipha-3B
--base_data_path /your/path/LaSeRS
--lora_r 4

会变成:

model_args.model_name_or_path
data_args.base_data_path
training_args.lora_r

所以你以后改实验,不应该直接改 Python 文件,而应该优先改 scripts/train.sh


四、读取 Mask2Former 配置

代码主线是:

mask_cfg = get_mask_config(config=model_args.mask_config)

这个配置决定分割分支怎么构造,包括:

Swin backbone
pixel decoder
transformer predictor
criterion
loss weight
hidden_dim
query 机制

这一步很重要。因为 SegEarth-R2 的 mask 不是 LLM 直接生成,而是通过 Mask2Former 类分割头生成。

也就是说:

LLM 负责产生 [SEG] embedding
Mask2Former 负责把 [SEG] embedding 解码成 mask

五、加载 SegEarthR2 主模型

关键代码是:

model = SegEarthR2.from_pretrained(
    model_args.model_name_or_path,
    mask_decoder_cfg=mask_cfg,
    add_cross_attn=True,
    cache_dir=training_args.cache_dir,
)

这一句加载的是整个 SegEarth-R2 主体。它不是普通语言模型,而是基于 Mipha/Phi 系列结构扩展出来的多模态分割模型。train.py 中明确从 segearth_r2.model.language_model.llava_phi 导入 SegEarthR2,然后通过 from_pretrained() 加载。(GitHub)

你要理解:

SegEarthR2 = 语言模型主体 + 多模态图像接入 + 分割模块接口

但此时分割模块不一定已经加载好,所以后面还要调用:

model.initial_mask_module(mask2former_ckpt, model_args)

六、初始化 Mask2Former 分割模块

代码逻辑:

if not model.is_train_mask_decode:
    mask2former_ckpt = model_args.vision_tower_mask if model_args.load_mask2former else None
    model.initial_mask_module(mask2former_ckpt, model_args)

含义:

如果当前模型还没有训练/初始化分割解码模块
就加载 vision_tower_mask 指定的 Mask2Former 权重
并初始化 pixel_decoder、predictor、criterion 等模块

这里有一个很重要的判断:

load_mask2former=True → 使用预训练 Mask2Former 权重
load_mask2former=False → 不加载预训练分割权重

复现时建议保持 load_mask2former=True,否则分割分支从头开始训练,难度更大。


七、关闭 use_cache

model.config.use_cache = False

训练大语言模型时通常关闭 cache,因为 cache 是推理生成时为了加速自回归 decoding 用的。训练时需要完整反向传播,cache 反而可能造成显存和梯度问题。

后面训练结束又会设置:

model.config.use_cache = True

用于保存和后续推理。train.py 在训练完成后调用 trainer.save_state() 并保存模型前把 use_cache 重新打开。(GitHub)


八、冻结语言模型主体

if model_args.freeze_backbone:
    model.model.requires_grad_(False)

如果 freeze_backbone=True,则冻结模型主体。

这里的 “backbone” 主要是语言模型主体,不是单纯 CNN backbone。

建议你第一次复现时不要改它,保持官方默认。等项目跑通后再做实验:

实验 A:冻结 LLM,只训 LoRA + 分割头
实验 B:LoRA 微调 LLM,冻结视觉塔
实验 C:打开 Swin backbone 训练
实验 D:打开 mm_projector 训练

九、gradient checkpointing

if training_args.gradient_checkpointing:
    model.enable_input_require_grads()

gradient checkpointing 的作用是节省显存:前向时不保存全部中间激活,反向时重新计算一部分。

优点:

显存更低
可以支持更大图像或更长序列

缺点:

训练更慢
调试更复杂
部分模型可能出现输入梯度问题

RTX 5090 第一次复现,建议先:

--gradient_checkpointing False

跑通以后,如果显存不够,再打开。


十、加载 tokenizer

tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_args.model_name_or_path,
    model_max_length=training_args.model_max_length,
    padding_side="right",
    use_fast=False,
)

作用:加载 Mipha-3B 对应 tokenizer。

关键点:

padding_side="right"
model_max_length=2048
use_fast=False

右 padding 对 causal LM 训练比较常见。model_max_length=2048 决定最长文本 token 序列。超长样本会被截断。官方代码在 tokenizer 没有 pad token 时,会通过 smart_tokenizer_and_embedding_resize() 添加 [PAD] 并同步 resize embedding。(GitHub)


十一、设置 conversation template

if model_args.version in conversation_lib.conv_templates:
    conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
else:
    conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]

这一步决定 prompt 格式。

它会影响:

human / gpt 对话怎么拼接
分隔符是什么
labels 哪些位置被 IGNORE_INDEX 屏蔽
模型学习回答时看到的上下文格式

你前面读 dataset.py 时看到的:

sources = [[
    {'from': 'human', 'value': prefix_inst + '\n<refer> <|assistant|>'},
    {'from': 'gpt', 'value': '\n' + answer}
]]

会通过 conversation template 转成最终训练文本。dataset.pypreprocess_llama2() 会根据对话轮次,把人类指令部分 label 置为 IGNORE_INDEX,只让 assistant 部分参与语言模型 loss。(GitHub)


十二、初始化两个视觉分支

核心代码:

model.get_model().initialize_vision_modules(
    model_args=model_args,
    fsdp=training_args.fsdp
)

vision_tower = model.get_vision_tower()
vision_tower_mask = model.model.get_vision_tower_mask()

vision_tower.to(dtype=..., device=training_args.device)
vision_tower_mask.to(dtype=..., device=training_args.device)

这段非常关键。它再次证明 SegEarth-R2 有两套视觉流:

vision_tower      → SigLIP,服务多模态 LLM
vision_tower_mask → Swin / Mask2Former,服务像素分割

训练代码中分别获取 vision_towervision_tower_mask,并把它们移动到训练设备和指定 dtype。(GitHub)

你要牢记:

images_clip 进入 vision_tower
images 进入 vision_tower_mask / mask decoder

这是项目中最容易混淆的地方。


十三、冻结视觉塔和 projector

代码:

if not model_args.train_clip_backbone:
    model.model.vision_tower.requires_grad_(False)

if not model_args.train_swin_backbone:
    model.model.vision_tower_mask.requires_grad_(False)

if training_args.freeze_mm_mlp_adapter:
    for p in model.get_model().mm_projector.parameters():
        p.requires_grad = False

含义:

默认不训练 SigLIP
默认不训练 Swin / Mask2Former backbone
默认冻结 mm_projector

但是后面会手动打开一些模块:

train_module_list = [
    "lm_head",
    "pixel_decoder",
    "predictor",
    "SEG_token_projector",
]

所以实际训练的是:

LoRA 参数
lm_head
pixel_decoder
predictor
SEG_token_projector

如果设置:

--train_swin_backbone True

vision_tower_mask 也会加入训练模块列表。官方代码中 train_module_list 默认包括 lm_headpixel_decoderpredictorSEG_token_projector,并在 train_swin_backbone 为真时追加 vision_tower_mask。(GitHub)


十四、最关键:添加 [SEG] token

tokenizer.add_tokens("[SEG]")
model.resize_token_embeddings(len(tokenizer))

这是整个项目的核心之一。

为什么要加 [SEG]

因为 SegEarth-R2 不是只生成文字,它要让语言模型在回答中生成一个特殊位置,这个位置代表:

我要在这里输出一个分割 mask

也就是说:

answer = "The building is [SEG]."

模型训练时不仅学习输出这句话,还会额外取 [SEG] 对应 hidden state 去预测 mask。

如果不加 [SEG],则 tokenizer 无法把 [SEG] 当成一个独立 token,后面的:

input_ids == self.SEG_token_id

就可能找不到正确位置。

train.py 中明确在视觉模块初始化和冻结逻辑之后添加 [SEG] 并 resize embedding,随后把 [SEG] 的 token id 传给模型内部。(GitHub)


十五、LoRA 注入逻辑

代码主线:

lora_target_modules = find_linear_layers(model, train_module_list=train_module_list)

lora_config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    target_modules=lora_target_modules,
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)

find_linear_layers() 的逻辑是:遍历模型中的 torch.nn.Linear,寻找名字里包含 q_projv_proj 的线性层,但排除 train_module_listvision_towervision_tower_mask 相关模块。(GitHub)

也就是说,LoRA 主要插入到语言模型注意力层中的:

q_proj
v_proj

这是一种典型的大模型微调方式。

你可以理解为:

大语言模型主体大部分被冻结
只在注意力 Q/V 投影上加可训练低秩分支
再额外训练分割相关模块

十六、手动打开分割相关模块训练

LoRA 注入后,代码还会执行:

for n, p in model.named_parameters():
    if any([x in n for x in train_module_list]):
        p.requires_grad = True

这一步非常关键。

因为 LoRA 只负责大语言模型部分的低秩微调,但分割输出能力依赖这些模块:

lm_head
pixel_decoder
predictor
SEG_token_projector

其中最重要的是:

SEG_token_projector

它负责把 LLM 的 [SEG] hidden state 映射到 Mask2Former predictor 可使用的维度。

所以整体训练策略是:

LLM:LoRA 微调
SigLIP:默认冻结
Swin backbone:默认冻结
mm_projector:默认冻结
Mask2Former pixel decoder:训练
Mask2Former predictor:训练
SEG_token_projector:训练
lm_head:训练

这也是你未来做博士改进时最应该动的区域。


十七、把 [SEG] id 注册到模型里

model.get_special_token(
    SEG=tokenizer("[SEG]", return_tensors='pt', add_special_tokens=False)['input_ids'],
    EOS=tokenizer.eos_token_id
)

这一步告诉模型:

哪个 token id 是 [SEG]
哪个 token id 是 EOS

为什么不只在 dataset 里用 [SEG]

因为模型 forward 里面还要根据 [SEG] 做生成、定位和 mask 解码。dataset.py 里会生成 SEG_token_embedding_indices,而 llava_phi.py 里会通过这些 index 从 hidden_states 取出 [SEG] embedding。(GitHub)


十八、构建 SigLIP 图像处理器

clip_image_processor = SiglipImageProcessor.from_pretrained(model_args.vision_tower)

这一步对应 dataset.pyDataCollatorForCOCODatasetV2

注意:

dataset.py 里的 preprocess_image() 处理的是 images
SiglipImageProcessor 处理的是 images_clip

DataCollatorForCOCODatasetV2 中会用 clip_image_processor.preprocess() 读取原图并构造 images_clip,而 LaSeRSDataset.__getitem__() 里提前构造的 image 会成为 images。(GitHub)


十九、构建数据模块

data_module = make_unify_datamodule(
    clip_image_processor=clip_image_processor,
    tokenizer=tokenizer,
    data_args=data_args,
    training_args=training_args
)

make_unify_datamodule() 会做三件事:

1. 解析 data_ratio
2. 构建 LaSeRSDataset
3. 构建 DataCollatorForCOCODatasetV2

最后返回:

{
    "train_dataset": train_dataset,
    "eval_dataset": None,
    "data_collator": data_collator
}

也就是说,这个项目训练时没有在 train.py 里单独构建 eval dataset,评估是后面通过 eval.py 单独跑。(GitHub)


二十、构建 LLaVATrainer

trainer = LLaVATrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    **data_module
)

它不是直接用 HuggingFace 原生 Trainer,而是自定义了 LLaVATrainer。核心原因是:模型返回的不只是一个总 loss,还包括:

loss_llm
loss_mask
loss_dice
loss_attention

自定义 Trainer 的 compute_loss() 会把 global_step 加入 inputs,然后调用 model(**inputs),再从 outputs 中取出 loss,并额外记录各类 loss。(GitHub)

这对调试很重要。

你训练时不要只看总 loss,还要看:

loss_llm 是否正常
loss_mask 是否正常
loss_dice 是否正常
loss_attention 是否异常

如果 loss_llm 正常但 loss_mask 不降,说明语言模型部分能学,但分割分支或 mask 对齐有问题。

如果 loss_mask 为 0 或 NaN,优先检查:

mask 是否读到
[SEG] 数量是否等于 mask 数量
seg_info 是否为空
mask dtype / shape 是否正确

二十一、训练启动和断点恢复

代码逻辑:

if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
    trainer.train(resume_from_checkpoint=True)
else:
    trainer.train()

含义:

如果 output_dir 里已有 checkpoint-*,自动断点恢复
否则从头训练

这很好用,但也容易坑你。

如果你改了模型结构、改了 [SEG]、改了 tokenizer、改了 LoRA 配置,却还让它从旧 checkpoint 恢复,可能会出现各种 shape mismatch。

因此调试时建议:

rm -rf outputs/debug_5090_lora_r4/checkpoint-*

或者每次换实验名:

outputs/exp001_debug
outputs/exp002_lora_r8
outputs/exp003_train_swin

二十二、保存模型

训练结束后:

trainer.save_state()
model.config.use_cache = True
safe_save_model_for_hf_trainer(...)

safe_save_model_for_hf_trainer() 对 DeepSpeed 做了特殊处理:如果 trainer 使用 DeepSpeed,则先同步 CUDA,然后调用 trainer.save_model(output_dir)。(GitHub)

你要注意:如果使用 LoRA,训练输出通常不是一个可直接推理的完整模型,还需要后续合并 LoRA 权重。

完整流程是:

train.sh
    ↓
得到 LoRA / checkpoint
    ↓
merge_lora_weights.sh
    ↓
得到合并后的模型
    ↓
test.sh / eval.py

二十三、到这里,你应该形成第二张总图

scripts/train.sh
    ↓
HfArgumentParser
    ↓
ModelArguments / DataArguments / TrainingArguments
    ↓
get_mask_config
    ↓
SegEarthR2.from_pretrained
    ↓
initial_mask_module
    ↓
AutoTokenizer.from_pretrained
    ↓
add [PAD] if needed
    ↓
initialize_vision_modules
    ↓
freeze SigLIP / Swin / mm_projector
    ↓
tokenizer.add_tokens("[SEG]")
    ↓
resize_token_embeddings
    ↓
LoRA 注入 q_proj / v_proj
    ↓
手动打开 lm_head / pixel_decoder / predictor / SEG_token_projector
    ↓
get_special_token([SEG], EOS)
    ↓
SiglipImageProcessor
    ↓
LaSeRSDataset + DataCollator
    ↓
LLaVATrainer
    ↓
model.forward
    ↓
loss_llm + loss_mask + loss_dice + loss_attention

二十四、提前接上 llava_phi.py 的 forward 主线

现在你已经知道 train.py 如何把 batch 送进模型。下一步要理解 model.forward() 里发生什么。

llava_phi.py 的 forward 参数包括:

input_ids
attention_mask
labels
images
images_clip
seg_info
token_refer_id
SEG_token_embedding_indices
mask_num
dataset_type

你看,正好对应前面 dataset.pyDataCollator 产生的字段。


1. 两路图像进入模型

在 forward 中:

image_features = self.get_vision_tower_feature(images)

这里的 images 是给分割分支用的,它会得到 Swin / Mask2Former 相关图像特征。随后:

prepare_inputs_labels_for_multimodal(..., images_clip, ...)

这里的 images_clip 是给多模态 LLM 用的,用于把图像 token 拼入语言模型输入序列。llava_phi.py forward 中先通过 get_vision_tower_feature(images) 获取分割图像特征,再通过 prepare_inputs_labels_for_multimodal(..., images_clip, ...) 处理多模态输入。(GitHub)

所以你必须牢牢记住:

images      → 分割视觉特征
images_clip → LLM 多模态输入

2. LLM 前向传播

代码主线:

outputs = self.model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    past_key_values=past_key_values,
    inputs_embeds=inputs_embeds,
    use_cache=use_cache,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
    return_dict=return_dict,
)

hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)

含义:

多模态 token 序列进入语言模型
得到每个 token 的 hidden state
再通过 lm_head 得到词表 logits

这一步用于语言建模 loss,也为 [SEG] 提供 hidden state。


3. 取出 [SEG] hidden state

关键代码:

SEG_embedding = self.SEG_token_projector(
    self.get_SEG_embedding(hidden_states, SEG_token_embedding_indices)
)

get_SEG_embedding() 会根据 SEG_token_embedding_indices 找到 hidden_states 中 [SEG] 对应位置,然后拼起来。(GitHub)

这就是项目最核心的桥梁:

[SEG] token hidden state
        ↓
SEG_token_projector
        ↓
Mask2Former query embedding

你可以把它看成:

语言模型说:“我要分割这个目标”
分割模型接收:“这个目标的查询向量是 SEG_embedding”

4. Mask2Former 解码 mask

关键代码主线:

mask_features, transformer_encoder_features, multi_scale_features =
    self.pixel_decoder.forward_features(image_features)

mask_outputs = self.predictor(
    multi_scale_features,
    mask_features,
    None,
    None,
    SEG_embedding
)

含义:

Swin 图像特征
    ↓
pixel_decoder
    ↓
multi_scale_features + mask_features
    ↓
predictor 使用 SEG_embedding 作为条件
    ↓
pred_masks

代码中还会根据 mask_nummask_featuresmulti_scale_featuresrepeat_interleave,因为一个 batch 中每张图可能有不同数量的 [SEG] / mask,需要把图像特征复制到与 mask query 数量对齐。(GitHub)


二十五、loss 是怎么组成的

核心逻辑:

总 loss = 语言建模 loss + mask loss + 0.01 × attention loss

L = L_{\mathrm{LLM}} + L_{\mathrm{mask}} + 0.01 L_{\mathrm{attention}}

代码中先计算语言模型的交叉熵 loss:把 logits[..., :-1, :]labels[..., 1:] 对齐,让前一个 token 预测下一个 token;然后根据 seg_info 构造 targets,调用 self.criterion(mask_outputs, targets) 得到 mask/dice 类损失;最后用 [SEG] 对图像 token 的 attention 和下采样 mask 计算 attention loss,并组合成总 loss。(GitHub)

更细一点:

loss_llm:
    语言模型交叉熵
    让模型学会生成 answer,包括 [SEG]

loss_mask:
    Mask2Former 的 mask 分类/像素损失

loss_dice:
    分割中常用的 Dice loss,用于提升区域重叠

loss_attention:
    约束 [SEG] token 对图像区域的注意力更接近真实 mask 区域

最终返回里包括:

loss
loss_mask
loss_dice
loss_llm
loss_attention

LLaVATrainer.compute_loss() 会把这些 loss 记录出来,所以训练日志里应该能看到各项损失。(GitHub)


二十六、目前你应该吃透的核心机制

到这里,你应该能完整讲出:

1. train.sh 把路径和参数传给 train.py
2. train.py 加载 SegEarthR2
3. train.py 初始化 Mask2Former 分割模块
4. train.py 加载 tokenizer 并添加 [SEG]
5. train.py 冻结 SigLIP、Swin、mm_projector
6. train.py 注入 LoRA 到 q_proj / v_proj
7. train.py 手动训练 lm_head / pixel_decoder / predictor / SEG_token_projector
8. dataset.py 构造 input_ids、images、images_clip、seg_info、SEG_token_embedding_indices
9. llava_phi.py 用 images_clip 拼接多模态 LLM 输入
10. llava_phi.py 用 images 提取分割图像特征
11. llava_phi.py 从 hidden_states 中取出 [SEG] embedding
12. SEG_embedding 进入 predictor,生成 pred_masks
13. 总 loss = LLM loss + mask/dice loss + attention loss

这已经是 SegEarth-R2 的主干了。


二十七、复现检查点:训练前必须打印这些信息

train.py 构建完 model 后,建议你临时加:

for name, p in model.named_parameters():
    if p.requires_grad:
        print(name, p.shape)

你应该看到主要可训练参数属于:

lora
lm_head
pixel_decoder
predictor
SEG_token_projector

如果你看到 vision_tower 大量参数可训练,说明 train_clip_backbone 可能被打开了。

如果你看到 vision_tower_mask 大量参数可训练,说明 train_swin_backbone 被打开了。

第一次复现不建议打开这两个。


二十八、训练第一个 batch 前必须检查

LLaVATrainer.compute_loss() 里,或者 dataset collator 后临时打印:

print("input_ids", inputs["input_ids"].shape)
print("labels", inputs["labels"].shape)
print("images", inputs["images"].shape)
print("images_clip", inputs["images_clip"].shape)
print("SEG sum", inputs["SEG_token_embedding_indices"].sum())
print("mask_num", inputs["mask_num"])
print("seg_info len", len(inputs["seg_info"]))

重点看:

SEG sum == sum(mask_num)
seg_info len == sum(mask_num)
images shape 是 [B, 3, 1024, 1024]
images_clip shape 是 SigLIP 输入尺寸

如果不满足:

优先查 dataset.py
其次查 answer 中 [SEG] 数量
最后查 mask RLE 是否正确解码

二十九、本讲结论

你现在应该把 train.py 理解成:

训练调度器

它本身不是主要创新点,但它决定了:

模型加载方式
特殊 token
可训练模块
LoRA 范围
视觉塔冻结策略
数据输入格式
Trainer 行为
保存方式

而真正的创新点在下一讲:

llava_phi.py

尤其是:

prepare_inputs_labels_for_multimodal
concat_image_seg_cls_embeds
get_SEG_embedding
forward
eval_seg

下一讲继续拆 llava_phi.py:模型结构与 forward 全流程。这部分会真正解释 SegEarth-R2 为什么能做到“语言指令 → 遥感分割 mask”。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值