QAT int8 MKL-DNN transformation pass#17819
Conversation
There was a problem hiding this comment.
We should have checking somehwere if mkldnn is enabled (i.e. FLAGS_use_mkldnn==1). Currently this transformation happens even if FLAGS_use_mkldnn=false when I run test written by @wojtuss #17814.
(This transformation sets attribute use_mkldnn=true for some operators which causes execution of mkldnn operators)
python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py
Outdated
Show resolved
Hide resolved
sfraczek
left a comment
There was a problem hiding this comment.
LGTM. the control whether to run this or not based on use_mkldnn can be added somewhere else
dc5f69e to
cab855c
Compare
| from paddle.fluid import core | ||
|
|
||
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | ||
| os.environ["CPU_NUM"] = "1" |
There was a problem hiding this comment.
In this test, we need to train a model for several iterations to get the weights value. And then converted to int8 and check if it is converted rightly. When we train this model, we need to set "CPU_NUM". This train model idea is following below example:
|
|
||
| class TransformForMkldnnPass(object): | ||
| def __init__(self, scope=None, place=None): | ||
| """ |
There was a problem hiding this comment.
These comments are not for __init__ function, they are for class TransformForMkldnnPass. Thus, move them after line 23?
This PR is to convert QuantizationFreezePass generated IrGraph to MKL-DNN support INT8
runnable IrGraph.
Following transformations have been done in this pass:
1. Convert int8 range weights with float type, (which generated by the QuantizationFreezePass), to fp32 range weights with float dtype by the corresponding scales.
Unit test is testing based on the Graph already applied QuantizationTransformPass and QuantizationFreezePass, it will check if:2. Create the new conv2d op with the converted weights and link its output to fake_dequantize_abs_max output and set conv2d's attribute "force_fp32_output" as true
3. Transform fake_quantize_xx to quantize op
4. Remove fake_dequantize_abs_max op
1. conv2d's output is rightly linked to the fake_dequantize op's output
The latest accuracy of full ImageNet validation set on Cascade Lake (new generation Xeon) is as below:2. conv2d's weights has been converted to fp32 range, they are not integer any more
3. check the graph locally to make sure that op's type have been transformed as we expected.