onnx2pytorch和onnx-simplifer新版介绍

1,505次阅读
没有评论
onnx2pytorch和onnx-simplifer新版介绍

【GiantPandaCV导语】本文是ONNX2Pytorch思路分享以及onnx-simplifier新版简要介绍。ONNX2Pytorch工具已经测试了onnx model zoo中的大量分类模型并转换正确,欢迎使用,github地址:https://github.com/BBuf/onnx2nn。GiantPandaCV几个月前遭受恶意举报,今天终于解除封印了。感谢众多粉丝们的长期等待和支持,我们会在此继续分享学习经验。

0x0. 背景

ONNX作为微软的神经网络模型的开放格式被各个框架广泛应用,包括Pytroch,TensorFlow,OneFlow,Keras,Paddle等多种深度学习训练框架。因此,之前一直在思考一个问题,一个TensorFlow/MxNet/Keras导出来的ONNX模型是否可以借助ONNX被Pytorch框架使用呢?ONNX的理想是作为所有训练框架模型的中间表示,那么我们只需要再实现ONNX到各个框架的逆转就可以完成这件事情了。本工程的目的即是尝试支持ONNX转换到Pytorch,主要为了锻炼算子对齐和更深入的了解ONNX。先放一下github地址:https://github.com/BBuf/onnx2nn,欢迎关注。这个工程复用了https://github.com/ToriML/onnx2pytorch的整体逻辑,解决了原始工程中遗留的大量BUG,支持了更多OP,实现了输入一个ONNX模型,返回一个torch.nn.Module对象,并将这个torch.nn.Module对应的Pytorch模型保存下来。

0x1. 思路

首先需要说明的是,在执行转换之前需要先过一遍onnx-simplifer对原始的ONNX模型进行简化,工程地址为:https://github.com/daquexian/onnx-simplifier 。为了使用方便,我将这个工具直接接入到了本工程,在后面的使用方法中可以看到。

然后这和项目的思路是非常简单的,直接遍历ONNX模型的计算节点(也即OP),把每个OP一对一的转换到Pytorch就可以了。核心代码地址为:https://github.com/BBuf/onnx2nn/blob/master/onnx2pytorch/convert/operations.py#L20-L181。简单截图说明一下:

def convert_operations(onnx_model, batch_dim=0):
    """
    Convert onnx model operations. Yields onnx's operator_id, opeartor_name and
    converted pytorch operator.

    Parameters
    ----------
    onnx_model: onnx.ModelProto
        Loaded onnx model.
    batch_dim: int
        Usually 0 for computer vision models and 1 for NLP models.

    Returns
    -------
    iterator: (op_id, op_name, op)
    """
    weights = {tensor.name: tensor for tensor in onnx_model.graph.initializer}

    for i, node in enumerate(onnx_model.graph.node):
        # extract only useful inputs
        params = [weights[par_name] for par_name in node.input if par_name in weights]

        if node.op_type == "Conv":
            op = convert_layer(node, "Conv", params)
        elif node.op_type == "Relu":
            op = nn.ReLU(inplace=True)
        elif node.op_type == "LeakyRelu":
            op = nn.LeakyReLU(**extract_attributes(node), inplace=True)
        elif node.op_type == "Sigmoid":
            op = nn.Sigmoid()
        elif node.op_type == "MaxPool":
            op = convert_layer(node, "MaxPool")
        elif node.op_type == "AveragePool":
            op = convert_layer(node, "AvgPool")
        elif node.op_type == "Flatten":
            op = Flatten(**extract_attributes(node))
        elif node.op_type == "Gemm":
            op = convert_linear_layer(node, params)
            op.feature_dim = batch_dim + 1  # Necessary for transformers
        elif node.op_type == "BatchNormalization":
            op = convert_batch_norm_layer(node, params=params)
        elif node.op_type == "InstanceNormalization":
            op = convert_instance_norm_layer(node, params=params)
        elif node.op_type == "Concat":
            op = Concat(**extract_attributes(node))
        else
         pass
        op_name = "{}_{}".format(node.op_type, node.output[0])
        op_id = node.output[0]
        yield op_id, op_name, op

可以看到通过遍历ONNX模型的所有计算节点并获取每个节点的信息(输入参数以及各种attribute)之后将其用Pytorch的对应OP写出来就完成了转换过程。里面涉及到的每个OP的具体转换过程比如权重,attribute参数的提取以及对应Pytorch的实现等可以直接查看源码,这里不详细展开。

在获得每个ONNX计算节点对应的Pytorch OP之后,我们需要根据ONNX的计算节点反应的拓扑关系把所有的Pytorch OP组合成一个完整的Pytorch的模型,这部分的代码实现在:https://github.com/BBuf/onnx2nn/blob/master/onnx2pytorch/convert/model.py#L36-L131

0x2. 一些需要注意的点

在执行ONNX2Pytorch的过程中需要注意一些由于Pytorch和ONNX OP实现不一致而导致模型转换失败的情况,下面列举一下:

  • 非对称Padding问题。在对alexnet和google-net进行转换时发现它们的卷积或者Max Pooling层经常会出现非对称Padding的情况,由于Pytorch的卷积和最大池化操作不支持不对称Padding操作,所以这个时候为了保证转换的等价,需要将这个非对称Padding的OP拆成nn.ConstantPad2d+无Padding的原始OP。
  • count_include_pad问题。在对inception-net进行转换时发现到了最后一个Avg Pooling层时出现了精度严重下降,经过Debug发现,Pytorch的Avg Pooling层的count_include_pad默认为True。如果这个时候也是非对称的Padding,那么按照上面的处理方法拆分成ConstantPad2d+Avg Pooling之后会丢失精度,因为这种情况下Avg Pooling无法知晓自己Padding了多少元素。如下图所示:

这个时候可以通过修改Kernel尺寸的方法来规避这个问题,在上面的例子中我们可以直接让kernel_shape等于(7-1=6,7-1=6)并且省掉新增常量Pad的操作。

这两点的代码实现在:https://github.com/BBuf/onnx2nn/blob/master/onnx2pytorch/convert/layer.py#L30-L91

  • LRN层。在alexnet模型中有一个LRN层,这个层的参数长这样:

然后我们看一下Pytorch的LRN层的API:

对比一下API的参数可以发现ONNX里面的bias对应的是Pytorch LRN里面的参数k,所以这里需要特殊处理一下,获取这个attribute的bias参数的值之后将其设为Pytorch LRN层里面的k参数的值。具体实现在:https://github.com/BBuf/onnx2nn/blob/master/onnx2pytorch/convert/attribute.py#L132-L139

0x3. onnx2nn工程介绍

0x3.1 代码结构

- onnx2pytorch onnx转pytorch代码实现
- onnx2pytorch.py onnx转pytorch测试代码
- convert_models.md 转换ONNX Model Zoo里面的模型对应的命令和结果记录
- README.md 

0x3.2 运行环境

  • pytorch >= 1.1.0
  • onnx>=1.8.1
  • onnxruntime>=1.6.0
  • onnxoptimizer>=0.2.3

0x3.3 使用方法

使用下面的命令将各个训练框架导出的ONNX模型转换成Pytorch模型

python .\onnx2pytorch.py ...

参数列表如下:

  • --onnx_path 字符串,必选参数,代表onnx模型的路径
  • --pytorch_path 字符串,必选参数,代表转换出的Pytorch模型保存路径
  • --simplify_path 字符串,可选参数,代表ONNX模型简化(例如删除Dropout和常量OP)后保存的ONNX模型路径
  • --input_shape 字符串,必选参数,代表ONNX模型的输入数据层的名字和维度信息

0x3.4使用示例

python .\onnx2pytorch.py --onnx_path .\models\mobilenetv2-7.onnx --simplify_path .\models\mobilenetv2-7-simplify.onnx --pytorch_path .\models\mobilenetv2-7.pth --input_shape input:1,3,224,224

0x3.5 模型转换失败处理方法

  • onnx2pytorch.py里面的model = convert.ConvertModel(onnx_model, debug=False)这行代码里面的debug设置False重新运行模型即可定位到转换失败的OP,然后你可以在工程提出issue或者自己解决然后给本工程PR。

0x3.6 已支持的ONNX OP

  • Conv
  • BatchNormalization
  • GlobalAvgragePool
  • AvgPool
  • MaxPool
  • BatchNorm
  • Flatten
  • Reshape
  • Relu
  • Add
  • Gemm
  • Sigmoid
  • Mul
  • Concat
  • Resize (还有一些问题需要解决,当前版本支持固定倍数方法)
  • Transpose
  • LRN
  • Clip
  • Pad2d
  • Split
  • ReduceMean

0x3.7 已验证支持的模型

基于ONNXRuntime和Pytorch推理之后特征值mse小于1e-7,视为转换成功

分类模型

  • zfnet512-9.onnx
  • resnet50-v2-7.onnx
  • mobilenetv2-7.onnx
  • mobilenetv2-1.0.onnx
  • bvlcalexnet-9.onnx
  • googlenet-9.onnx
  • squeezenet1.1-7.onnx
  • shufflenet-v2-10.onnx
  • inception-v1-9.onnx
  • inception-v2-9.onnx
  • vgg19-caffe2-9.onnx
  • rcnn-ilsvrc13-9.onnx

检测模型

  • yolov5s-simple.onnx

0x3.7 TODO

  • 支持更多模型
  • 重构工程,并解决某些模型转为Pytorch模型之后Netron可视化看不到某些OP的问题
  • 一些部署工作,比如Keras导出的ONNX转为Pytorch模型后,二次导出ONNX递交给NCNN推理

0x4. onnx-simplifer最近更新

onnx-simplifer最近迎来了一次更新,这次更新是和onnxruntime一起更新的,小伙伴们要使用最新版本记得把onnxruntime更新到1.6.0哦。然后我去阅读了一下最新的onnx-simplifer,在上次的ONNX初探基础上,增加了一个递归函数fixed_point,功能就是递归执行func_a和fun_b直到模型稳定,代码如下:

# 递归执行func_a和func_b直到模型稳定
def fixed_point(x: T, func_a: Callable[[T], T], func_b: Callable[[T], T]) -> T:
    """
    Run `func_a` and `func_b` on `x` until func_b(func_a(x)) == x
    :param x: 
    :param func_a: A function satisfying func_a(func_a(x)) == func_a(x)
    :param func_b: A function satisfying func_b(func_b(x)) == func_b(x)
    :return: the x that satisfies func_b(func_a(x)) == x
    """
    x = func_a(x)
    x = func_b(x)
    while True:
        y = func_a(x)
        if y == x:
            # Since func_b(func_b(x)) == func_b(x),
            # we are already at the fixed point if
            # `y == x`
            return x
        x = y
        y = func_b(x)
        if y == x:
            return x
        x = y

我们看一下它是怎么应用的就可以了,注释如下:

def simplify(model: Union[str, onnx.ModelProto],
             check_n: int = 0,
             perform_optimization: bool = True,
             skip_fuse_bn: bool = False,
             input_shapes: Optional[TensorShapesWithOptionalKey] = None,
             skipped_optimizers: Optional[Sequence[str]] = None,
             skip_shape_inference=False,
             input_data: Optional[Tensors] = None,
             dynamic_input_shape: bool = False,
             custom_lib: Optional[str] = None) -> Tuple[onnx.ModelProto, bool]:
    """
    :param model: onnx ModelProto object or file path
    :param check_n: The simplified model will be checked for `check_n` times by random inputs
    :param perform_optimization: Whether to run onnx optimizer on the model
    :param skip_fuse_bn: Skip fuse_bn_into_conv onnx optimizer
    :param input_shapes: If the model has dynamic input shape, user must pass a fixed input shape 
            for generating random inputs and checking equality. (Also see "dynamic_input_shape" param)
    :param skipped_optimizers: Skip some specific onnx optimizers
    :param skip_shape_inference: Skip shape inference (sometimes shape inference will crash)
    :param input_data: Feed custom input data for checking if needed
    :param dynamic_input_shape: Indicates whether the input shape should be dynamic. Note that
            input_shapes is also needed even if dynamic_input_shape is True,
            the value of input_shapes will be used when generating random inputs for checking equality.
            If 'dynamic_input_shape' is False, the input shape in simplified model will be overwritten
            by the value of 'input_shapes' param.
    :param custom_lib: onnxruntime custom ops's shared library
    :return: A tuple (simplified model, success(True) or failed(False))
    """
    if input_shapes is None:
        input_shapes = {}
    if input_data is None:
        input_data = {}

    if type(model) == str:
        # 加载ONNX模型
        model = onnx.load(model)
    assert(isinstance(model, onnx.ModelProto))
    # 检查ONNX模型格式是否正确,图结构是否完整,节点是否正确等
    onnx.checker.check_model(model)
    # 深拷贝一份原始ONNX模型
    model_ori = copy.deepcopy(model)


    input_names = get_input_names(model)
    for input_name, data in input_data.items():
        if input_name not in input_names:
            raise RuntimeError(
                'The model doesn\'t have input named "{}"'.format(input_name))

        shape = list(input_data[input_name].shape)

        # special case for single constant variables (with shape [])
        if len(shape) == 0:
            shape = [input_data[input_name].size]
        if input_name in input_shapes and shape != input_shapes[input_name]:
            raise RuntimeError('The shape of input_data[{}] is not the same with input_shape[{}]'.format(
                input_name, input_name))
        elif input_name not in input_shapes:
            input_shapes[input_name] = shape

    # 检查核对输入节点
    updated_input_shapes = check_and_update_input_shapes(model, input_shapes)


    def infer_shapes_and_optimize(model: onnx.ModelProto) -> onnx.ModelProto:
        # 做ONNX模型节点形状推断
        def infer_shapes_if_applicable(model: onnx.ModelProto) -> onnx.ModelProto:
            if not skip_shape_inference:
                model = infer_shapes(model)
            return model
        # 对ONNX模型进行optimizer
        def optimize_if_applicable(model: onnx.ModelProto) -> onnx.ModelProto:
            if perform_optimization:
                model = optimize(model, skip_fuse_bn, skipped_optimizers)
            return model
        # 递归执行infer_shapes_if_applicable和optimize_if_applicable直到模型稳定
        return fixed_point(model, infer_shapes_if_applicable, optimize_if_applicable)

    def constant_folding(model: onnx.ModelProto) -> onnx.ModelProto:
        # 获取模型的常量OP
        const_nodes = get_constant_nodes(
            model, dynamic_input_shape=dynamic_input_shape)
        # 获取所有的常量OP以及原始输出OP的特征值
        res = forward_for_node_outputs(model,
                                       const_nodes,
                                       input_shapes=updated_input_shapes,
                                       input_data=input_data,
                                       custom_lib=custom_lib)
        # 清洗那些没有被onnxruntime推理的静态节点
        const_nodes = clean_constant_nodes(const_nodes, res)
        # 移除常量OP,获得简化后的ONNX模型
        model = eliminate_const_nodes(model, const_nodes, res)
        # 检查ONNX模型格式是否正确,图结构是否完整,节点是否正确等
        onnx.checker.check_model(model)
        return model

    # 递归执行infer_shapes_and_optimize和constant_folding直到模型稳定
    model = fixed_point(model, infer_shapes_and_optimize, constant_folding)

    # 重写模型的输入shape
    if not dynamic_input_shape:
        for name, input_shape in updated_input_shapes.items():
            for ipt in model.graph.input:
                if ipt.name == name:
                    for i, dim in enumerate(ipt.type.tensor_type.shape.dim):
                        dim.dim_value = input_shape[i]
    # 检查核对输入节点
    check_ok = check(model_ori, model, check_n,
                     input_shapes=updated_input_shapes)

    return model, check_ok

现在onnx-simplifer在简化过程中会递归的去推断shape,折叠常量,以及optimizer。所以这个程序比较依赖各个操作都不出错,如果某一步发生错误,可能有qia住的风险哦。使用最新版onnx-simplifer前切记更新onnxruntime到最新版本,否则使用model zoo里面的mobilenet模型就会引发qia住这一现象。

了解更多onnx-simplifer,比如执行流程,每一步再干什么请看ONNX初探的文章以及大老师发布的onnx simplifier 和 optimizer

BBuf只是API搬运工,onnxoptimizer和onnx-simplifer的作者大老师才是yyds

0x5. 推荐学习

之前写过和整理一些ONNX学习笔记,现在汇总如下,如果你是从模型部署来看ONNX,其实我个人认为看这些了解就差不多了,当然有新的想法我也会继续更新的(鸽。

0x6. 相关链接

正文完
可以使用微信扫码关注公众号(ID:xzluomor)
post-qrcode
 0
评论(没有评论)

文心AIGC

2024 年 1 月
1234567
891011121314
15161718192021
22232425262728
293031  
文心AIGC
文心AIGC
人工智能ChatGPT,AIGC指利用人工智能技术来生成内容,其中包括文字、语音、代码、图像、视频、机器人动作等等。被认为是继PGC、UGC之后的新型内容创作方式。AIGC作为元宇宙的新方向,近几年迭代速度呈现指数级爆发,谷歌、Meta、百度等平台型巨头持续布局
文章搜索
热门文章
潞晨尤洋:日常办公没必要上私有模型,这三类企业才需要 | MEET2026

潞晨尤洋:日常办公没必要上私有模型,这三类企业才需要 | MEET2026

潞晨尤洋:日常办公没必要上私有模型,这三类企业才需要 | MEET2026 Jay 2025-12-22 09...
“昆山杯”第二十七届清华大学创业大赛决赛举行

“昆山杯”第二十七届清华大学创业大赛决赛举行

“昆山杯”第二十七届清华大学创业大赛决赛举行 一水 2025-12-22 17:04:24 来源:量子位 本届...
MiniMax海螺视频团队首次开源:Tokenizer也具备明确的Scaling Law

MiniMax海螺视频团队首次开源:Tokenizer也具备明确的Scaling Law

MiniMax海螺视频团队首次开源:Tokenizer也具备明确的Scaling Law 一水 2025-12...
清库存!DeepSeek突然补全R1技术报告,训练路径首次详细公开

清库存!DeepSeek突然补全R1技术报告,训练路径首次详细公开

清库存!DeepSeek突然补全R1技术报告,训练路径首次详细公开 Jay 2026-01-08 20:18:...
最新评论
ufabet ufabet มีเกมให้เลือกเล่นมากมาย: เกมเดิมพันหลากหลาย ครบทุกค่ายดัง
tornado crypto mixer tornado crypto mixer Discover the power of privacy with TornadoCash! Learn how this decentralized mixer ensures your transactions remain confidential.
ดูบอลสด ดูบอลสด Very well presented. Every quote was awesome and thanks for sharing the content. Keep sharing and keep motivating others.
ดูบอลสด ดูบอลสด Pretty! This has been a really wonderful post. Many thanks for providing these details.
ดูบอลสด ดูบอลสด Pretty! This has been a really wonderful post. Many thanks for providing these details.
ดูบอลสด ดูบอลสด Hi there to all, for the reason that I am genuinely keen of reading this website’s post to be updated on a regular basis. It carries pleasant stuff.
Obrazy Sztuka Nowoczesna Obrazy Sztuka Nowoczesna Thank you for this wonderful contribution to the topic. Your ability to explain complex ideas simply is admirable.
ufabet ufabet Hi there to all, for the reason that I am genuinely keen of reading this website’s post to be updated on a regular basis. It carries pleasant stuff.
ufabet ufabet You’re so awesome! I don’t believe I have read a single thing like that before. So great to find someone with some original thoughts on this topic. Really.. thank you for starting this up. This website is something that is needed on the internet, someone with a little originality!
ufabet ufabet Very well presented. Every quote was awesome and thanks for sharing the content. Keep sharing and keep motivating others.
热评文章
摩尔线程的野心,不藏了

摩尔线程的野心,不藏了

摩尔线程的野心,不藏了 量子位的朋友们 2025-12-22 10:11:58 来源:量子位 上市后的仅15天...
摩尔线程的野心,不藏了

摩尔线程的野心,不藏了

摩尔线程的野心,不藏了 量子位的朋友们 2025-12-22 10:11:58 来源:量子位 上市后的仅15天...
AI体育教练来了!中国团队打造SportsGPT,完成从数值评估到专业指导的智能转身

AI体育教练来了!中国团队打造SportsGPT,完成从数值评估到专业指导的智能转身

AI体育教练来了!中国团队打造SportsGPT,完成从数值评估到专业指导的智能转身 量子位的朋友们 2025...
AI体育教练来了!中国团队打造SportsGPT,完成从数值评估到专业指导的智能转身

AI体育教练来了!中国团队打造SportsGPT,完成从数值评估到专业指导的智能转身

AI体育教练来了!中国团队打造SportsGPT,完成从数值评估到专业指导的智能转身 量子位的朋友们 2025...
真正面向大模型的AI Infra,必须同时懂模型、系统、产业|商汤大装置宣善明@MEET2026

真正面向大模型的AI Infra,必须同时懂模型、系统、产业|商汤大装置宣善明@MEET2026

真正面向大模型的AI Infra,必须同时懂模型、系统、产业|商汤大装置宣善明@MEET2026 量子位的朋友...