5.JAX Library - MO436-MC934/notebooks GitHub Wiki
5. The JAX Library
In the fourth project you will have to implement the ResNet-18 model using JAX and Flax libraries [1,2] for inference only. The FLAX library, discussed in the second lecture, provides primitives to stack multiple kinds of layers in order to form a neural network architecture. You can load the network weights using an ONNX model available at [3]. Installing the onnx
python package [3] gives you an API to manipulate ONNX files. The code below shows a starting point on how to load the weights from a ONNX file.
import onnx
from onnx import numpy_helper
model = onnx.load("resnet.onnx")
for initializer in model.graph.initializer:
array = numpy_helper.to_array(initializer)
print(f"- Tensor: {initializer.name!r:45} shape={array.shape}")
The code above produces the following output:
- Tensor: 'resnetv15_conv0_weight' shape=(64, 3, 7, 7)
- Tensor: 'resnetv15_batchnorm0_gamma' shape=(64,)
- Tensor: 'resnetv15_batchnorm0_beta' shape=(64,)
- Tensor: 'resnetv15_batchnorm0_running_mean' shape=(64,)
- Tensor: 'resnetv15_batchnorm0_running_var' shape=(64,)
- Tensor: 'resnetv15_stage1_conv0_weight' shape=(64, 64, 3, 3)
- Tensor: 'resnetv15_stage1_batchnorm0_gamma' shape=(64,)
- Tensor: 'resnetv15_stage1_batchnorm0_beta' shape=(64,)
- Tensor: 'resnetv15_stage1_batchnorm0_running_mean' shape=(64,)
- Tensor: 'resnetv15_stage1_batchnorm0_running_var' shape=(64,)
...
After you have implemented the model and loaded the weights, it is time to test your code on a few test images. You may use the images provided in [5]. Finally, you job is to obtain the maximum amount of performance improvement from your network using JAX transformations: jit
, vmap
, pmap
, etc. You may use your CPU, GPU or TPU (in Google Collab). In the end, you should write a report (PDF) describing how you implemented your model, which transformations you applied and why. If you could not apply some transformations, discuss the problems you found while trying to use it. Finally, do a performance analysis showing how fast your model has become compared to the non-transformed model (tables and graphs are welcome).