继续。下面进入 第 4 讲:train.py 主训练流程精读,并提前把 llava_phi.py 的 forward 主线接上。你要把这部分看成 SegEarth-R2 的“总控制台”。
第 4 讲:train.py 主训练流程
train.py 决定了 6 件事:
1. 训练参数怎么解析
2. SegEarthR2 主模型怎么加载
3. Mask2Former 分割模块怎么初始化
4. tokenizer 怎么扩展 [SEG]
5. LoRA 怎么注入
6. 数据集和 Trainer 怎么构建
这个文件不是模型创新本身,但它是你复现必须吃透的文件。官方 train.py 里定义了 ModelArguments、DataArguments、TrainingArguments,然后在 train() 函数中完成参数解析、模型加载、视觉模块初始化、LoRA 注入、数据模块构造和训练启动。(GitHub)
一、先看三个参数类
1. ModelArguments
它控制模型相关参数。
核心字段包括:
model_name_or_path
vision_tower
vision_tower_mask
mask_config
freeze_backbone
train_clip_backbone
train_swin_backbone
load_mask2former
projector_outdim
mm_projector_type
你可以这样理解:
model_name_or_path → Mipha-3B / 语言模型主体
vision_tower → SigLIP / 给 LLM 使用的图像编码器
vision_tower_mask → Mask2Former / Swin 分割分支权重
mask_config → Mask2Former 配置文件
freeze_backbone → 是否冻结语言模型主体
train_clip_backbone → 是否训练 SigLIP
train_swin_backbone → 是否训练 Swin 分割骨干
load_mask2former → 是否加载 Mask2Former 预训练权重
这里最关键的是:
vision_tower 和 vision_tower_mask 不是同一个模块
vision_tower 是给多模态大模型看的,vision_tower_mask 是给分割解码器用的。官方参数里同时出现 vision_tower 和 vision_tower_mask,训练流程里也分别初始化并移动到对应 dtype/device。(GitHub)
2. DataArguments
它控制数据相关参数。
核心字段:
base_data_path
data_ratio
switch_bs
segmentation
image_aspect_ratio
fix_dataset_len
含义:
base_data_path → LaSeRS 数据集根目录
data_ratio → 多数据集混合比例
switch_bs → 多数据集训练时 batch 切换间隔
segmentation → 是否启用分割任务
fix_dataset_len → 是否固定数据集长度
当前官方训练主要使用 LaSeRS 数据,因此 data_ratio='1' 就够。make_unify_datamodule() 里会根据 data_ratio 构造 LaSeRSDataset,再封装成 UnifyDatasetSingleDatasetForBatch。(GitHub)
3. TrainingArguments
它继承 HuggingFace 的 transformers.TrainingArguments,控制训练超参数。
核心字段:
per_device_train_batch_size
gradient_accumulation_steps
gradient_checkpointing
deepspeed
output_dir
model_max_length
bf16 / fp16
lora_enable
lora_r
lora_alpha
lora_dropout
dataloader_num_workers
其中最重要的是:
lora_enable=True
lora_r
deepspeed
bf16
model_max_length=2048
SegEarth-R2 默认不是全量微调所有参数,而是 LoRA + 部分分割模块训练。train.py 中会根据 lora_enable 构造 LoraConfig,然后用 get_peft_model() 包装模型。(GitHub)
二、train() 函数主线
train() 是整个训练入口。你读源码时不要从第一行陷入细节,而是先按这条主线读:
parse_args
↓
get_mask_config
↓
SegEarthR2.from_pretrained
↓
initial_mask_module
↓
AutoTokenizer.from_pretrained
↓
initialize_vision_modules
↓
freeze / unfreeze modules
↓
tokenizer.add_tokens("[SEG]")
↓
LoRA 注入
↓
model.get_special_token
↓
SiglipImageProcessor
↓
make_unify_datamodule
↓
LLaVATrainer
↓
trainer.train()
↓
save_model
下面逐段拆。
三、参数解析
代码逻辑是:
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
含义:把 train.sh 里的命令行参数自动映射到三个 dataclass 中。
例如:
--model_name_or_path pretrained_model/mllm/Mipha-3B
--base_data_path /your/path/LaSeRS
--lora_r 4
会变成:
model_args.model_name_or_path
data_args.base_data_path
training_args.lora_r
所以你以后改实验,不应该直接改 Python 文件,而应该优先改 scripts/train.sh。
四、读取 Mask2Former 配置
代码主线是:
mask_cfg = get_mask_config(config=model_args.mask_config)
这个配置决定分割分支怎么构造,包括:
Swin backbone
pixel decoder
transformer predictor
criterion
loss weight
hidden_dim
query 机制
这一步很重要。因为 SegEarth-R2 的 mask 不是 LLM 直接生成,而是通过 Mask2Former 类分割头生成。
也就是说:
LLM 负责产生 [SEG] embedding
Mask2Former 负责把 [SEG] embedding 解码成 mask
五、加载 SegEarthR2 主模型
关键代码是:
model = SegEarthR2.from_pretrained(
model_args.model_name_or_path,
mask_decoder_cfg=mask_cfg,
add_cross_attn=True,
cache_dir=training_args.cache_dir,
)
这一句加载的是整个 SegEarth-R2 主体。它不是普通语言模型,而是基于 Mipha/Phi 系列结构扩展出来的多模态分割模型。train.py 中明确从 segearth_r2.model.language_model.llava_phi 导入 SegEarthR2,然后通过 from_pretrained() 加载。(GitHub)
你要理解:
SegEarthR2 = 语言模型主体 + 多模态图像接入 + 分割模块接口
但此时分割模块不一定已经加载好,所以后面还要调用:
model.initial_mask_module(mask2former_ckpt, model_args)
六、初始化 Mask2Former 分割模块
代码逻辑:
if not model.is_train_mask_decode:
mask2former_ckpt = model_args.vision_tower_mask if model_args.load_mask2former else None
model.initial_mask_module(mask2former_ckpt, model_args)
含义:
如果当前模型还没有训练/初始化分割解码模块
就加载 vision_tower_mask 指定的 Mask2Former 权重
并初始化 pixel_decoder、predictor、criterion 等模块
这里有一个很重要的判断:
load_mask2former=True → 使用预训练 Mask2Former 权重
load_mask2former=False → 不加载预训练分割权重
复现时建议保持 load_mask2former=True,否则分割分支从头开始训练,难度更大。
七、关闭 use_cache
model.config.use_cache = False
训练大语言模型时通常关闭 cache,因为 cache 是推理生成时为了加速自回归 decoding 用的。训练时需要完整反向传播,cache 反而可能造成显存和梯度问题。
后面训练结束又会设置:
model.config.use_cache = True
用于保存和后续推理。train.py 在训练完成后调用 trainer.save_state() 并保存模型前把 use_cache 重新打开。(GitHub)
八、冻结语言模型主体
if model_args.freeze_backbone:
model.model.requires_grad_(False)
如果 freeze_backbone=True,则冻结模型主体。
这里的 “backbone” 主要是语言模型主体,不是单纯 CNN backbone。
建议你第一次复现时不要改它,保持官方默认。等项目跑通后再做实验:
实验 A:冻结 LLM,只训 LoRA + 分割头
实验 B:LoRA 微调 LLM,冻结视觉塔
实验 C:打开 Swin backbone 训练
实验 D:打开 mm_projector 训练
九、gradient checkpointing
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
gradient checkpointing 的作用是节省显存:前向时不保存全部中间激活,反向时重新计算一部分。
优点:
显存更低
可以支持更大图像或更长序列
缺点:
训练更慢
调试更复杂
部分模型可能出现输入梯度问题
RTX 5090 第一次复现,建议先:
--gradient_checkpointing False
跑通以后,如果显存不够,再打开。
十、加载 tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
作用:加载 Mipha-3B 对应 tokenizer。
关键点:
padding_side="right"
model_max_length=2048
use_fast=False
右 padding 对 causal LM 训练比较常见。model_max_length=2048 决定最长文本 token 序列。超长样本会被截断。官方代码在 tokenizer 没有 pad token 时,会通过 smart_tokenizer_and_embedding_resize() 添加 [PAD] 并同步 resize embedding。(GitHub)
十一、设置 conversation template
if model_args.version in conversation_lib.conv_templates:
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
else:
conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
这一步决定 prompt 格式。
它会影响:
human / gpt 对话怎么拼接
分隔符是什么
labels 哪些位置被 IGNORE_INDEX 屏蔽
模型学习回答时看到的上下文格式
你前面读 dataset.py 时看到的:
sources = [[
{'from': 'human', 'value': prefix_inst + '\n<refer> <|assistant|>'},
{'from': 'gpt', 'value': '\n' + answer}
]]
会通过 conversation template 转成最终训练文本。dataset.py 里 preprocess_llama2() 会根据对话轮次,把人类指令部分 label 置为 IGNORE_INDEX,只让 assistant 部分参与语言模型 loss。(GitHub)
十二、初始化两个视觉分支
核心代码:
model.get_model().initialize_vision_modules(
model_args=model_args,
fsdp=training_args.fsdp
)
vision_tower = model.get_vision_tower()
vision_tower_mask = model.model.get_vision_tower_mask()
vision_tower.to(dtype=..., device=training_args.device)
vision_tower_mask.to(dtype=..., device=training_args.device)
这段非常关键。它再次证明 SegEarth-R2 有两套视觉流:
vision_tower → SigLIP,服务多模态 LLM
vision_tower_mask → Swin / Mask2Former,服务像素分割
训练代码中分别获取 vision_tower 和 vision_tower_mask,并把它们移动到训练设备和指定 dtype。(GitHub)
你要牢记:
images_clip 进入 vision_tower
images 进入 vision_tower_mask / mask decoder
这是项目中最容易混淆的地方。
十三、冻结视觉塔和 projector
代码:
if not model_args.train_clip_backbone:
model.model.vision_tower.requires_grad_(False)
if not model_args.train_swin_backbone:
model.model.vision_tower_mask.requires_grad_(False)
if training_args.freeze_mm_mlp_adapter:
for p in model.get_model().mm_projector.parameters():
p.requires_grad = False
含义:
默认不训练 SigLIP
默认不训练 Swin / Mask2Former backbone
默认冻结 mm_projector
但是后面会手动打开一些模块:
train_module_list = [
"lm_head",
"pixel_decoder",
"predictor",
"SEG_token_projector",
]
所以实际训练的是:
LoRA 参数
lm_head
pixel_decoder
predictor
SEG_token_projector
如果设置:
--train_swin_backbone True
则 vision_tower_mask 也会加入训练模块列表。官方代码中 train_module_list 默认包括 lm_head、pixel_decoder、predictor、SEG_token_projector,并在 train_swin_backbone 为真时追加 vision_tower_mask。(GitHub)
十四、最关键:添加 [SEG] token
tokenizer.add_tokens("[SEG]")
model.resize_token_embeddings(len(tokenizer))
这是整个项目的核心之一。
为什么要加 [SEG]?
因为 SegEarth-R2 不是只生成文字,它要让语言模型在回答中生成一个特殊位置,这个位置代表:
我要在这里输出一个分割 mask
也就是说:
answer = "The building is [SEG]."
模型训练时不仅学习输出这句话,还会额外取 [SEG] 对应 hidden state 去预测 mask。
如果不加 [SEG],则 tokenizer 无法把 [SEG] 当成一个独立 token,后面的:
input_ids == self.SEG_token_id
就可能找不到正确位置。
train.py 中明确在视觉模块初始化和冻结逻辑之后添加 [SEG] 并 resize embedding,随后把 [SEG] 的 token id 传给模型内部。(GitHub)
十五、LoRA 注入逻辑
代码主线:
lora_target_modules = find_linear_layers(model, train_module_list=train_module_list)
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
find_linear_layers() 的逻辑是:遍历模型中的 torch.nn.Linear,寻找名字里包含 q_proj 或 v_proj 的线性层,但排除 train_module_list、vision_tower 和 vision_tower_mask 相关模块。(GitHub)
也就是说,LoRA 主要插入到语言模型注意力层中的:
q_proj
v_proj
这是一种典型的大模型微调方式。
你可以理解为:
大语言模型主体大部分被冻结
只在注意力 Q/V 投影上加可训练低秩分支
再额外训练分割相关模块
十六、手动打开分割相关模块训练
LoRA 注入后,代码还会执行:
for n, p in model.named_parameters():
if any([x in n for x in train_module_list]):
p.requires_grad = True
这一步非常关键。
因为 LoRA 只负责大语言模型部分的低秩微调,但分割输出能力依赖这些模块:
lm_head
pixel_decoder
predictor
SEG_token_projector
其中最重要的是:
SEG_token_projector
它负责把 LLM 的 [SEG] hidden state 映射到 Mask2Former predictor 可使用的维度。
所以整体训练策略是:
LLM:LoRA 微调
SigLIP:默认冻结
Swin backbone:默认冻结
mm_projector:默认冻结
Mask2Former pixel decoder:训练
Mask2Former predictor:训练
SEG_token_projector:训练
lm_head:训练
这也是你未来做博士改进时最应该动的区域。
十七、把 [SEG] id 注册到模型里
model.get_special_token(
SEG=tokenizer("[SEG]", return_tensors='pt', add_special_tokens=False)['input_ids'],
EOS=tokenizer.eos_token_id
)
这一步告诉模型:
哪个 token id 是 [SEG]
哪个 token id 是 EOS
为什么不只在 dataset 里用 [SEG]?
因为模型 forward 里面还要根据 [SEG] 做生成、定位和 mask 解码。dataset.py 里会生成 SEG_token_embedding_indices,而 llava_phi.py 里会通过这些 index 从 hidden_states 取出 [SEG] embedding。(GitHub)
十八、构建 SigLIP 图像处理器
clip_image_processor = SiglipImageProcessor.from_pretrained(model_args.vision_tower)
这一步对应 dataset.py 的 DataCollatorForCOCODatasetV2。
注意:
dataset.py 里的 preprocess_image() 处理的是 images
SiglipImageProcessor 处理的是 images_clip
DataCollatorForCOCODatasetV2 中会用 clip_image_processor.preprocess() 读取原图并构造 images_clip,而 LaSeRSDataset.__getitem__() 里提前构造的 image 会成为 images。(GitHub)
十九、构建数据模块
data_module = make_unify_datamodule(
clip_image_processor=clip_image_processor,
tokenizer=tokenizer,
data_args=data_args,
training_args=training_args
)
make_unify_datamodule() 会做三件事:
1. 解析 data_ratio
2. 构建 LaSeRSDataset
3. 构建 DataCollatorForCOCODatasetV2
最后返回:
{
"train_dataset": train_dataset,
"eval_dataset": None,
"data_collator": data_collator
}
也就是说,这个项目训练时没有在 train.py 里单独构建 eval dataset,评估是后面通过 eval.py 单独跑。(GitHub)
二十、构建 LLaVATrainer
trainer = LLaVATrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
**data_module
)
它不是直接用 HuggingFace 原生 Trainer,而是自定义了 LLaVATrainer。核心原因是:模型返回的不只是一个总 loss,还包括:
loss_llm
loss_mask
loss_dice
loss_attention
自定义 Trainer 的 compute_loss() 会把 global_step 加入 inputs,然后调用 model(**inputs),再从 outputs 中取出 loss,并额外记录各类 loss。(GitHub)
这对调试很重要。
你训练时不要只看总 loss,还要看:
loss_llm 是否正常
loss_mask 是否正常
loss_dice 是否正常
loss_attention 是否异常
如果 loss_llm 正常但 loss_mask 不降,说明语言模型部分能学,但分割分支或 mask 对齐有问题。
如果 loss_mask 为 0 或 NaN,优先检查:
mask 是否读到
[SEG] 数量是否等于 mask 数量
seg_info 是否为空
mask dtype / shape 是否正确
二十一、训练启动和断点恢复
代码逻辑:
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
含义:
如果 output_dir 里已有 checkpoint-*,自动断点恢复
否则从头训练
这很好用,但也容易坑你。
如果你改了模型结构、改了 [SEG]、改了 tokenizer、改了 LoRA 配置,却还让它从旧 checkpoint 恢复,可能会出现各种 shape mismatch。
因此调试时建议:
rm -rf outputs/debug_5090_lora_r4/checkpoint-*
或者每次换实验名:
outputs/exp001_debug
outputs/exp002_lora_r8
outputs/exp003_train_swin
二十二、保存模型
训练结束后:
trainer.save_state()
model.config.use_cache = True
safe_save_model_for_hf_trainer(...)
safe_save_model_for_hf_trainer() 对 DeepSpeed 做了特殊处理:如果 trainer 使用 DeepSpeed,则先同步 CUDA,然后调用 trainer.save_model(output_dir)。(GitHub)
你要注意:如果使用 LoRA,训练输出通常不是一个可直接推理的完整模型,还需要后续合并 LoRA 权重。
完整流程是:
train.sh
↓
得到 LoRA / checkpoint
↓
merge_lora_weights.sh
↓
得到合并后的模型
↓
test.sh / eval.py
二十三、到这里,你应该形成第二张总图
scripts/train.sh
↓
HfArgumentParser
↓
ModelArguments / DataArguments / TrainingArguments
↓
get_mask_config
↓
SegEarthR2.from_pretrained
↓
initial_mask_module
↓
AutoTokenizer.from_pretrained
↓
add [PAD] if needed
↓
initialize_vision_modules
↓
freeze SigLIP / Swin / mm_projector
↓
tokenizer.add_tokens("[SEG]")
↓
resize_token_embeddings
↓
LoRA 注入 q_proj / v_proj
↓
手动打开 lm_head / pixel_decoder / predictor / SEG_token_projector
↓
get_special_token([SEG], EOS)
↓
SiglipImageProcessor
↓
LaSeRSDataset + DataCollator
↓
LLaVATrainer
↓
model.forward
↓
loss_llm + loss_mask + loss_dice + loss_attention
二十四、提前接上 llava_phi.py 的 forward 主线
现在你已经知道 train.py 如何把 batch 送进模型。下一步要理解 model.forward() 里发生什么。
llava_phi.py 的 forward 参数包括:
input_ids
attention_mask
labels
images
images_clip
seg_info
token_refer_id
SEG_token_embedding_indices
mask_num
dataset_type
你看,正好对应前面 dataset.py 和 DataCollator 产生的字段。
1. 两路图像进入模型
在 forward 中:
image_features = self.get_vision_tower_feature(images)
这里的 images 是给分割分支用的,它会得到 Swin / Mask2Former 相关图像特征。随后:
prepare_inputs_labels_for_multimodal(..., images_clip, ...)
这里的 images_clip 是给多模态 LLM 用的,用于把图像 token 拼入语言模型输入序列。llava_phi.py forward 中先通过 get_vision_tower_feature(images) 获取分割图像特征,再通过 prepare_inputs_labels_for_multimodal(..., images_clip, ...) 处理多模态输入。(GitHub)
所以你必须牢牢记住:
images → 分割视觉特征
images_clip → LLM 多模态输入
2. LLM 前向传播
代码主线:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
含义:
多模态 token 序列进入语言模型
得到每个 token 的 hidden state
再通过 lm_head 得到词表 logits
这一步用于语言建模 loss,也为 [SEG] 提供 hidden state。
3. 取出 [SEG] hidden state
关键代码:
SEG_embedding = self.SEG_token_projector(
self.get_SEG_embedding(hidden_states, SEG_token_embedding_indices)
)
get_SEG_embedding() 会根据 SEG_token_embedding_indices 找到 hidden_states 中 [SEG] 对应位置,然后拼起来。(GitHub)
这就是项目最核心的桥梁:
[SEG] token hidden state
↓
SEG_token_projector
↓
Mask2Former query embedding
你可以把它看成:
语言模型说:“我要分割这个目标”
分割模型接收:“这个目标的查询向量是 SEG_embedding”
4. Mask2Former 解码 mask
关键代码主线:
mask_features, transformer_encoder_features, multi_scale_features =
self.pixel_decoder.forward_features(image_features)
mask_outputs = self.predictor(
multi_scale_features,
mask_features,
None,
None,
SEG_embedding
)
含义:
Swin 图像特征
↓
pixel_decoder
↓
multi_scale_features + mask_features
↓
predictor 使用 SEG_embedding 作为条件
↓
pred_masks
代码中还会根据 mask_num 对 mask_features 和 multi_scale_features 做 repeat_interleave,因为一个 batch 中每张图可能有不同数量的 [SEG] / mask,需要把图像特征复制到与 mask query 数量对齐。(GitHub)
二十五、loss 是怎么组成的
核心逻辑:
总 loss = 语言建模 loss + mask loss + 0.01 × attention loss
L = L_{\mathrm{LLM}} + L_{\mathrm{mask}} + 0.01 L_{\mathrm{attention}}
代码中先计算语言模型的交叉熵 loss:把 logits[..., :-1, :] 和 labels[..., 1:] 对齐,让前一个 token 预测下一个 token;然后根据 seg_info 构造 targets,调用 self.criterion(mask_outputs, targets) 得到 mask/dice 类损失;最后用 [SEG] 对图像 token 的 attention 和下采样 mask 计算 attention loss,并组合成总 loss。(GitHub)
更细一点:
loss_llm:
语言模型交叉熵
让模型学会生成 answer,包括 [SEG]
loss_mask:
Mask2Former 的 mask 分类/像素损失
loss_dice:
分割中常用的 Dice loss,用于提升区域重叠
loss_attention:
约束 [SEG] token 对图像区域的注意力更接近真实 mask 区域
最终返回里包括:
loss
loss_mask
loss_dice
loss_llm
loss_attention
LLaVATrainer.compute_loss() 会把这些 loss 记录出来,所以训练日志里应该能看到各项损失。(GitHub)
二十六、目前你应该吃透的核心机制
到这里,你应该能完整讲出:
1. train.sh 把路径和参数传给 train.py
2. train.py 加载 SegEarthR2
3. train.py 初始化 Mask2Former 分割模块
4. train.py 加载 tokenizer 并添加 [SEG]
5. train.py 冻结 SigLIP、Swin、mm_projector
6. train.py 注入 LoRA 到 q_proj / v_proj
7. train.py 手动训练 lm_head / pixel_decoder / predictor / SEG_token_projector
8. dataset.py 构造 input_ids、images、images_clip、seg_info、SEG_token_embedding_indices
9. llava_phi.py 用 images_clip 拼接多模态 LLM 输入
10. llava_phi.py 用 images 提取分割图像特征
11. llava_phi.py 从 hidden_states 中取出 [SEG] embedding
12. SEG_embedding 进入 predictor,生成 pred_masks
13. 总 loss = LLM loss + mask/dice loss + attention loss
这已经是 SegEarth-R2 的主干了。
二十七、复现检查点:训练前必须打印这些信息
在 train.py 构建完 model 后,建议你临时加:
for name, p in model.named_parameters():
if p.requires_grad:
print(name, p.shape)
你应该看到主要可训练参数属于:
lora
lm_head
pixel_decoder
predictor
SEG_token_projector
如果你看到 vision_tower 大量参数可训练,说明 train_clip_backbone 可能被打开了。
如果你看到 vision_tower_mask 大量参数可训练,说明 train_swin_backbone 被打开了。
第一次复现不建议打开这两个。
二十八、训练第一个 batch 前必须检查
在 LLaVATrainer.compute_loss() 里,或者 dataset collator 后临时打印:
print("input_ids", inputs["input_ids"].shape)
print("labels", inputs["labels"].shape)
print("images", inputs["images"].shape)
print("images_clip", inputs["images_clip"].shape)
print("SEG sum", inputs["SEG_token_embedding_indices"].sum())
print("mask_num", inputs["mask_num"])
print("seg_info len", len(inputs["seg_info"]))
重点看:
SEG sum == sum(mask_num)
seg_info len == sum(mask_num)
images shape 是 [B, 3, 1024, 1024]
images_clip shape 是 SigLIP 输入尺寸
如果不满足:
优先查 dataset.py
其次查 answer 中 [SEG] 数量
最后查 mask RLE 是否正确解码
二十九、本讲结论
你现在应该把 train.py 理解成:
训练调度器
它本身不是主要创新点,但它决定了:
模型加载方式
特殊 token
可训练模块
LoRA 范围
视觉塔冻结策略
数据输入格式
Trainer 行为
保存方式
而真正的创新点在下一讲:
llava_phi.py
尤其是:
prepare_inputs_labels_for_multimodal
concat_image_seg_cls_embeds
get_SEG_embedding
forward
eval_seg
下一讲继续拆 llava_phi.py:模型结构与 forward 全流程。这部分会真正解释 SegEarth-R2 为什么能做到“语言指令 → 遥感分割 mask”。

509

被折叠的 条评论
为什么被折叠?



