@@ -45,6 +45,7 @@ public sealed class TensorFlowTransformer : RowToRowTransformerBase, IDisposable
4545 private readonly string _savedModelPath ;
4646 private readonly bool _isTemporarySavedModel ;
4747 private readonly bool _addBatchDimensionInput ;
48+ private readonly bool _treatOutputAsBatched ;
4849 internal readonly Session Session ;
4950 internal readonly Runner Runner ;
5051 internal readonly DataViewType [ ] OutputTypes ;
@@ -71,8 +72,9 @@ private static VersionInfo GetVersionInfo()
7172 modelSignature : "TENSFLOW" ,
7273 //verWrittenCur: 0x00010001, // Initial
7374 //verWrittenCur: 0x00010002, // Added Support for Multiple Outputs and SavedModel.
74- verWrittenCur : 0x00010003 , // Added Support for adding batch dimension in inputs.
75- verReadableCur : 0x00010003 ,
75+ //verWrittenCur: 0x00010003, // Added Support for adding batch dimension in inputs.
76+ verWrittenCur : 0x00010004 , // Added Support for treating batch as output or not.
77+ verReadableCur : 0x00010004 ,
7678 verWeCanReadBack : 0x00010001 ,
7779 loaderSignature : LoaderSignature ,
7880 loaderAssemblyName : typeof ( TensorFlowTransformer ) . Assembly . FullName ) ;
@@ -123,20 +125,21 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte
123125 // *** Binary format ***
124126 // byte: indicator for frozen models
125127 // byte: indicator for adding batch dimension in input
128+ // byte: indicator for treating output as batched
126129 // stream: tensorFlow model.
127130 // int: number of input columns
128131 // for each input column
129132 // int: id of int column name
130133 // int: number of output columns
131134 // for each output column
132135 // int: id of output column name
133- GetModelInfo ( env , ctx , out string [ ] inputs , out string [ ] outputs , out bool isFrozen , out bool addBatchDimensionInput ) ;
136+ GetModelInfo ( env , ctx , out string [ ] inputs , out string [ ] outputs , out bool isFrozen , out bool addBatchDimensionInput , out bool treatOutputAsBatched ) ;
134137 if ( isFrozen )
135138 {
136139 byte [ ] modelBytes = null ;
137140 if ( ! ctx . TryLoadBinaryStream ( "TFModel" , r => modelBytes = r . ReadByteArray ( ) ) )
138141 throw env . ExceptDecode ( ) ;
139- return new TensorFlowTransformer ( env , LoadTFSession ( env , modelBytes ) , outputs , inputs , null , false , addBatchDimensionInput ) ;
142+ return new TensorFlowTransformer ( env , LoadTFSession ( env , modelBytes ) , outputs , inputs , null , false , addBatchDimensionInput , treatOutputAsBatched : treatOutputAsBatched ) ;
140143 }
141144
142145 var tempDirPath = Path . GetFullPath ( Path . Combine ( Path . GetTempPath ( ) , nameof ( TensorFlowTransformer ) + "_" + Guid . NewGuid ( ) ) ) ;
@@ -165,7 +168,7 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte
165168 }
166169 } ) ;
167170
168- return new TensorFlowTransformer ( env , GetSession ( env , tempDirPath ) , outputs , inputs , tempDirPath , true , addBatchDimensionInput ) ;
171+ return new TensorFlowTransformer ( env , GetSession ( env , tempDirPath ) , outputs , inputs , tempDirPath , true , addBatchDimensionInput , treatOutputAsBatched : treatOutputAsBatched ) ;
169172 }
170173 catch ( Exception )
171174 {
@@ -237,7 +240,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
237240 private static IRowMapper Create ( IHostEnvironment env , ModelLoadContext ctx , DataViewSchema inputSchema )
238241 => Create ( env , ctx ) . MakeRowMapper ( inputSchema ) ;
239242
240- private static void GetModelInfo ( IHostEnvironment env , ModelLoadContext ctx , out string [ ] inputs , out string [ ] outputs , out bool isFrozen , out bool addBatchDimensionInput )
243+ private static void GetModelInfo ( IHostEnvironment env , ModelLoadContext ctx , out string [ ] inputs , out string [ ] outputs , out bool isFrozen , out bool addBatchDimensionInput , out bool treatOutputAsBatched )
241244 {
242245 isFrozen = true ;
243246 bool isNonFrozenModelSupported = ctx . Header . ModelVerReadable >= 0x00010002 ;
@@ -249,6 +252,11 @@ private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out
249252 if ( isAddingBatchDimensionSupported )
250253 addBatchDimensionInput = ctx . Reader . ReadBoolByte ( ) ;
251254
255+ treatOutputAsBatched = true ;
256+ bool isTreatingOutputAsBatchedSupported = ctx . Header . ModelVerReadable >= 0x00010004 ;
257+ if ( isTreatingOutputAsBatchedSupported )
258+ treatOutputAsBatched = ctx . Reader . ReadBoolByte ( ) ;
259+
252260 var numInputs = ctx . Reader . ReadInt32 ( ) ;
253261 env . CheckDecode ( numInputs > 0 ) ;
254262 inputs = new string [ numInputs ] ;
@@ -280,6 +288,7 @@ internal TensorFlowTransformer(IHostEnvironment env, Session session, string[] o
280288 _isTemporarySavedModel = isTemporarySavedModel ;
281289 Session = session ;
282290 _addBatchDimensionInput = addBatchDimensionInput ;
291+ _treatOutputAsBatched = treatOutputAsBatched ;
283292 Inputs = inputColumnNames ;
284293 Outputs = outputColumnNames ;
285294 tf . compat . v1 . disable_eager_execution ( ) ;
@@ -421,6 +430,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
421430 // *** Binary format ***
422431 // byte: indicator for frozen models
423432 // byte: indicator for adding batch dimension in input
433+ // byte: indicator for treating output as batched
424434 // stream: tensorFlow model.
425435 // int: number of input columns
426436 // for each input column
@@ -431,6 +441,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
431441 var isFrozen = string . IsNullOrEmpty ( _savedModelPath ) ;
432442 ctx . Writer . WriteBoolByte ( isFrozen ) ;
433443 ctx . Writer . WriteBoolByte ( _addBatchDimensionInput ) ;
444+ ctx . Writer . WriteBoolByte ( _treatOutputAsBatched ) ;
434445 if ( isFrozen )
435446 {
436447 using ( var status = new Status ( ) )
0 commit comments