博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
使用 ONNX 将模型从 PyTorch 迁移到 Caffe2
阅读量:6040 次
发布时间:2019-06-20

本文共 3452 字,大约阅读时间需要 11 分钟。

hot3.png

1. PyTorch及ONNX环境准备

为了正常运行ONNX,我们需要安装最新的Pytorch

git clone --recursive https://github.com/pytorch/pytorchcd pytorchmkdir build && cd buildsudo cmake .. -DPYTHON_INCLUDE_DIR=/usr/include/python3.5  -DUSE_MPI=OFFmake installexport PYTHONPATH=$PYTHONPATH:/opt/pytorch/build

上面的"/opt/pytorch/build"是你前面build pytorch的目录,写对路径即可。

通过整个PyTorch的源码安装,PyTorch支持的相关ONNX库也会随之安装好。安装路径在:/usr/local/lib/python3.5/dist-packages/torch

运行如下命令安装ONNX的库:

conda install -c conda-forge onnx

此外,还需要安装onnx-caffe2,一个纯Python库,它为ONNX提供了一个caffe2的编译器。你可以用pip安装onnx-caffe2:

pip3 install onnx-caffe2

2. 准备好把PyTorch转换成ONNX的代码

在 上面的pytorch2caffe2.py就是一段参考代码,把DeblurGAN训练好的模型转换成ONNX 。代码解释如下:

import osimport sysimport torchimport torch.onnximport torch.utils.model_zoofrom torch.autograd import Variablesys.path.append("../DeblurGAN")from models.models import create_modelimport models.networks as networksfrom options.test_options import TestOptionsimport shutilimport onnxfrom onnx_caffe2.backend import Caffe2Backendbatch_size = 1    # just a random number# Load the pretrained model weightsmodel_path = './model/char_deblur/latest_net_G.pth'onnx_model_path = "./deblurring.onnx.pb"state_dict = torch.utils.model_zoo.load_url(model_path, model_dir="./model/char_deblur")# Load the DeblurnGAN neural networkgan_opt = TestOptions().parse()gan_opt.name = "char_deblur"gan_opt.checkpoints_dir = "./model/"gan_opt.model = "test"gan_opt.dataset_mode = "single"gan_opt.dataroot = "/tmp/gan/"try:    shutil.rmtree(gan_opt.dataroot)except:    passos.mkdir(gan_opt.dataroot)gan_opt.loadSizeX = 64gan_opt.loadSizeY = 64gan_opt.fineSize = 64gan_opt.learn_residual = Truegan_opt.nThreads = 1  # test code only supports nThreads = 1gan_opt.batchSize = 1  # test code only supports batchSize = 1gan_opt.serial_batches = True  # no shufflegan_opt.no_flip = True  # no flip#torch_model = create_model(gan_opt)gpus = []torch_model = networks.define_G(gan_opt.input_nc, gan_opt.output_nc, gan_opt.ngf,                                gan_opt.which_model_netG, gan_opt.norm, not gan_opt.no_dropout, gpus, False,                                gan_opt.learn_residual)torch_model.load_state_dict(state_dict)#torch_model.load_state_dict(state_dict)# set the train mode to false since we will only run the forward pass.torch_model.train(False)# Input to the modelx = Variable(torch.randn(batch_size, 3, 60, 60), requires_grad=True)x = x.float()# Export the modeltorch_out = torch.onnx._export(torch_model,             # model being run                               x,                       # model input (or a tuple for multiple inputs)                               onnx_model_path, # where to save the model (can be a file or file-like object)                               verbose=True, export_params=True, training=False)      # store the trained parameter weights inside the model fileonnx_model = onnx.load(onnx_model_path)onnx.checker.check_model(onnx_model)model_name = onnx_model_path.replace('.onnx.pb','')init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model.graph, device="CUDA")with open(model_name + "_init.pb", "wb") as f:    f.write(init_net.SerializeToString())with open(model_name + "_predict.pb", "wb") as f:    f.write(predict_net.SerializeToString())

基于这个例子中,用户需要自己修改的部分有如下几个:

  • 训练好的PyTorch模型的路径:如在这个例子中“./model/char_deblur/latest_net_G.pth”需要修改成自己的模型路径
  •  

通过上面的代码,将生成两个Caffe2的pb文件,deblurring_init.pb 和 deblurring_predict.pb。

 

转载于:https://my.oschina.net/u/1431433/blog/2878668

你可能感兴趣的文章
Android系统Intent的使用
查看>>
nginx+video-thumbextractor生成视频缩略图
查看>>
linux的shell编程初探---变量
查看>>
考研英语(1-10)转自何凯文老师
查看>>
监控和消耗内存资源
查看>>
萤石云API (4年前分享)
查看>>
Java多线程——重进入(Reentrancy)机制
查看>>
Apache和Nginx设置伪静态(URL Rewrite)的方法
查看>>
crontab 使用时间
查看>>
远程密令临时开启ssh端口
查看>>
【Visual C++】游戏开发笔记之九 游戏地图制作(一)平面地图贴图
查看>>
ACCP学习旅程之----- CSS样式库
查看>>
Apache日志Shell分析
查看>>
freemarker中日期的比较
查看>>
特殊用法
查看>>
Linux service管理自定义脚本
查看>>
mysql创建date数据类型
查看>>
linux开机图形界面和字符界面切换
查看>>
sphinx的安装
查看>>
scsi_cnmd.h
查看>>