05_yolox_s的后处理截断并导出onnx

目的是得到以下模型:
在这里插入图片描述

1、
官方yolox_s的源码和yolox_s.pth获取
https://github.com/Megvii-BaseDetection/YOLOX
在这里插入图片描述
2、
修改yolo_head.py的forward,替换为以下

    def forward(self, xin, labels=None, imgs=None):
        outputs = []

        for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
                zip(self.cls_convs, self.reg_convs, self.strides, xin)
        ):
            x = self.stems[k](x)

            cls_feat = cls_conv(x)
            reg_feat = reg_conv(x)

            cls_output = self.cls_preds[k](cls_feat)  # [B, C, H, W]
            reg_output = self.reg_preds[k](reg_feat)  # [B, 4, H, W]
            obj_output = self.obj_preds[k](reg_feat)  # [B, 1, H, W]

            # 🚨 关键:不要 decode,不要 concat
            outputs.append(reg_output)
            outputs.append(obj_output)
            outputs.append(cls_output)

        return outputs

3、
修改export_onnx.py的main()为以下

def main():
    args = make_parser().parse_args()
    logger.info("args value: {}".format(args))
    exp = get_exp(args.exp_file, args.name)
    exp.merge(args.opts)

    if not args.experiment_name:
        args.experiment_name = exp.exp_name

    model = exp.get_model()
    if args.ckpt is None:
        file_name = os.path.join(exp.output_dir, args.experiment_name)
        ckpt_file = os.path.join(file_name, "best_ckpt.pth")
    else:
        ckpt_file = args.ckpt

    # load the model state dict
    ckpt = torch.load(ckpt_file, map_location="cpu")

    model.eval()
    if "model" in ckpt:
        ckpt = ckpt["model"]
    model.load_state_dict(ckpt)
    model = replace_module(model, nn.SiLU, SiLU)
    model.head.decode_in_inference = False

    logger.info("loading checkpoint done.")
    dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])

    output_names = []

    output_names = [
        "reg1", "obj1", "cls1",
        "reg2", "obj2", "cls2",
        "reg3", "obj3", "cls3",
    ]
    torch.onnx._export(
        model,
        dummy_input,
        args.output_name,
        input_names=[args.input],
        output_names=output_names,
        dynamic_axes={args.input: {0: 'batch'},
                      **{name: {0: 'batch'} for name in output_names}} if args.dynamic else None,
        opset_version=args.opset,
    )
    logger.info("generated onnx model named {}".format(args.output_name))

    if not args.no_onnxsim:
        import onnx
        from onnxsim import simplify

        # use onnx-simplifier to reduce reduent model.
        onnx_model = onnx.load(args.output_name)
        model_simp, check = simplify(onnx_model)
        assert check, "Simplified ONNX model could not be validated"
        onnx.save(model_simp, args.output_name)
        logger.info("generated simplified onnx model named {}".format(args.output_name))

4、
导出指令

python tools/export_onnx.py  -f exps/default/yolox_s.py  -c yolox_s.pth  --output-name yolox_s.onnx  --opset 12 --output .

上述完成就可得到需要的onnx

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值