您的位置:  首页 > 技术杂谈 > 正文

ONNX整理

2022-10-20 17:00 https://my.oschina.net/u/3768341/blog/5585440 算法之名 次阅读 条评论

ONNX(Open Neural Network Exchange)——开放神经网络交换格式,作为框架共用的一种模型交换格式,使用protobuf二进制格式来序列化模型(protobuf序列化可以参考Netty整合Protobuffer ),可以提供更好的传输性能。官方github:GitHub - onnx/onnx at f2daca5e9b9315a2034da61c662d2a7ac28a9488

ONNX将每一个网络的每一层或者说是每一个算子当作节点Node,再由这些Node去构建一个Graph,相当于是一个网络。最后将Graph和这个onnx模型的其他信息结合在一起,生成一个model,也就是最终的onnx模型。实例如下

在这里插入图片描述

创建ONNX模型

创建onnx模型有两种方法,一种是其他框架转换过来,如Pytorch、PaddlePaddle等,从Pytorch转换onnx可以参考模型部署篇 的Pytorch 权重 pth 转换 onnx;PaddlePaddle转换onnx可以参考PaddleOCR使用指南 中的Paddle2ONNX

我们先来生成一个onnx文件

import torch
import torch.nn as nn
from torch.autograd import Variable

class Network(nn.Module):

    def __init__(self):
        super(Network, self).__init__()
        self.conv = nn.Conv2d(1, 1, 1)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.conv(x))

if __name__ == '__main__':

    net = Network()
    input = Variable(torch.randn([1, 1, 1, 1]))
    torch.onnx.export(net, input, 'net.onnx', opset_version=10)

然后来打印这个onnx文件的结构

import torch
import torch.nn as nn
from torch.autograd import Variable
import onnx

class Network(nn.Module):

    def __init__(self):
        super(Network, self).__init__()
        self.conv = nn.Conv2d(1, 1, 1)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.conv(x))

if __name__ == '__main__':

    # net = Network()
    # input = Variable(torch.randn([1, 1, 1, 1]))
    # torch.onnx.export(net, input, 'net.onnx', opset_version=10)

    print(onnx.load("./net.onnx"))

运行结果

ir_version: 5
producer_name: "pytorch"
producer_version: "1.12.1"
graph {
  node {
    input: "input.1"
    input: "conv.weight"
    input: "conv.bias"
    output: "input"
    name: "Conv_0"
    op_type: "Conv"
    attribute {
      name: "dilations"
      ints: 1
      ints: 1
      type: INTS
    }
    attribute {
      name: "group"
      i: 1
      type: INT
    }
    attribute {
      name: "kernel_shape"
      ints: 1
      ints: 1
      type: INTS
    }
    attribute {
      name: "pads"
      ints: 0
      ints: 0
      ints: 0
      ints: 0
      type: INTS
    }
    attribute {
      name: "strides"
      ints: 1
      ints: 1
      type: INTS
    }
  }
  node {
    input: "input"
    output: "4"
    name: "Relu_1"
    op_type: "Relu"
  }
  name: "torch_jit"
  initializer {
    dims: 1
    dims: 1
    dims: 1
    dims: 1
    data_type: 1
    name: "conv.weight"
    raw_data: "\014\317B?"
  }
  initializer {
    dims: 1
    data_type: 1
    name: "conv.bias"
    raw_data: "\344\n\026\277"
  }
  input {
    name: "input.1"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 1
          }
        }
      }
    }
  }
  output {
    name: "4"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 1
          }
        }
      }
    }
  }
}
opset_import {
  version: 10
}

首先是onnx版本,我们这里为ir_version: 5,然后是从什么框架转换过来的,这里是从Pytorch转换过来的producer_name: "pytorch",版本号是producer_version: "1.12.1"。

然后是graph->node,第一个node是2D卷积核,第二个node是Relu激活函数。node中的op_type是节点类型,所有类型可以参考https://github.com/onnx/onnx/blob/f2daca5e9b9315a2034da61c662d2a7ac28a9488/docs/Operators.md。name是节点名称,它跟op_type是不同的。attribute是节点属性,在Conv_0中就是2D卷积的各种属性,比如"group"是分组卷积,"kernel_shape"是卷积核尺寸等等。initializer是初始化,包含了权重初始化和偏置初始化。input是输入,包含输入的形状,output是输出,包含输出的形状。opset_import为当前的模型文件所依赖的算子domain和版本。

最后我们来检查该模型,运行是没有问题的。

import torch
import torch.nn as nn
from torch.autograd import Variable
import onnx

class Network(nn.Module):

    def __init__(self):
        super(Network, self).__init__()
        self.conv = nn.Conv2d(1, 1, 1)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.conv(x))

if __name__ == '__main__':

    # net = Network()
    # input = Variable(torch.randn([1, 1, 1, 1]))
    # torch.onnx.export(net, input, 'net.onnx', opset_version=10)

    # print(onnx.load("./net.onnx"))
    model = onnx.load("./net.onnx")
    onnx.checker.check_model(model)

另外一种就是用onnx自己的方法创建onnx模型。

import onnx
import onnx.helper as helper
import numpy as np

if __name__ == '__main__':

    input = helper.make_tensor_value_info(name='input', elem_type=onnx.TensorProto.FLOAT, shape=[1, 3, 244, 244])
    output = helper.make_tensor_value_info(name='output', elem_type=onnx.TensorProto.FLOAT, shape=[1, 3, 244, 244])
    weight = helper.make_tensor(name='weight', data_type=onnx.TensorProto.FLOAT, dims=[3, 3, 1, 1], vals=np.random.randn(3, 3, 1, 1))
    bias = helper.make_tensor(name='bias', data_type=onnx.TensorProto.FLOAT, dims=[3], vals=np.random.randn(3))
    node = helper.make_node(op_type='Conv', inputs=['input', 'weight', 'bias'], outputs=['output'], kernel_shape=[1, 1], strides=[1, 1],
                            group=1, pads=[0, 0, 0, 0])
    graph = helper.make_graph(name='graph', nodes=[node], inputs=[input], outputs=[output], initializer=[weight, bias])

    model = helper.make_model(graph)
    onnx.checker.check_model(model)
    print(model)
    onnx.save_model(model, 'model.onnx')

运行结果

ir_version: 8
graph {
  node {
    input: "input"
    input: "weight"
    input: "bias"
    output: "output"
    op_type: "Conv"
    attribute {
      name: "group"
      i: 1
      type: INT
    }
    attribute {
      name: "kernel_shape"
      ints: 1
      ints: 1
      type: INTS
    }
    attribute {
      name: "pads"
      ints: 0
      ints: 0
      ints: 0
      ints: 0
      type: INTS
    }
    attribute {
      name: "strides"
      ints: 1
      ints: 1
      type: INTS
    }
  }
  name: "graph"
  initializer {
    dims: 3
    dims: 3
    dims: 1
    dims: 1
    data_type: 1
    float_data: 0.45837152004241943
    float_data: 0.10209446400403976
    float_data: 1.0382566452026367
    float_data: -0.09292714297771454
    float_data: 1.58871591091156
    float_data: 0.3746287226676941
    float_data: -0.35588690638542175
    float_data: 0.7165427207946777
    float_data: 0.10244251787662506
    name: "weight"
  }
  initializer {
    dims: 3
    data_type: 1
    float_data: -0.36782845854759216
    float_data: 2.305680513381958
    float_data: -0.13051341474056244
    name: "bias"
  }
  input {
    name: "input"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 244
          }
          dim {
            dim_value: 244
          }
        }
      }
    }
  }
  output {
    name: "output"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 244
          }
          dim {
            dim_value: 244
          }
        }
      }
    }
  }
}
opset_import {
  version: 17
}

动态设置batch_size

在上面的结果中,我们可以看到input的维度都是固定值[1,3,244,244],现在我们要来改变这个固定值为可以动态输入的值。我们先将模型给运行起来。

import onnx
import onnxruntime
import numpy as np

if __name__ == '__main__':

    model = onnx.load("./model.onnx")
    sess = onnxruntime.InferenceSession('./model.onnx')
    input = np.random.randn(1, 3, 244, 244).astype(np.float32)
    print(sess.run(['output'], {'input': input}))

运行结果

[array([[[[-7.4062514e-01,  2.5951520e-01, -3.5876265e-01, ...,
          -2.0852795e+00, -1.0078001e-01, -4.9386564e-01],
         [-6.0379845e-01,  9.2830718e-01, -4.2096943e-02, ...,
          -1.9139317e-01,  1.6547061e+00,  1.4468774e+00],
         [ 2.6494553e+00, -9.6209788e-01,  8.2099646e-02, ...,
          -1.5899204e+00, -1.3295431e+00,  1.1512205e-01],
         ...,
         [ 1.4135087e+00,  6.4077592e-01, -5.6514746e-01, ...,
           2.1367333e+00,  2.6012421e+00, -1.3565271e+00],
         [ 6.9879985e-01,  1.2454928e+00,  6.0045028e-01, ...,
          -6.1302024e-01, -4.3026954e-02, -7.2975445e-01],
         [-2.1020520e+00, -1.2499222e+00, -9.3896770e-01, ...,
          -4.6129468e-01,  5.4580927e-01, -7.4599540e-01]],

        [[ 5.6230574e+00,  2.6218858e+00,  7.1071947e-01, ...,
           3.6510468e-02,  2.5771899e+00,  2.0060635e+00],
         [ 4.2759910e+00,  2.5261867e+00,  1.0787441e+00, ...,
           3.3373690e+00,  4.5090003e+00,  3.5535808e+00],
         [ 1.6522924e+00,  1.5206050e+00,  3.6905313e+00, ...,
           1.5963824e+00,  5.1875353e-02,  3.4248161e+00],
         ...,
         [ 1.0295208e+00,  4.5397396e+00,  4.3366423e+00, ...,
           1.2408195e+00,  3.1239326e+00,  1.7476916e+00],
         [ 9.7080982e-01,  1.9692242e+00,  3.7690439e+00, ...,
          -1.6770840e-01,  1.1871569e+00,  4.2690439e+00],
         [ 4.4730301e+00,  1.5573008e+00,  7.2707558e+00, ...,
           4.7898588e+00,  2.9080591e+00,  7.2294927e-01]],

        [[ 1.3509388e+00, -1.9160898e-01, -1.3318433e+00, ...,
          -1.0562456e+00,  1.0652192e-01, -4.4993240e-01],
         [ 7.3106253e-01, -4.0714890e-03, -5.3625894e-01, ...,
          -6.2385768e-02,  3.3464909e-01,  2.7667671e-01],
         [-7.8517151e-01, -7.1918708e-01,  5.5366117e-01, ...,
          -4.7982591e-01, -1.0322813e+00,  8.0901492e-01],
         ...,
         [-1.0904443e+00,  4.7577775e-01,  9.5288980e-01, ...,
          -9.8435390e-01, -5.1632053e-01,  2.4581529e-01],
         [-6.4627886e-01, -9.8449951e-01,  1.6146483e-01, ...,
          -1.2009792e+00, -7.3006052e-01,  7.0891309e-01],
         [ 1.3855783e+00, -8.9338100e-01,  2.4704218e+00, ...,
           6.8950468e-01,  1.7709453e-01, -7.6678610e-01]]]],
      dtype=float32)]

现在我们来把输入的batch_size调整成2

import onnx
import onnxruntime
import numpy as np

if __name__ == '__main__':

    model = onnx.load("./model.onnx")
    inputs = model.graph.input
    outputs = model.graph.output
    for i in inputs:
        i.type.tensor_type.shape.dim[0].dim_value = 2
    for o in outputs:
        o.type.tensor_type.shape.dim[0].dim_value = 2
    onnx.checker.check_model(model)
    onnx.save_model(model, 'dynamic_model.onnx')
    sess = onnxruntime.InferenceSession('./dynamic_model.onnx')
    input = np.random.randn(2, 3, 244, 244).astype(np.float32)
    print(sess.run(['output'], {'input': input}))

运行结果

[array([[[[-2.10871696e-02, -1.32871771e+00, -1.22335061e-01, ...,
           4.77721721e-01, -4.10815179e-01, -1.37511027e+00],
         [-1.09181249e+00, -2.02204657e+00,  1.54176390e+00, ...,
          -1.88722742e+00, -2.00726366e+00,  4.24929589e-01],
         [-7.14685619e-01,  3.82802397e-01, -2.30412316e+00, ...,
           7.06834435e-01, -2.36892438e+00, -2.11947155e+00],
         ...,
         [-9.51929450e-01, -1.22408187e+00, -1.35213524e-01, ...,
           5.55669367e-02, -5.95110297e-01, -2.15206313e+00],
         [ 8.90325904e-01, -1.89442956e+00,  8.34725618e-01, ...,
          -2.34860206e+00, -1.09965193e+00, -4.96994108e-01],
         [ 1.56639183e+00,  5.97145438e-01, -5.28750658e-01, ...,
           5.77995658e-01, -1.46205699e+00,  2.80693078e+00]],

        [[ 3.09728765e+00, -1.42589498e+00,  7.58970022e-01, ...,
           3.48910093e+00,  2.95971513e+00,  1.96736765e+00],
         [ 2.76622701e+00,  1.58350587e+00,  2.41761374e+00, ...,
           3.68322372e-01,  3.05963039e-01,  2.99718475e+00],
         [-1.75151324e+00,  2.79870439e+00, -3.03543806e-01, ...,
           2.86027908e+00,  1.78771615e+00,  4.79569674e+00],
         ...,
         [ 1.30739605e+00,  1.83714139e+00,  4.55001736e+00, ...,
           1.44066858e+00,  4.87037659e+00,  2.10291076e+00],
         [ 9.44083452e-01, -8.11131001e-02,  2.89160919e+00, ...,
           2.34788847e+00,  1.95467031e+00,  3.87145948e+00],
         [ 2.71238947e+00,  1.46723819e+00,  7.61192560e-01, ...,
           2.69581342e+00,  2.11386037e+00,  4.08577728e+00]],

        [[ 3.52043629e-01, -1.83945060e+00, -9.97831583e-01, ...,
          -2.60245442e-01,  3.69277894e-01,  1.17505208e-01],
         [ 2.62015522e-01, -6.50106370e-01, -7.36498535e-01, ...,
          -3.72626394e-01, -9.92001474e-01,  1.87904552e-01],
         [-2.00427341e+00, -2.67415404e-01, -1.00334084e+00, ...,
           8.22970718e-02,  1.41485706e-01,  1.49001801e+00],
         ...,
         [-9.85595703e-01, -9.74879414e-03,  1.27501774e+00, ...,
          -7.10564435e-01,  1.17551017e+00, -4.15902734e-01],
         [-9.80473995e-01, -1.07735765e+00,  2.39617974e-02, ...,
          -1.93872005e-02,  2.48361230e-02,  7.19040394e-01],
         [-6.61614537e-02, -4.85614896e-01, -7.31452227e-01, ...,
          -9.65917259e-02, -2.94267178e-01,  1.87805906e-01]]],


       [[[-1.05941308e+00,  8.10959578e-01, -9.29054856e-01, ...,
          -1.33419132e+00, -5.62950134e-01,  3.15277368e-01],
         [-2.45844007e+00, -5.31174302e-01,  8.06264520e-01, ...,
          -1.37343729e+00, -1.26287377e+00, -1.79255664e+00],
         [ 5.01155496e-01,  2.53203034e+00, -9.11398768e-01, ...,
          -2.61194611e+00, -6.27550602e-01, -1.04612875e+00],
         ...,
         [ 5.64767838e-01,  1.82380235e+00, -9.87865806e-01, ...,
          -1.48546624e+00,  5.00284791e-01, -1.14099467e+00],
         [-1.48488015e-01, -3.75306606e-03,  2.05217457e+00, ...,
          -4.82964367e-01,  6.37757182e-01,  5.87742925e-01],
         [-7.62285709e-01,  5.78535438e-01, -9.07517672e-01, ...,
          -1.40203249e+00,  3.13063234e-01,  9.46564317e-01]],

        [[ 2.21778965e+00,  1.17825162e+00,  1.17773283e+00, ...,
           4.21785736e+00,  1.93207061e+00,  6.90674305e+00],
         [ 5.16840172e+00,  4.03573513e-02,  3.72957373e+00, ...,
           2.57324958e+00,  3.23857665e-01,  8.98278236e-01],
         [ 1.18916261e+00,  4.03137350e+00,  1.54717636e+00, ...,
           5.73142242e+00,  2.54209590e+00,  3.02691102e+00],
         ...,
         [ 2.02949071e+00,  4.00444984e+00,  3.55739307e+00, ...,
           5.54533482e-01,  3.57894540e+00,  7.03547835e-01],
         [ 2.57975435e+00,  2.32062602e+00,  4.18669128e+00, ...,
           2.15663671e+00,  2.39567637e+00,  7.93485880e-01],
         [ 3.32399893e+00,  3.12817383e+00,  3.60134292e+00, ...,
           1.70791423e+00,  7.71586776e-01,  3.58140349e+00]],

        [[-6.90246701e-01, -8.55753422e-01, -1.35433823e-01, ...,
           9.99482393e-01, -2.96287388e-01,  2.49611807e+00],
         [ 1.56937921e+00, -9.95752215e-01,  5.38442284e-02, ...,
           1.63274094e-01, -9.27845955e-01, -6.64922059e-01],
         [-5.40241778e-01,  2.26666585e-01, -2.95405626e-01, ...,
           1.90356636e+00,  4.94795978e-01,  1.35599896e-01],
         ...,
         [-4.09579694e-01,  1.26961544e-01,  5.97525239e-01, ...,
          -9.00853217e-01,  8.11160445e-01, -8.88532698e-01],
         [-5.75763881e-01, -1.15364529e-01,  2.42510274e-01, ...,
           1.83168098e-01, -3.83193374e-01, -1.10992551e+00],
         [ 9.75027233e-02,  1.07848495e-02,  2.93477297e-01, ...,
          -2.67393768e-01, -8.09366763e-01,  6.60410523e-03]]]],
      dtype=float32)]

但是现在batch_size依然是一个固定值,如果我们修改input的第一个维度,是会报错的。则我们需要修改成以下的方式才能输入任意的batch_size。

import onnx
import onnxruntime
import numpy as np

if __name__ == '__main__':

    model = onnx.load("./model.onnx")
    inputs = model.graph.input
    outputs = model.graph.output
    for i in inputs:
        i.type.tensor_type.shape.dim[0].dim_param = 'batchsize'
    for o in outputs:
        o.type.tensor_type.shape.dim[0].dim_param = 'batchsize'
    onnx.checker.check_model(model)
    onnx.save_model(model, 'dynamic_model.onnx')
    sess = onnxruntime.InferenceSession('./dynamic_model.onnx')
    input = np.random.randn(3, 3, 244, 244).astype(np.float32)
    print(sess.run(['output'], {'input': input}))

运行结果

[array([[[[-5.6308472e-01, -2.8269453e+00, -2.7103744e+00, ...,
           4.2550400e-01,  6.5147376e-01, -4.7779888e-02],
         [-2.5536952e+00,  1.1469245e-01,  3.4514198e-01, ...,
          -1.8919052e+00, -5.7445437e-01, -1.5864235e+00],
         [-1.7443299e-02, -8.9739335e-01, -2.9766396e-01, ...,
           2.7872375e-01, -8.8234627e-01, -2.3681331e+00],
         ...,
         [-1.3148707e+00, -5.4888296e-01,  4.1061863e-01, ...,
          -1.0763314e+00, -9.6379507e-01,  1.3077673e+00],
         [-6.3514382e-02, -5.1493609e-01, -1.5793841e+00, ...,
          -2.2589236e-02, -2.2170777e+00,  1.2437304e+00],
         [ 7.4394345e-01,  7.8581774e-01,  2.0062235e-01, ...,
          -1.4014708e+00,  5.5377036e-02,  3.6608991e-01]],

        [[ 5.4129419e+00,  1.7448205e+00,  3.4165416e+00, ...,
           1.0320716e+00,  1.6988618e+00,  5.1501741e+00],
         [-1.3918903e+00,  1.7199724e+00,  2.1343894e+00, ...,
           8.0553353e-01,  4.7985373e+00,  2.5783958e+00],
         [ 2.3555427e+00,  6.3222194e-01,  2.9314611e+00, ...,
           4.3459427e-01,  1.3417060e+00,  1.6852837e+00],
         ...,
         [ 5.7537341e-01,  3.0654173e+00, -5.7629395e-01, ...,
           1.0968879e+00,  3.7861698e+00,  1.4928346e+00],
         [ 3.1267416e+00,  2.0358701e+00,  2.2204084e+00, ...,
           5.2084265e+00,  3.9166064e+00,  6.4575119e+00],
         [-7.2486067e-01,  2.3311584e+00,  2.0912974e+00, ...,
           1.8693907e+00,  3.2796674e+00,  3.8991761e+00]],

        [[ 1.1898929e+00,  8.9648962e-03,  8.5148907e-01, ...,
          -4.7057205e-01, -7.9108685e-01,  1.0573645e+00],
         [-1.5732453e+00, -3.8554335e-01, -1.6086581e-01, ...,
          -8.1125468e-01,  1.2085729e+00,  5.6812420e-02],
         [-3.0767348e-01, -8.5083431e-01,  4.9003422e-02, ...,
          -7.8210533e-01, -5.2408022e-01, -2.3199841e-02],
         ...,
         [-7.3540843e-01, -2.9384446e-01, -1.6465921e+00, ...,
          -3.8980949e-01,  7.1137357e-01, -8.0783540e-01],
         [ 2.3953258e-01, -1.7050017e-01, -1.3933203e-01, ...,
           1.6591790e+00,  1.0759927e+00,  1.7683787e+00],
         [-1.6956003e+00, -4.6602386e-01, -3.4259117e-01, ...,
          -1.0014131e-01,  2.6990986e-01,  8.6363363e-01]]],


       [[[ 6.4478827e-01, -8.1067204e-01, -1.2237258e+00, ...,
          -1.2951733e+00, -6.2070227e-01, -1.2906476e+00],
         [-6.6038930e-01, -2.8674665e-01, -1.0612940e+00, ...,
           4.6769258e-01,  4.8500946e-01, -5.6188315e-01],
         [ 1.0600269e-02, -1.4934481e+00,  9.1430867e-01, ...,
          -6.1285675e-01, -3.0706315e+00, -9.9033105e-01],
         ...,
         [ 1.7771789e+00, -1.3830042e+00, -1.4351614e+00, ...,
          -2.6786397e+00,  3.7956804e-02,  6.7189908e-01],
         [-2.1517308e+00, -5.8123243e-01, -7.7163374e-01, ...,
           1.6774191e+00,  7.2239363e-01,  1.3373801e+00],
         [-8.6465418e-01, -1.3932706e+00, -2.2982714e+00, ...,
           1.9587449e+00, -6.2718022e-01, -1.1754386e+00]],

        [[ 5.2605295e+00,  6.8119764e-01,  1.6433215e+00, ...,
           1.4899890e+00,  7.7494907e-01,  1.0885936e+00],
         [ 1.7135508e+00,  1.7890544e+00,  1.5538380e+00, ...,
           4.2714515e+00,  3.4532502e+00,  4.0540075e+00],
         [ 3.2757509e-01,  2.8093519e+00,  4.4473543e+00, ...,
           1.6302650e+00,  2.0791094e+00, -2.7314346e+00],
         ...,
         [ 3.1872306e+00,  2.1063502e+00,  4.4839258e+00, ...,
           8.6034179e-01,  3.7707591e+00,  3.9809742e+00],
         [-2.0055294e-02, -4.3134212e-02,  2.1313593e+00, ...,
           3.0318618e+00,  2.2852294e+00,  3.9968524e+00],
         [ 2.1781492e+00,  3.6937137e+00,  1.5003638e+00, ...,
           3.5955300e+00,  1.7056749e+00,  1.9585730e+00]],

        [[ 1.1795213e+00, -7.1754062e-01, -1.3523299e-01, ...,
          -5.6350648e-01, -1.3417213e+00, -5.0127864e-02],
         [-5.1167816e-01, -7.7823803e-02, -3.1461412e-01, ...,
           7.8631788e-01,  5.9256524e-01,  6.9275266e-01],
         [-1.3142396e+00,  7.9331988e-01,  5.0062788e-01, ...,
          -6.4525604e-03, -3.3234254e-02, -2.1546085e+00],
         ...,
         [-2.0651843e-01,  1.1771068e-02,  1.3835690e+00, ...,
           2.8883666e-03,  4.5511311e-01,  2.9804629e-01],
         [-9.0822458e-01, -1.3634090e+00, -4.2348909e-01, ...,
           2.9903316e-01, -5.9180021e-01,  5.1938176e-01],
         [-3.7974668e-01,  6.5785772e-01, -4.8025602e-01, ...,
           1.5578230e-01, -8.5666311e-01, -8.2990326e-02]]],


       [[[ 2.7270940e-01,  1.6803369e-01,  6.4784336e-01, ...,
          -8.6817765e-01,  2.4317000e+00,  9.9560642e-01],
         [-1.0902294e+00, -1.5418210e+00, -6.4213789e-01, ...,
           3.8346985e-01, -2.2009264e-01, -1.4083362e+00],
         [-1.2999996e+00, -1.0029310e+00, -8.0927563e-01, ...,
          -9.6844232e-01,  4.7647089e-02, -1.7528368e+00],
         ...,
         [ 7.9181468e-01, -7.1245348e-01, -1.2355906e+00, ...,
          -4.4910422e-01,  7.0296872e-01, -1.8157486e+00],
         [ 8.5229218e-01, -3.9036795e-01,  3.7029549e-01, ...,
          -2.0579123e+00,  9.2259049e-03, -1.2485095e+00],
         [-1.0421257e+00,  9.6360290e-01, -1.9165359e+00, ...,
          -1.5525728e+00, -2.7757692e+00,  5.9844279e-01]],

        [[ 2.0120070e+00,  2.5763493e+00,  2.5311258e+00, ...,
           2.0375581e+00,  1.6430848e+00,  4.5296006e+00],
         [-6.4119029e-01,  3.2270002e-01,  2.7286339e+00, ...,
           3.4792902e+00,  4.8433290e+00,  1.8760866e+00],
         [ 5.2160606e+00,  5.8354855e-01,  1.9910555e+00, ...,
           3.8761294e-01,  3.4568546e+00,  2.2840927e+00],
         ...,
         [ 2.4697292e+00,  3.1099756e+00,  4.5984769e+00, ...,
           3.1638999e+00,  1.7895203e+00,  5.1426482e-01],
         [ 2.0174649e+00,  3.7343421e+00,  1.3838698e+00, ...,
           6.8948352e-01,  1.9830887e+00, -1.2911747e+00],
         [ 2.2970469e+00,  2.8243198e+00,  8.7906146e-01, ...,
           3.2837601e+00,  1.0420291e+00,  4.1244802e+00]],

        [[-4.5218289e-01, -1.1248827e-02, -3.9010030e-01, ...,
          -2.1441557e-01, -8.8925439e-01,  1.0432711e+00],
         [-1.5277631e+00, -6.0763943e-01,  8.2450414e-01, ...,
           5.1565582e-01,  9.1227055e-01, -4.1257131e-01],
         [ 1.1678007e+00, -8.4806198e-01, -4.1370481e-01, ...,
          -9.4888353e-01,  2.4556525e-01,  2.7058780e-02],
         ...,
         [-2.6444227e-01,  6.4803612e-01,  1.4935874e+00, ...,
           1.9097075e-02, -6.0670388e-01, -3.2186458e-01],
         [-5.2368152e-01,  6.9923353e-01, -6.0641676e-01, ...,
          -3.2536793e-01, -3.0933461e-01, -1.7596698e+00],
         [-5.7884902e-01, -9.0141267e-02, -4.4471401e-01, ...,
           3.5021925e-01, -1.7998603e-01,  6.4285696e-01]]]],
      dtype=float32)]

这里可以把input的第一个维度,也就是batch_size修改成任意数值,程序都可以运行。此时我们打印下model的信息。

import onnx
import onnxruntime
import numpy as np

if __name__ == '__main__':

    model = onnx.load("./model.onnx")
    inputs = model.graph.input
    outputs = model.graph.output
    for i in inputs:
        i.type.tensor_type.shape.dim[0].dim_param = 'batchsize'
    for o in outputs:
        o.type.tensor_type.shape.dim[0].dim_param = 'batchsize'
    print(model)
    # onnx.checker.check_model(model)
    # onnx.save_model(model, 'dynamic_model.onnx')
    # sess = onnxruntime.InferenceSession('./dynamic_model.onnx')
    # input = np.random.randn(3, 3, 244, 244).astype(np.float32)
    # print(sess.run(['output'], {'input': input}))

运行结果

ir_version: 8
graph {
  node {
    input: "input"
    input: "weight"
    input: "bias"
    output: "output"
    op_type: "Conv"
    attribute {
      name: "group"
      i: 1
      type: INT
    }
    attribute {
      name: "kernel_shape"
      ints: 1
      ints: 1
      type: INTS
    }
    attribute {
      name: "pads"
      ints: 0
      ints: 0
      ints: 0
      ints: 0
      type: INTS
    }
    attribute {
      name: "strides"
      ints: 1
      ints: 1
      type: INTS
    }
  }
  name: "graph"
  initializer {
    dims: 3
    dims: 3
    dims: 1
    dims: 1
    data_type: 1
    float_data: 0.45837152004241943
    float_data: 0.10209446400403976
    float_data: 1.0382566452026367
    float_data: -0.09292714297771454
    float_data: 1.58871591091156
    float_data: 0.3746287226676941
    float_data: -0.35588690638542175
    float_data: 0.7165427207946777
    float_data: 0.10244251787662506
    name: "weight"
  }
  initializer {
    dims: 3
    data_type: 1
    float_data: -0.36782845854759216
    float_data: 2.305680513381958
    float_data: -0.13051341474056244
    name: "bias"
  }
  input {
    name: "input"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_param: "batchsize"
          }
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 244
          }
          dim {
            dim_value: 244
          }
        }
      }
    }
  }
  output {
    name: "output"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_param: "batchsize"
          }
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 244
          }
          dim {
            dim_value: 244
          }
        }
      }
    }
  }
}
opset_import {
  version: 17
}

这里我们可以看到在input中的第一个dim中变成了dim_param: "batchsize"

节点的增加和删除

  • 增加节点
import onnx
import onnx.helper as helper
import onnxruntime
import numpy as np

if __name__ == '__main__':

    model = onnx.load('./model.onnx')
    nodes = model.graph.node
    new_node = helper.make_node(op_type='Relu', name='relu1', inputs=['conv1'], outputs=['output'])
    nodes.append(new_node)
    nodes[0].output[0] = 'conv1'
    onnx.checker.check_model(model)
    onnx.save_model(model, 'add_model.onnx')

    input = np.random.randn(1, 3, 244, 244).astype(np.float32)
    sess = onnxruntime.InferenceSession('./add_model.onnx')
    print(sess.run(['output'], {'input': input}))

运行结果

[array([[[[1.5453527 , 0.        , 0.        , ..., 0.04255658,
          0.        , 0.40214583],
         [0.        , 0.5019511 , 0.        , ..., 0.34235588,
          0.36859825, 0.        ],
         [0.        , 0.        , 0.34334645, ..., 0.        ,
          0.        , 0.        ],
         ...,
         [1.1857387 , 1.0710502 , 0.        , ..., 0.        ,
          1.8497316 , 0.        ],
         [0.37889728, 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ],
         [0.73697627, 0.        , 0.4978644 , ..., 0.        ,
          0.        , 0.32394186]],

        [[1.2723072 , 0.        , 0.66669345, ..., 5.6399436 ,
          1.4827138 , 2.7300682 ],
         [4.5705633 , 2.9856906 , 2.9005556 , ..., 3.505543  ,
          4.7502317 , 0.        ],
         [1.5251542 , 3.3182473 , 3.8036246 , ..., 0.        ,
          1.6024959 , 1.4051957 ],
         ...,
         [1.7204559 , 4.551407  , 4.172427  , ..., 0.9121852 ,
          3.3593512 , 4.6163626 ],
         [0.2845726 , 0.13289118, 3.3601975 , ..., 3.9331636 ,
          0.3700601 , 1.5711328 ],
         [3.3283763 , 2.128338  , 2.1621299 , ..., 1.7635765 ,
          0.        , 2.1479769 ]],

        [[0.        , 0.        , 0.        , ..., 1.4292918 ,
          0.        , 0.46683455],
         [1.0534286 , 0.        , 0.02258705, ..., 0.4342987 ,
          1.1339298 , 0.        ],
         [0.        , 0.50237906, 0.20627443, ..., 0.        ,
          0.        , 0.        ],
         ...,
         [0.        , 0.78040606, 1.003104  , ..., 0.        ,
          0.        , 1.0389903 ],
         [0.        , 0.        , 0.        , ..., 0.74816215,
          0.        , 0.02678718],
         [0.26068228, 0.        , 0.        , ..., 0.        ,
          0.        , 0.        ]]]], dtype=float32)]

这里我们再来打印下model的信息

import onnx
import onnx.helper as helper
import onnxruntime
import numpy as np

if __name__ == '__main__':

    model = onnx.load('./model.onnx')
    nodes = model.graph.node
    new_node = helper.make_node(op_type='Relu', name='relu1', inputs=['conv1'], outputs=['output'])
    nodes.append(new_node)
    nodes[0].output[0] = 'conv1'
    print(model)
    # onnx.checker.check_model(model)
    # onnx.save_model(model, 'add_model.onnx')
    #
    # input = np.random.randn(1, 3, 244, 244).astype(np.float32)
    # sess = onnxruntime.InferenceSession('./add_model.onnx')
    # print(sess.run(['output'], {'input': input}))

运行结果

ir_version: 8
graph {
  node {
    input: "input"
    input: "weight"
    input: "bias"
    output: "conv1"
    op_type: "Conv"
    attribute {
      name: "group"
      i: 1
      type: INT
    }
    attribute {
      name: "kernel_shape"
      ints: 1
      ints: 1
      type: INTS
    }
    attribute {
      name: "pads"
      ints: 0
      ints: 0
      ints: 0
      ints: 0
      type: INTS
    }
    attribute {
      name: "strides"
      ints: 1
      ints: 1
      type: INTS
    }
  }
  node {
    input: "conv1"
    output: "output"
    name: "relu1"
    op_type: "Relu"
  }
  name: "graph"
  initializer {
    dims: 3
    dims: 3
    dims: 1
    dims: 1
    data_type: 1
    float_data: 0.45837152004241943
    float_data: 0.10209446400403976
    float_data: 1.0382566452026367
    float_data: -0.09292714297771454
    float_data: 1.58871591091156
    float_data: 0.3746287226676941
    float_data: -0.35588690638542175
    float_data: 0.7165427207946777
    float_data: 0.10244251787662506
    name: "weight"
  }
  initializer {
    dims: 3
    data_type: 1
    float_data: -0.36782845854759216
    float_data: 2.305680513381958
    float_data: -0.13051341474056244
    name: "bias"
  }
  input {
    name: "input"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 244
          }
          dim {
            dim_value: 244
          }
        }
      }
    }
  }
  output {
    name: "output"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 1
          }
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 244
          }
          dim {
            dim_value: 244
          }
        }
      }
    }
  }
}
opset_import {
  version: 17
}

这里我们可以看到增加了一个relu1的节点,并且第一个节点的output是conv1,第二个节点的input是conv1,output是output。

  • 删除节点
import onnx
import onnxruntime
import numpy as np

if __name__ == '__main__':

    model = onnx.load('./add_model.onnx')
    nodes = model.graph.node
    for node in nodes:
        if node.name == 'relu1':
            nodes.remove(node)
    nodes[0].output[0] = 'output'
    onnx.checker.check_model(model)
    onnx.save_model(model, 'del_model.onnx')

    input = np.random.randn(1, 3, 244, 244).astype(np.float32)
    sess = onnxruntime.InferenceSession('./del_model.onnx')
    print(sess.run(['output'], {'input': input}))

运行结果

[array([[[[-8.5923064e-01, -4.2249173e-01,  3.8687822e-01, ...,
          -4.8348337e-02,  3.1652334e-01, -5.7166600e-01],
         [ 3.1469372e-01, -9.4796360e-01, -2.4245100e+00, ...,
           4.1007617e-01, -1.4098099e+00,  6.7472184e-01],
         [-1.2910874e+00,  1.6070822e-01, -1.0217074e+00, ...,
           7.1467435e-01,  1.5835044e-01, -6.4228356e-01],
         ...,
         [-2.5442154e+00, -8.8969648e-01,  1.1389736e+00, ...,
           1.7202379e+00, -1.1968368e+00, -3.3861694e-01],
         [-9.0216339e-01,  4.8469666e-01, -9.5050204e-01, ...,
           4.0511075e-01, -1.0113320e-01,  1.8743831e+00],
         [ 3.2901958e-01,  4.3780953e-02,  1.4250931e+00, ...,
          -1.4544667e+00,  9.0659869e-01,  1.7170597e+00]],

        [[ 1.3439684e+00,  3.0856354e+00,  2.7811766e+00, ...,
           4.1714394e-01, -3.3547878e-02,  1.1771207e+00],
         [ 2.1574910e+00,  2.1122241e+00, -5.8333945e-01, ...,
           1.9629711e+00,  3.4840956e+00,  6.1747317e+00],
         [ 5.2136226e+00,  4.8688288e+00,  1.4613919e+00, ...,
           4.1095753e+00,  1.4553337e+00,  3.5171165e+00],
         ...,
         [ 7.3736429e-02,  8.4109855e-01,  5.7113109e+00, ...,
           3.6336284e+00,  4.4551125e+00,  3.4602299e+00],
         [ 1.1054695e+00,  2.7417006e+00,  4.9065466e+00, ...,
           2.1775680e+00,  4.4132576e+00,  2.3781679e+00],
         [-1.2788355e+00,  2.5300267e+00,  3.2560487e+00, ...,
           2.2025514e+00,  4.2551570e+00,  3.5148311e+00]],

        [[-8.5124874e-01,  3.1858414e-01,  3.3686757e-03, ...,
          -1.1497847e+00, -1.1996644e+00, -9.6176589e-01],
         [-4.2057925e-01, -1.8098265e-01, -7.4302059e-01, ...,
          -3.5920531e-01,  7.0454830e-01,  1.8304255e+00],
         [ 1.4177717e+00,  8.4456313e-01, -1.6396353e-01, ...,
           4.2133337e-01, -4.6482396e-01,  6.6906375e-01],
         ...,
         [-1.0060047e+00, -1.2088763e+00,  1.2608007e+00, ...,
           5.1739502e-01,  8.9526463e-01,  7.2866821e-01],
         [-3.5698372e-01, -5.9943002e-01,  1.0040566e+00, ...,
           3.1322885e-01,  3.4513384e-01, -6.2404698e-01],
         [-2.0622578e+00,  3.9633280e-01,  2.1701033e-01, ...,
           2.6992482e-01,  4.4787437e-01,  2.1187775e-01]]]],
      dtype=float32)]

 

展开阅读全文
  • 0
    感动
  • 0
    路过
  • 0
    高兴
  • 0
    难过
  • 0
    搞笑
  • 0
    无聊
  • 0
    愤怒
  • 0
    同情
热度排行
友情链接