|
| 1 | +import numpy as np |
| 2 | +from PIL import Image |
| 3 | +import tensorflow as tf |
| 4 | +import keras |
| 5 | +from keras.models import Model |
| 6 | +import sys |
| 7 | +from keras.models import load_model |
| 8 | +from keras.layers import Input, Flatten |
| 9 | +import matplotlib.pyplot as plt |
| 10 | +from kito import reduce_keras_model # Ensure kito is installed |
| 11 | + |
| 12 | +# Load the model (output of training - checkpoint) |
| 13 | +model=load_model(sys.argv[1]) |
| 14 | + |
| 15 | +# Fold batch norms |
| 16 | +model_reduced = reduce_keras_model(model) |
| 17 | +model_reduced.save('bilinear_bnoptimized_munet.h5') # Use this model in PC |
| 18 | + |
| 19 | +# Flatten output and save model (Optimize for phone) |
| 20 | +output = model_reduced.output |
| 21 | +newout=Flatten()(output) |
| 22 | +new_model=Model(model_reduced.input,newout) |
| 23 | + |
| 24 | +new_model.save('bilinear_fin_munet.h5') |
| 25 | + |
| 26 | + |
| 27 | +# For Float32 Model |
| 28 | +converter = tf.lite.TFLiteConverter.from_keras_model_file('bilinear_fin_munet.h5') |
| 29 | +tflite_model = converter.convert() |
| 30 | +open("bilinear_fin_munet.tflite", "wb").write(tflite_model) |
| 31 | + |
| 32 | + |
| 33 | +#For UINT8 Quantization |
| 34 | +converter = tf.lite.TFLiteConverter.from_keras_model_file('bilinear_fin_munet.h5') |
| 35 | +#converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] |
| 36 | +converter.post_training_quantize=True |
| 37 | +tflite_model = converter.convert() |
| 38 | +open("bilinear_fin_munet_uint8.tflite", "wb").write(tflite_model) |
| 39 | + |
| 40 | + |
| 41 | +#For Float16 Quantization (Requires TF 1.15 or above) |
| 42 | +converter = tf.lite.TFLiteConverter.from_keras_model_file('bilinear_fin_munet.h5') |
| 43 | +converter.optimizations = [tf.lite.Optimize.DEFAULT] |
| 44 | +converter.target_spec.supported_types = [tf.lite.constants.FLOAT16] |
| 45 | +tflite_model = converter.convert() |
| 46 | +open("bilinear_fin_munet_fp16.tflite", "wb").write(tflite_model) |
| 47 | + |
| 48 | + |
| 49 | +# Sample run: python export.py checkpoints/up_super_model-102-0.06.hdf5 |
| 50 | + |
0 commit comments