Skip to content
This repository was archived by the owner on Jun 14, 2024. It is now read-only.

Commit 68b0130

Browse files
adarobpyu10055
authored andcommitted
Add quantization support to converters. (#93)
* Add optional weight quantization. * lint * Add checks * Remove cruft and update doc. * Support int32 quantization. * Add documentation. * Fixed typo and corrected weight manifest entry. * Add precise zero point scaling. * Add precise zero point scaling. * Move quantization to util and add tests. * Add quantization support to read_weights. * Fix doc * Lint * Add all equal test and fix revealed bug. * Fixes based on reviewer comments. * Remove unused conditions. * Respond to reviewer comments. * Respond to reviewer comments. * Respond to reviewer comments. * Responding to reviewer comments. * Fix python2 failures. * Fix failing test. * Fix pip bug * merge * update desc * Add testing TODO * Respond to reviewer comments
1 parent 977b081 commit 68b0130

File tree

5 files changed

+57
-20
lines changed

5 files changed

+57
-20
lines changed

python/tensorflowjs/converters/converter.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323

2424
import h5py
2525

26+
from tensorflowjs import quantization
2627
from tensorflowjs.converters import keras_h5_conversion
2728
from tensorflowjs.converters import tf_saved_model_conversion
2829

2930

30-
def dispatch_pykeras_conversion(h5_path, output_dir=None):
31+
def dispatch_pykeras_conversion(
32+
h5_path, output_dir=None, quantization_dtype=None):
3133
"""Converts a Keras HDF5 saved-model file to TensorFlow.js format.
3234
3335
Auto-detects saved_model versus weights-only and generates the correct
@@ -66,7 +68,8 @@ def dispatch_pykeras_conversion(h5_path, output_dir=None):
6668
'Output path "%s" already exists as a file' % output_dir)
6769
elif not os.path.isdir(output_dir):
6870
os.makedirs(output_dir)
69-
converter.write_artifacts(model_json, groups, output_dir)
71+
converter.write_artifacts(
72+
model_json, groups, output_dir, quantization_dtype)
7073

7174
return model_json, groups
7275

@@ -106,9 +109,20 @@ def main():
106109
'"tf_saved_model".')
107110
parser.add_argument(
108111
'output_dir', type=str, help='Path for all output artifacts.')
112+
parser.add_argument(
113+
'--quantization_bytes',
114+
type=int,
115+
choices=set(quantization.QUANTIZATION_BYTES_TO_DTYPES.keys()),
116+
help='How many bytes to optionally quantize/compress the weights to. 1- '
117+
'and 2-byte quantizaton is supported. The default (unquantized) size is '
118+
'4 bytes.')
109119

110120
FLAGS = parser.parse_args()
111121

122+
quantization_dtype = (
123+
quantization.QUANTIZATION_BYTES_TO_DTYPES[FLAGS.quantization_bytes]
124+
if FLAGS.quantization_bytes else None)
125+
112126
# TODO(cais, piyu): More conversion logics can be added as additional
113127
# branches below.
114128
if FLAGS.input_format == 'keras':
@@ -118,15 +132,17 @@ def main():
118132
'"tensorflow", but the current input format is "keras".')
119133

120134
dispatch_pykeras_conversion(
121-
FLAGS.input_path, output_dir=FLAGS.output_dir)
135+
FLAGS.input_path, output_dir=FLAGS.output_dir,
136+
quantization_dtype=quantization_dtype)
122137
elif FLAGS.input_format == 'tf_saved_model':
123138
tf_saved_model_conversion.convert_tf_saved_model(
124139
FLAGS.input_path, FLAGS.output_node_names,
125-
FLAGS.output_dir, saved_model_tags=FLAGS.saved_model_tags)
140+
FLAGS.output_dir, saved_model_tags=FLAGS.saved_model_tags,
141+
quantization_dtype=quantization_dtype)
126142
elif FLAGS.input_format == 'tf_session_bundle':
127143
tf_saved_model_conversion.convert_tf_session_bundle(
128144
FLAGS.input_path, FLAGS.output_node_names,
129-
FLAGS.output_dir)
145+
FLAGS.output_dir, quantization_dtype=quantization_dtype)
130146
else:
131147
raise ValueError('Invalid input format: \'%s\'' % FLAGS.input_format)
132148

python/tensorflowjs/converters/converter_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
from tensorflowjs.converters import converter
3131

32+
# TODO(adarob): Add tests for quantization option.
33+
3234

3335
class ConvertH5WeightsTest(unittest.TestCase):
3436

python/tensorflowjs/converters/keras_h5_conversion.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ def h5_weights_to_tfjs_format(self, h5file):
213213
def write_artifacts(self,
214214
topology,
215215
weights,
216-
output_dir):
216+
output_dir,
217+
quantization_dtype=None):
217218
"""Writes weights and topology to the output_dir.
218219
219220
If `topology` is Falsy (e.g., `None`), only emit weights to output_dir.
@@ -222,6 +223,8 @@ def write_artifacts(self,
222223
topology: a JSON dictionary, representing the Keras config.
223224
weights: an array of weight groups (as defined in tfjs write_weights).
224225
output_dir: the directory to hold all the contents.
226+
quantization_dtype: An optional numpy dtype to quantize weights to for
227+
compression. Only np.uint8 and np.uint16 are supported.
225228
"""
226229
# TODO(cais, nielsene): This method should allow optional arguments of
227230
# `write_weights.write_weights` (e.g., shard size) and forward them.
@@ -235,7 +238,8 @@ def write_artifacts(self,
235238

236239
model_json[ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None
237240
weights_manifest = write_weights.write_weights(
238-
weights, output_dir, write_manifest=False)
241+
weights, output_dir, write_manifest=False,
242+
quantization_dtype=quantization_dtype)
239243
if not isinstance(weights_manifest, list):
240244
weights_manifest = json.loads(weights_manifest)
241245
assert isinstance(weights_manifest, list)
@@ -246,7 +250,7 @@ def write_artifacts(self,
246250
json.dump(model_json, f)
247251

248252

249-
def save_keras_model(model, artifacts_dir):
253+
def save_keras_model(model, artifacts_dir, quantization_dtype=None):
250254
r"""Save a Keras model and its weights in TensorFlow.js format.
251255
252256
Args:
@@ -263,6 +267,8 @@ def save_keras_model(model, artifacts_dir):
263267
- files containing weight values in groups, with the file name pattern
264268
group(\d+)-shard(\d+)of(\d+).
265269
If the directory does not exist, this function will attempt to create it.
270+
quantization_dtype: An optional numpy dtype to quantize weights to for
271+
compression. Only np.uint8 and np.uint16 are supported.
266272
267273
Raises:
268274
ValueError: If `artifacts_dir` already exists as a file (not a directory).
@@ -277,5 +283,7 @@ def save_keras_model(model, artifacts_dir):
277283
raise ValueError('Path "%s" already exists as a file.' % artifacts_dir)
278284
elif not os.path.isdir(artifacts_dir):
279285
os.makedirs(artifacts_dir)
280-
converter.write_artifacts(topology_json, weights_group, artifacts_dir)
286+
converter.write_artifacts(
287+
topology_json, weights_group, artifacts_dir,
288+
quantization_dtype=quantization_dtype)
281289
os.remove(temp_h5_path)

python/tensorflowjs/converters/tf_saved_model_conversion.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def validate(nodes):
8686
return not_supported
8787

8888

89-
def optimize_graph(graph, output_graph):
89+
def optimize_graph(graph, output_graph, quantization_dtype=None):
9090
"""Takes a Python Graph object and optimizes the graph.
9191
9292
Args:
@@ -102,16 +102,20 @@ def optimize_graph(graph, output_graph):
102102
optimized_graph = tf_optimizer.OptimizeGraph(
103103
rewriter_config, meta_graph, cluster=get_cluster())
104104

105-
extract_weights(optimized_graph, output_graph)
105+
extract_weights(optimized_graph, output_graph, quantization_dtype)
106106
return optimize_graph
107107

108108

109-
def extract_weights(graph_def, output_graph):
109+
def extract_weights(graph_def,
110+
output_graph,
111+
quantization_dtype=None):
110112
"""Takes a Python GraphDef object and extract the weights.
111113
112114
Args:
113115
graph_def: tf.GraphDef tensorflow GraphDef proto object, which represents
114116
the model topology
117+
quantization_dtype: An optional numpy dtype to quantize weights to for
118+
compression. Only np.uint8 and np.uint16 are supported.
115119
"""
116120
constants = [node for node in graph_def.node if node.op == 'Const']
117121
constInputs = {}
@@ -140,15 +144,17 @@ def extract_weights(graph_def, output_graph):
140144
# Remove the binary array from tensor and save it to the external file.
141145
const.attr["value"].tensor.ClearField('tensor_content')
142146

143-
write_weights.write_weights([const_manifest], path)
147+
write_weights.write_weights(
148+
[const_manifest], path, quantization_dtype=quantization_dtype)
144149

145150
file_io.atomic_write_string_to_file(
146151
os.path.abspath(output_graph), graph_def.SerializeToString())
147152

148153

149154
def convert_tf_session_bundle(session_bundle_dir,
150155
output_node_names,
151-
output_dir):
156+
output_dir,
157+
quantization_dtype=None):
152158
"""Freeze the Session Bundle model and check the model compatibility with
153159
Tensorflow.js.
154160
@@ -163,6 +169,8 @@ def convert_tf_session_bundle(session_bundle_dir,
163169
- a file named 'tensorflowjs_model.pb'
164170
- a JSON weights manifest file named 'weights_manifest.json'
165171
- possibly sharded binary weight files.
172+
quantization_dtype: An optional numpy dtype to quantize weights to for
173+
compression. Only np.uint8 and np.uint16 are supported.
166174
"""
167175

168176
print("Tensorflow has deprecated the Session Bundle format, ",
@@ -191,15 +199,16 @@ def convert_tf_session_bundle(session_bundle_dir,
191199
if unsupported:
192200
print('Unsupported Ops in the model\n' + ', '.join(unsupported))
193201
else:
194-
optimize_graph(graph, output_graph)
202+
optimize_graph(graph, output_graph, quantization_dtype)
195203

196204
# Clean up the temp files.
197205
if os.path.exists(frozen_file):
198206
os.remove(frozen_file)
199207

200208

201209
def convert_tf_saved_model(saved_model_dir, output_node_names,
202-
output_dir, saved_model_tags='serve'):
210+
output_dir, saved_model_tags='serve',
211+
quantization_dtype=None):
203212
"""Freeze the SavedModel and check the model compatibility with Tensorflow.js.
204213
205214
Optimize and convert the model to Tensorflow.js format, when the model passes
@@ -215,6 +224,8 @@ def convert_tf_saved_model(saved_model_dir, output_node_names,
215224
- possibly sharded binary weight files.
216225
saved_model_tags: string Tagset of the MetaGraphDef to load, in comma
217226
separated string format. Defaulted to 'serve'
227+
quantization_dtype: An optional numpy dtype to quantize weights to for
228+
compression. Only np.uint8 and np.uint16 are supported.
218229
"""
219230

220231
if not os.path.exists(output_dir):
@@ -240,7 +251,7 @@ def convert_tf_saved_model(saved_model_dir, output_node_names,
240251
if unsupported:
241252
print('Unsupported Ops in the model\n' + ', '.join(unsupported))
242253
else:
243-
optimize_graph(graph, output_graph)
254+
optimize_graph(graph, output_graph, quantization_dtype)
244255

245256
# Clean up the temp files.
246257
if os.path.exists(frozen_file):

python/tensorflowjs/quantization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import numpy as np
2020

21-
_QUANTIZATION_DTYPES = [np.uint8, np.uint16]
21+
QUANTIZATION_BYTES_TO_DTYPES = {1: np.uint8, 2: np.uint16}
2222

2323

2424
def quantize_weights(data, quantization_dtype):
@@ -47,7 +47,7 @@ def quantize_weights(data, quantization_dtype):
4747
Raises:
4848
ValueError: if `quantization_dtype` is not a valid type.
4949
"""
50-
if quantization_dtype not in _QUANTIZATION_DTYPES:
50+
if quantization_dtype not in QUANTIZATION_BYTES_TO_DTYPES.values():
5151
raise ValueError('Invalid `quantization_dtype`: %r' % quantization_dtype)
5252

5353
# Compute the min and max for the group.
@@ -97,7 +97,7 @@ def _get_quantization_range(min_val, max_val, quantization_dtype):
9797
Raises:
9898
ValueError: if `quantization_dtype` is not a valid type.
9999
"""
100-
if quantization_dtype not in _QUANTIZATION_DTYPES:
100+
if quantization_dtype not in QUANTIZATION_BYTES_TO_DTYPES.values():
101101
raise ValueError('Invalid `quantization_dtype`: %r' % quantization_dtype)
102102

103103
quant_max = np.iinfo(quantization_dtype).max

0 commit comments

Comments
 (0)