Switching between TensorFlow and PyTorch with ONNX
Use your favorite AI frameworks without limits
As machine learning models get larger and more complex, the time and material costs of building models often demand that your projects are consistent under one framework. Tensorflow and PyTorch are two of the most popular frameworks in the Artificial Intelligence and Machine Learning field, especially for the deep learning community.
The choice between Tensorflow and PyTorch can often come down to your familiarity with the development and production process in each framework or company and industry standards. However, there are definite advantages to each framework from their ease of use and deployment infrastructure to the available ecosystem support.
Tensorflow is older and has more support with other languages (Tensorflow.js, Swift) with a built-in API and production tools such as TFServing, TFX, and TFLite that ease the deployment of models across local, cloud, IoT, and mobile platforms. Tensorflow also has a more extensive ecosystem with use-case specific deep learning applications (text, images, audio, and video) that can integrate across platforms.
PyTorch is notably easier to learn and utilize, at least for Python programmers. It has a faster model development process with its CUDA backend and efficient memory usage. This has made it a preference for most researchers or open-source organizations such as OpenAI to build their models. This, in turn, means more pre-trained State-of-the-Art models are PyTorch exclusive.
The general opinion is that if you are building models in a research capacity, PyTorch is the framework for you while if you are working within the industry or want to deploy your model in any way Tensorflow is ideal. However, it is possible to develop with PyTorch and deploy with Tensorflow using ONNX.
This article showcases the ability to convert your models between frameworks with ONNX.
Introducing ONNX
The open neural network exchange (ONNX) is an open-source ecosystem that enables AI developers to use the most effective tools for their projects without worrying about production compatibility and downstream deployments. It provides a general format for your artificial intelligence models whether deep learning or otherwise, enabling interoperability between different frameworks and hardware optimization.
ONNX abstracts the similarities between frameworks to create a standard model definition with built-in operators and data types. The platform currently offers support for several AI frameworks to build, deploy, optimize and visualize your model.
Using ONNX
As a researcher working on a computer vision project targeted at classifying plants at various stages of health. You might wish to take advantage of PyTorch’s general ease of use and speed to iterate through the development process of your deep learning model. However, your current backend model deployment process is restricted to Tensorflow specific formats exploiting the extensiveness of Tensorflow Serving,
You can train your model using PyTorch, and save the model per usual using the “torch.save” function
torch.save(your_model.state_dict(), ‘your_model.pth’)
Then utilize PyTorch’s built-in ONNX exporter to represent your model in the ONNX format encompassing your model architecture and parameters.
trained_model = model()
trained_model.load_state_dict(torch.load(your_model.pth’))dummy_input = Variable(torch.randn(1,3,128,128))torch.onnx.export(trained_model, dummy_input, “your_model.onnx”)
Note: ONNX is not just a convenient model format, you can utilize your models in this format with the ONNXRuntime for faster re-training and inferencing processes.
Your next step is to convert your model from its ONNX format to a Tensorflow model format. You can utilize the ONNX Tensorflow Backend package which enables ONXX to Tensorflow compatibility.
import onnx
from onnx_tf.backend import prepareonnx_model = onnx.load('your_model.onnx')
tf_rep = prepare(onnx_model)
This outputs a Tensorflow model representation that can then be used for inferencing or deployment.
Note: Here you have seen the transfer from PyTorch to ONNX to Tensorflow, the reverse can be done as well with Tensorflow to ONNX and ONNX to PyTorch tools.
Conclusion
This article gave you a brief introduction to ONNX and its methods for enabling interoperability between AI frameworks and tools. In this article, I showcased the easy process, using fewer than ten lines of code, of exchanging your AI model between PyTorch and ONNX.
You can find the full code sample from PyTorch training to Tensorflow inference in this GitHub repository.