-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy pathImageClassification.cs
More file actions
140 lines (121 loc) · 4.98 KB
/
ImageClassification.cs
File metadata and controls
140 lines (121 loc) · 4.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
using System;
using System.IO;
using System.Linq;
using System.Net;
using ICSharpCode.SharpZipLib.GZip;
using ICSharpCode.SharpZipLib.Tar;
using Microsoft.ML;
using Microsoft.ML.Data;
namespace Samples.Dynamic
{
public static class ImageClassification
{
/// <summary>
/// Example use of the TensorFlow image model in a ML.NET pipeline.
/// </summary>
public static void Example()
{
// Download the ResNet 101 model from the location below.
// https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz
string modelLocation = "resnet_v2_101_299_frozen.pb";
if (!File.Exists(modelLocation))
{
modelLocation = Download(@"https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz", @"resnet_v2_101_299_frozen.tgz");
Unzip(Path.Join(Directory.GetCurrentDirectory(), modelLocation),
Directory.GetCurrentDirectory());
modelLocation = "resnet_v2_101_299_frozen.pb";
}
var mlContext = new MLContext();
var data = GetTensorData();
var idv = mlContext.Data.LoadFromEnumerable(data);
// Create a ML pipeline.
using var model = mlContext.Model.LoadTensorFlowModel(modelLocation);
var pipeline = model.ScoreTensorFlowModel(
new[] { nameof(OutputScores.output) },
new[] { nameof(TensorData.input) }, addBatchDimensionInput: true);
// Run the pipeline and get the transformed values.
var estimator = pipeline.Fit(idv);
var transformedValues = estimator.Transform(idv);
// Retrieve model scores.
var outScores = mlContext.Data.CreateEnumerable<OutputScores>(
transformedValues, reuseRowObject: false);
// Display scores. (for the sake of brevity we display scores of the
// first 3 classes)
foreach (var prediction in outScores)
{
int numClasses = 0;
foreach (var classScore in prediction.output.Take(3))
{
Console.WriteLine(
$"Class #{numClasses++} score = {classScore}");
}
Console.WriteLine(new string('-', 10));
}
// Results look like below...
//Class #0 score = -0.8092947
//Class #1 score = -0.3310375
//Class #2 score = 0.1119193
//----------
//Class #0 score = -0.7807726
//Class #1 score = -0.2158062
//Class #2 score = 0.1153686
//----------
}
private const int imageHeight = 224;
private const int imageWidth = 224;
private const int numChannels = 3;
private const int inputSize = imageHeight * imageWidth * numChannels;
/// <summary>
/// A class to hold sample tensor data.
/// Member name should match the inputs that the model expects (in this
/// case, input).
/// </summary>
public class TensorData
{
[VectorType(imageHeight, imageWidth, numChannels)]
public float[] input { get; set; }
}
/// <summary>
/// Method to generate sample test data. Returns 2 sample rows.
/// </summary>
public static TensorData[] GetTensorData()
{
// This can be any numerical data. Assume image pixel values.
var image1 = Enumerable.Range(0, inputSize).Select(
x => (float)x / inputSize).ToArray();
var image2 = Enumerable.Range(0, inputSize).Select(
x => (float)(x + 10000) / inputSize).ToArray();
return new TensorData[] { new TensorData() { input = image1 },
new TensorData() { input = image2 } };
}
/// <summary>
/// Class to contain the output values from the transformation.
/// </summary>
class OutputScores
{
public float[] output { get; set; }
}
private static string Download(string baseGitPath, string dataFile)
{
using (WebClient client = new WebClient())
{
client.DownloadFile(new Uri($"{baseGitPath}"), dataFile);
}
return dataFile;
}
/// <summary>
/// Taken from
/// https://github.com/icsharpcode/SharpZipLib/wiki/GZip-and-Tar-Samples.
/// </summary>
private static void Unzip(string path, string targetDir)
{
Stream inStream = File.OpenRead(path);
Stream gzipStream = new GZipInputStream(inStream);
TarArchive tarArchive = TarArchive.CreateInputTarArchive(gzipStream);
tarArchive.ExtractContents(targetDir);
tarArchive.Close();
gzipStream.Close();
inStream.Close();
}
}
}