Skip to content

Commit d045354

Browse files
committed
Fixing saving/loading with new parameter.
1 parent d35e1ce commit d045354

File tree

3 files changed

+63
-38
lines changed

3 files changed

+63
-38
lines changed

src/Microsoft.ML.TensorFlow/TensorflowTransform.cs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ public sealed class TensorFlowTransformer : RowToRowTransformerBase, IDisposable
4545
private readonly string _savedModelPath;
4646
private readonly bool _isTemporarySavedModel;
4747
private readonly bool _addBatchDimensionInput;
48+
private readonly bool _treatOutputAsBatched;
4849
internal readonly Session Session;
4950
internal readonly Runner Runner;
5051
internal readonly DataViewType[] OutputTypes;
@@ -71,8 +72,9 @@ private static VersionInfo GetVersionInfo()
7172
modelSignature: "TENSFLOW",
7273
//verWrittenCur: 0x00010001, // Initial
7374
//verWrittenCur: 0x00010002, // Added Support for Multiple Outputs and SavedModel.
74-
verWrittenCur: 0x00010003, // Added Support for adding batch dimension in inputs.
75-
verReadableCur: 0x00010003,
75+
//verWrittenCur: 0x00010003, // Added Support for adding batch dimension in inputs.
76+
verWrittenCur: 0x00010004, // Added Support for treating batch as output or not.
77+
verReadableCur: 0x00010004,
7678
verWeCanReadBack: 0x00010001,
7779
loaderSignature: LoaderSignature,
7880
loaderAssemblyName: typeof(TensorFlowTransformer).Assembly.FullName);
@@ -123,20 +125,21 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte
123125
// *** Binary format ***
124126
// byte: indicator for frozen models
125127
// byte: indicator for adding batch dimension in input
128+
// byte: indicator for treating output as batched
126129
// stream: tensorFlow model.
127130
// int: number of input columns
128131
// for each input column
129132
// int: id of int column name
130133
// int: number of output columns
131134
// for each output column
132135
// int: id of output column name
133-
GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput);
136+
GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput, out bool treatOutputAsBatched);
134137
if (isFrozen)
135138
{
136139
byte[] modelBytes = null;
137140
if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray()))
138141
throw env.ExceptDecode();
139-
return new TensorFlowTransformer(env, LoadTFSession(env, modelBytes), outputs, inputs, null, false, addBatchDimensionInput);
142+
return new TensorFlowTransformer(env, LoadTFSession(env, modelBytes), outputs, inputs, null, false, addBatchDimensionInput, treatOutputAsBatched: treatOutputAsBatched);
140143
}
141144

142145
var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), nameof(TensorFlowTransformer) + "_" + Guid.NewGuid()));
@@ -165,7 +168,7 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte
165168
}
166169
});
167170

168-
return new TensorFlowTransformer(env, GetSession(env, tempDirPath), outputs, inputs, tempDirPath, true, addBatchDimensionInput);
171+
return new TensorFlowTransformer(env, GetSession(env, tempDirPath), outputs, inputs, tempDirPath, true, addBatchDimensionInput, treatOutputAsBatched: treatOutputAsBatched);
169172
}
170173
catch (Exception)
171174
{
@@ -237,7 +240,7 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
237240
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
238241
=> Create(env, ctx).MakeRowMapper(inputSchema);
239242

240-
private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput)
243+
private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, out string[] outputs, out bool isFrozen, out bool addBatchDimensionInput, out bool treatOutputAsBatched)
241244
{
242245
isFrozen = true;
243246
bool isNonFrozenModelSupported = ctx.Header.ModelVerReadable >= 0x00010002;
@@ -249,6 +252,11 @@ private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out
249252
if (isAddingBatchDimensionSupported)
250253
addBatchDimensionInput = ctx.Reader.ReadBoolByte();
251254

255+
treatOutputAsBatched = true;
256+
bool isTreatingOutputAsBatchedSupported = ctx.Header.ModelVerReadable >= 0x00010004;
257+
if (isTreatingOutputAsBatchedSupported)
258+
treatOutputAsBatched = ctx.Reader.ReadBoolByte();
259+
252260
var numInputs = ctx.Reader.ReadInt32();
253261
env.CheckDecode(numInputs > 0);
254262
inputs = new string[numInputs];
@@ -280,6 +288,7 @@ internal TensorFlowTransformer(IHostEnvironment env, Session session, string[] o
280288
_isTemporarySavedModel = isTemporarySavedModel;
281289
Session = session;
282290
_addBatchDimensionInput = addBatchDimensionInput;
291+
_treatOutputAsBatched = treatOutputAsBatched;
283292
Inputs = inputColumnNames;
284293
Outputs = outputColumnNames;
285294
tf.compat.v1.disable_eager_execution();
@@ -421,6 +430,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
421430
// *** Binary format ***
422431
// byte: indicator for frozen models
423432
// byte: indicator for adding batch dimension in input
433+
// byte: indicator for treating output as batched
424434
// stream: tensorFlow model.
425435
// int: number of input columns
426436
// for each input column
@@ -431,6 +441,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
431441
var isFrozen = string.IsNullOrEmpty(_savedModelPath);
432442
ctx.Writer.WriteBoolByte(isFrozen);
433443
ctx.Writer.WriteBoolByte(_addBatchDimensionInput);
444+
ctx.Writer.WriteBoolByte(_treatOutputAsBatched);
434445
if (isFrozen)
435446
{
436447
using (var status = new Status())

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,38 +1152,6 @@ public void TensorFlowGettingSchemaMultipleTimes()
11521152
}
11531153
}
11541154

1155-
// This test has been created as result of https://github.com/dotnet/machinelearning/issues/5364.
1156-
[TensorFlowFact]
1157-
public void TreatOutputAsBatched()
1158-
{
1159-
MLContext mlContext = new MLContext();
1160-
1161-
var dataView = mlContext.Data.CreateTextLoader<TextInput>().Load(new MultiFileSource(null));
1162-
1163-
string modelLocation = @"model_string_test";
1164-
1165-
// When treatOutputAsBatched is defaulted to true, make sure the output is correct.
1166-
using var model = mlContext.Model.LoadTensorFlowModel(modelLocation).ScoreTensorFlowModel(new[] { "Original_A", "Joined_Splited_Text" }, new[] { "A", "B" })
1167-
.Fit(dataView);
1168-
1169-
var modelSchema = model.GetOutputSchema(dataView.Schema);
1170-
1171-
Assert.Equal(4, modelSchema.Count);
1172-
Assert.Equal(new VectorDataViewType(TextDataViewType.Instance, 2), modelSchema[2].Type);
1173-
Assert.Equal(new VectorDataViewType(TextDataViewType.Instance, 1,1), modelSchema[3].Type);
1174-
1175-
using var modelNonBatched = mlContext.Model.LoadTensorFlowModel(modelLocation, false).ScoreTensorFlowModel(new[] { "Original_A", "Joined_Splited_Text" }, new[] { "A", "B" })
1176-
.Fit(dataView);
1177-
1178-
modelSchema = modelNonBatched.GetOutputSchema(dataView.Schema);
1179-
1180-
Assert.Equal(4, modelSchema.Count);
1181-
Assert.Equal(new VectorDataViewType(TextDataViewType.Instance, 0, 2), modelSchema[2].Type);
1182-
Assert.Equal(new VectorDataViewType(TextDataViewType.Instance, 1, 1), modelSchema[3].Type);
1183-
1184-
}
1185-
1186-
11871155
[TensorFlowFact]
11881156
public void TensorFlowTransformCifarInvalidShape()
11891157
{

test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,52 @@ public void TestTensorFlow()
183183
Assert.Equal(4, numRows);
184184
}
185185
}
186+
187+
[TensorFlowFact]
188+
public void TreatOutputAsBatched()
189+
{
190+
var modelLocation = "cifar_model/frozen_model.pb";
191+
192+
var mlContext = new MLContext(seed: 1);
193+
var imageHeight = 32;
194+
var imageWidth = 32;
195+
var dataFile = GetDataPath("images/images.tsv");
196+
var imageFolder = Path.GetDirectoryName(dataFile);
197+
198+
var data = ML.Data.LoadFromTextFile(dataFile, new[] {
199+
new TextLoader.Column("imagePath", DataKind.String, 0),
200+
new TextLoader.Column("name", DataKind.String, 1)
201+
});
202+
203+
// Note that CamelCase column names are there to match the TF graph node names.
204+
// Check and make sure save/load work correctly for the new TreatOutputAsBatched value.
205+
var pipe = ML.Transforms.LoadImages("Input", imageFolder, "imagePath")
206+
.Append(ML.Transforms.ResizeImages("Input", imageHeight, imageWidth))
207+
.Append(ML.Transforms.ExtractPixels("Input", interleavePixelColors: true))
208+
.Append(ML.Model.LoadTensorFlowModel(modelLocation, false).ScoreTensorFlowModel("Output", "Input"));
209+
210+
TestEstimatorCore(pipe, data);
211+
var schema = pipe.Fit(data).Transform(data).Schema;
212+
213+
// The dimensions of the output with treatOutputAsBatched set to false should be * 10
214+
// as the first dimension of -1 is treated as an unkown dimension.
215+
Assert.Equal(new VectorDataViewType(NumberDataViewType.Single, 0, 10), schema["Output"].Type);
216+
217+
// Note that CamelCase column names are there to match the TF graph node names.
218+
// Test with TreatOutputAsBatched set to default value of true.
219+
pipe = ML.Transforms.LoadImages("Input", imageFolder, "imagePath")
220+
.Append(ML.Transforms.ResizeImages("Input", imageHeight, imageWidth))
221+
.Append(ML.Transforms.ExtractPixels("Input", interleavePixelColors: true))
222+
.Append(ML.Model.LoadTensorFlowModel(modelLocation).ScoreTensorFlowModel("Output", "Input"));
223+
224+
TestEstimatorCore(pipe, data);
225+
schema = pipe.Fit(data).Transform(data).Schema;
226+
227+
// The dimensions of the output with treatOutputAsBatched set to true should be 10
228+
// as the first dimension of -1 is treated as the batch dimension.
229+
Assert.Equal(new VectorDataViewType(NumberDataViewType.Single, 10), schema["Output"].Type);
230+
231+
}
186232

187233
[TensorFlowFact]
188234
public void TestTensorFlowWithSchema()

0 commit comments

Comments
 (0)