-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy pathTensorFlowModel.cs
More file actions
103 lines (91 loc) · 4.78 KB
/
TensorFlowModel.cs
File metadata and controls
103 lines (91 loc) · 4.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using Microsoft.ML.Runtime;
using Microsoft.ML.TensorFlow;
using Tensorflow;
namespace Microsoft.ML.Transforms
{
/// <summary>
/// This class holds the information related to TensorFlow model and session.
/// It provides some convenient methods to query model schema as well as
/// creation of <see cref="TensorFlowEstimator"/> object.
/// </summary>
public sealed class TensorFlowModel : IDisposable
{
internal Session Session { get; }
internal string ModelPath { get; }
private readonly IHostEnvironment _env;
/// <summary>
/// Instantiates <see cref="TensorFlowModel"/>.
/// </summary>
/// <param name="env">An <see cref="IHostEnvironment"/> object.</param>
/// <param name="session">TensorFlow session object.</param>
/// <param name="modelLocation">Location of the model from where <paramref name="session"/> was loaded.</param>
internal TensorFlowModel(IHostEnvironment env, Session session, string modelLocation)
{
Session = session;
ModelPath = modelLocation;
_env = env;
_disposed = false;
}
/// <summary>
/// Get <see cref="DataViewSchema"/> for complete model. Every node in the TensorFlow model will be included in the <see cref="DataViewSchema"/> object.
/// </summary>
public DataViewSchema GetModelSchema()
{
return TensorFlowUtils.GetModelSchema(_env, Session.graph);
}
/// <summary>
/// Get <see cref="DataViewSchema"/> for only those nodes which are marked "Placeholder" in the TensorFlow model.
/// This method is convenient for exploring the model input(s) in case TensorFlow graph is very large.
/// </summary>
public DataViewSchema GetInputSchema()
{
return TensorFlowUtils.GetModelSchema(_env, Session.graph, "Placeholder");
}
/// <summary>
/// Scores a dataset using a pre-trained <a href="https://www.tensorflow.org/">TensorFlow</a> model.
/// </summary>
/// <param name="inputColumnName"> The name of the model input. The data type is a vector of <see cref="System.Single"/>.</param>
/// <param name="outputColumnName">The name of the requested model output. The data type is a vector of <see cref="System.Single"/></param>
/// <param name="addBatchDimensionInput">Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].
/// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well.</param>
/// <example>
/// <format type="text/markdown">
/// <]
/// ]]>
/// </format>
/// </example>
public TensorFlowEstimator ScoreTensorFlowModel(string outputColumnName, string inputColumnName, bool addBatchDimensionInput = false)
=> new TensorFlowEstimator(_env, new[] { outputColumnName }, new[] { inputColumnName }, this, addBatchDimensionInput);
/// <summary>
/// Scores a dataset using a pre-trained TensorFlow model.
/// </summary>
/// <param name="inputColumnNames"> The names of the model inputs.</param>
/// <param name="outputColumnNames">The names of the requested model outputs.</param>
/// <param name="addBatchDimensionInput">Add a batch dimension to the input e.g. input = [224, 224, 3] => [-1, 224, 224, 3].
/// This parameter is used to deal with models that have unknown shape but the internal operators in the model require data to have batch dimension as well.</param>
/// <example>
/// <format type="text/markdown">
/// <]
/// ]]>
/// </format>
/// </example>
public TensorFlowEstimator ScoreTensorFlowModel(string[] outputColumnNames, string[] inputColumnNames, bool addBatchDimensionInput = false)
=> new TensorFlowEstimator(_env, outputColumnNames, inputColumnNames, this, addBatchDimensionInput);
#region IDisposable Support
private bool _disposed;
public void Dispose()
{
if (_disposed)
return;
Session.Dispose();
_disposed = true;
}
#endregion
}
}