Pytorch模型转ONNX时cross操作不支持的解决方法
概述
Pytorch很灵活,支持各种OP和Python的动态语法。但是转换到onnx的时候,有些OP(目前)并不支持,比如torch.cross
。这里以一个最小化的例子来演示这个过程,以及对应的解决办法。
一个例子
考虑下面这个简单的Pytorch转ONNX的例子:
1 | # file name: pytorch_cross_to_onnx.py |
运行这个脚本,会报下面的错误:
1 | $ python3 pytorch_cross_to_onnx.py |
注意最后一句的报错:
1 | RuntimeError: Exporting the operator cross to ONNX opset version 14 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub. |
也就是说目前版本是不支持torch.cross
转onnx的,同时提示你”feel free” 去Pytorch 的 GitHub 上提交/贡献一个转换操作。不过2020年03月就有人提了issue,至今仍没有g官方的解决方案。
解决办法
上面的issue里有人给出了解决思路,就是用元素相乘替代cross
操作。具体来说,实现如下:
1 | def my_cross(x, y, dim=1): |
注意:这里是以dim=1为例写的实现,如果是在别的维度进行cross操作,需要修改dim参数,同时修改对应stack的维度。
同时在Pytorch doc网站上看到,如果torch.cross
不指定dim
参数的话,默认是从前往后找第一个维度为3的维度,因此这个可能是你所不期望的,建议显式指定这个参数。
因此总结下来,下面是修改后的代码:
1 | import torch |
为了验证我们的实现与Pytorch的实现是否一致,可以用下面的函数验证:
1 | def test_torch_cross_and_my_cross(): |
执行后输出如下:
1 | my_cross == torch.cross: True |
说明这个实现是正确的。