Skip to content

Commit 9d41814

Browse files
committed
Ngram with uint16 input fix
1 parent d6e64cb commit 9d41814

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

src/Microsoft.ML.Transforms/Text/NgramTransform.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ public void SaveAsOnnx(OnnxContext ctx)
768768
}
769769
}
770770

771-
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName )
771+
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
772772
{
773773
VBuffer<ReadOnlyMemory<char>> slotNames = default;
774774
GetSlotNames(iinfo, 0, ref slotNames);
@@ -777,13 +777,15 @@ private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName,
777777

778778
// TfIdfVectorizer accepts strings, int32 and int64 tensors.
779779
// But in the ML.NET implementation of the NGramTransform, it only accepts keys as inputs
780-
// That are the result of ValueToKeyMapping transformer, which outputs UInt32 values
781-
// So, if it is UInt32 or UInt64, cast the output here to Int32/Int64
780+
// That are the result of ValueToKeyMapping transformer, which outputs UInt32 values,
781+
// Or TokenizingByCharacters, which outputs UInt16 values
782+
// So, if it is UInt32, UInt64, or UInt16, cast the output here to Int32/Int64
782783
string opType;
783784
var vectorType = _srcTypes[iinfo] as VectorDataViewType;
784785

785786
if ((vectorType != null) &&
786-
((vectorType.RawType == typeof(VBuffer<UInt32>)) || (vectorType.RawType == typeof(VBuffer<UInt64>))))
787+
((vectorType.RawType == typeof(VBuffer<UInt32>)) || (vectorType.RawType == typeof(VBuffer<UInt64>)) ||
788+
(vectorType.RawType == typeof(VBuffer<UInt16>))))
787789
{
788790
var dataKind = _srcTypes[iinfo] == NumberDataViewType.UInt32 ? DataKind.Int32 : DataKind.Int64;
789791

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,7 @@ public void WordTokenizerOnnxConversionTest()
12051205

12061206
[Theory]
12071207
[CombinatorialData]
1208-
public void NgramOnnxConnversionTest(
1208+
public void NgramOnnxConversionTest(
12091209
[CombinatorialValues(1, 2, 3)] int ngramLength,
12101210
bool useAllLength,
12111211
NgramExtractingEstimator.WeightingCriteria weighting)
@@ -1231,6 +1231,12 @@ public void NgramOnnxConnversionTest(
12311231
useAllLengths: useAllLength,
12321232
weighting: weighting)),
12331233

1234+
mlContext.Transforms.Text.TokenizeIntoCharactersAsKeys("Tokens", "Text")
1235+
.Append(mlContext.Transforms.Text.ProduceNgrams("NGrams", "Tokens",
1236+
ngramLength: ngramLength,
1237+
useAllLengths: useAllLength,
1238+
weighting: weighting)),
1239+
12341240
mlContext.Transforms.Text.ProduceWordBags("Tokens", "Text",
12351241
ngramLength: ngramLength,
12361242
useAllLengths: useAllLength,
@@ -1255,10 +1261,9 @@ public void NgramOnnxConnversionTest(
12551261
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxFilePath);
12561262
var onnxTransformer = onnxEstimator.Fit(dataView);
12571263
var onnxResult = onnxTransformer.Transform(dataView);
1258-
CompareSelectedR4VectorColumns(transformedData.Schema[3].Name, outputNames[outputNames.Length-1], transformedData, onnxResult, 3);
1264+
CompareSelectedR4VectorColumns(transformedData.Schema[transformedData.Schema.Count-1].Name, outputNames[outputNames.Length-1], transformedData, onnxResult, 3); //comparing Ngrams
12591265
}
12601266
}
1261-
12621267
Done();
12631268
}
12641269

0 commit comments

Comments
 (0)