Skip to content

Commit

Permalink
add fp16 option when freeze graph (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cjkkkk authored Nov 4, 2020
1 parent 81220ad commit 8ae20c6
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 8 deletions.
58 changes: 58 additions & 0 deletions models/tensorflow/nnf_tf_freezer/convert_graph_fp16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import tensorflow as tf
from tensorflow.core.framework import types_pb2, graph_pb2, attr_value_pb2
from tensorflow.tools.graph_transforms import TransformGraph
from google.protobuf import text_format
import numpy as np


def convert_graph_to_fp16(source_graph_def, target_type='fp16', input_name=None, output_names=None, keep_fp32_node_name=[]):
if target_type == 'fp16':
dtype = types_pb2.DT_HALF
elif target_type == 'fp64':
dtype = types_pb2.DT_DOUBLE
else:
dtype = types_pb2.DT_FLOAT

target_graph_def = graph_pb2.GraphDef()
target_graph_def.versions.CopyFrom(source_graph_def.versions)
for node in source_graph_def.node:
# replicate node
new_node = target_graph_def.node.add()
new_node.op = node.op
new_node.name = node.name
new_node.input.extend(node.input)
attrs = list(node.attr.keys())
# replace dtype in node attr with target dtype
for attr in attrs:
# keep special node in fp32
new_node.attr[attr].CopyFrom(node.attr[attr])
if node.name in keep_fp32_node_name:
continue
if node.attr[attr].type == types_pb2.DT_FLOAT:
# modify node dtype
new_node.attr[attr].type = dtype
if attr == "value":
tensor = node.attr[attr].tensor
if tensor.dtype == types_pb2.DT_FLOAT:
# if float_val exists
if tensor.float_val:
float_val = tf.make_ndarray(node.attr[attr].tensor)
new_node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(float_val, dtype=dtype))
continue
# if tensor content exists
if tensor.tensor_content:
tensor_shape = [x.size for x in tensor.tensor_shape.dim]
tensor_weights = tf.make_ndarray(tensor)
# reshape tensor
tensor_weights = np.reshape(tensor_weights, tensor_shape)
new_node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(tensor_weights, dtype=dtype))
continue
# transform graph
if output_names:
if not input_name:
input_name = []
transforms = ["strip_unused_nodes"]
target_graph_def = TransformGraph(target_graph_def, input_name, output_names, transforms)
# write graph_def to model
print("Converting done ...")
return target_graph_def
20 changes: 16 additions & 4 deletions models/tensorflow/nnf_tf_freezer/nnf_tf_freezer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto
from tensorflow.tools import graph_transforms
from typing import List
from convert_graph_fp16 import*

class nnf_tf_freezer(object):
def __init__(self, frozen_graph= "frozen_graph.pb", const_folding=True, run_graph=True, xla=False, parallel=0,
warmup=5, num_iter=10, run_const_folded_graph=False, debug=False, is_training=False):
warmup=5, num_iter=10, run_const_folded_graph=False, debug=False, is_training=False, to_fp16=False):
self.frozen_graph = frozen_graph
self.const_folding = const_folding
self.run_graph = run_graph
Expand All @@ -30,7 +31,8 @@ def __init__(self, frozen_graph= "frozen_graph.pb", const_folding=True, run_grap
self.run_const_folded_graph = run_const_folded_graph
self.debug = debug
self.is_training = is_training

self.to_fp16 = to_fp16

def execute(self, inputs : List[tf.placeholder], outputs : List[tf.identity], optimizer : tf.train.Optimizer=None):
self.freeze(inputs, outputs, optimizer)
if self.const_folding:
Expand Down Expand Up @@ -83,7 +85,18 @@ def freeze(self, inputs : List[tf.placeholder], outputs : List[tf.identity], opt
except:
print('Not using existing checkpoint.')
pass

saver_path = saver.save(sess, "/tmp/save/model.ckpt")

if self.to_fp16:
# convert graph to fp16 model
print('convert to fp16 model')
input_name = [input.name for input in inputs]
output_names = [output.name for output in outputs]

new_graph = convert_graph_to_fp16(sess.graph_def, target_type='fp16', input_name=input_name, output_names=output_names)
tf.train.write_graph(new_graph, '/tmp/save', 'model.pbtxt')

freeze_graph.freeze_graph(
input_graph="/tmp/save/model.pbtxt",
input_checkpoint="/tmp/save/model.ckpt",
Expand All @@ -98,7 +111,6 @@ def freeze(self, inputs : List[tf.placeholder], outputs : List[tf.identity], opt
variable_names_blacklist = varlist)
'''
self.graphdef_to_json(self.frozen_graph, self.frozen_graph + ".json.gz")
ops_used = subprocess.getoutput("zgrep -v tensorContent " + self.frozen_graph + ".json.gz | grep '\"op\":' | sort | uniq | awk -F'\"' '{print $4}' | xargs echo").split()
os.system('zgrep -v tensorContent ' + self.frozen_graph + '.json.gz > ' + self.frozen_graph + '.json.thin')
print('>> Ops used by Graph `%s`:' % self.frozen_graph)
Expand All @@ -107,7 +119,7 @@ def freeze(self, inputs : List[tf.placeholder], outputs : List[tf.identity], opt
for op in ops_used:
fp.write(op + '\n')
'''

def tf_run_const_folding(self, file):
print("run const folding----------------------------")
tf.reset_default_graph()
Expand Down
7 changes: 3 additions & 4 deletions models/tensorflow/nnf_tf_freezer/tf_freeze_graph_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
help='Print log.')
parser.add_argument('--is_training', action='store_true',
help='Is training graph.')

parser.add_argument('--to_fp16', action='store_true',
help='whether save frozen_graph in fp16 format')

args = parser.parse_args()

Expand Down Expand Up @@ -225,8 +226,6 @@

if __name__ == "__main__":
freezer = nnf_tf_freezer(args.frozen_graph, args.const_folding, args.run_graph, args.xla, args.parallel,
args.warmup, args.num_iter, args.run_const_folded_graph, args.debug, args.is_training)
args.warmup, args.num_iter, args.run_const_folded_graph, args.debug, args.is_training, args.to_fp16)
freezer.execute(inputs, outputs, optimizer)



0 comments on commit 8ae20c6

Please sign in to comment.