-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy pathCompositeRowToRowMapper.cs
More file actions
138 lines (117 loc) · 5.92 KB
/
CompositeRowToRowMapper.cs
File metadata and controls
138 lines (117 loc) · 5.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
namespace Microsoft.ML.Data
{
/// <summary>
/// A row-to-row mapper that is the result of a chained application of multiple mappers.
/// </summary>
[BestFriend]
internal sealed class CompositeRowToRowMapper : IRowToRowMapper, IDisposable
{
[BestFriend]
internal IRowToRowMapper[] InnerMappers { get; }
private static readonly IRowToRowMapper[] _empty = new IRowToRowMapper[0];
public DataViewSchema InputSchema { get; }
public DataViewSchema OutputSchema { get; }
/// <summary>
/// Out of a series of mappers, construct a seemingly unitary mapper that is able to apply them in sequence.
/// </summary>
/// <param name="inputSchema">The input schema.</param>
/// <param name="mappers">The sequence of mappers to wrap. An empty or <c>null</c> argument
/// is legal, and counts as being a no-op application.</param>
public CompositeRowToRowMapper(DataViewSchema inputSchema, IRowToRowMapper[] mappers)
{
Contracts.CheckValue(inputSchema, nameof(inputSchema));
Contracts.CheckValueOrNull(mappers);
InnerMappers = Utils.Size(mappers) > 0 ? mappers : _empty;
InputSchema = inputSchema;
OutputSchema = Utils.Size(mappers) > 0 ? mappers[mappers.Length - 1].OutputSchema : inputSchema;
_disposed = false;
}
/// <summary>
/// Given a set of columns, return the input columns that are needed to generate those output columns.
/// </summary>
IEnumerable<DataViewSchema.Column> IRowToRowMapper.GetDependencies(IEnumerable<DataViewSchema.Column> columnsNeeded)
{
for (int i = InnerMappers.Length - 1; i >= 0; --i)
columnsNeeded = InnerMappers[i].GetDependencies(columnsNeeded);
return columnsNeeded;
}
DataViewRow IRowToRowMapper.GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns)
{
Contracts.CheckValue(input, nameof(input));
Contracts.CheckValue(activeColumns, nameof(activeColumns));
Contracts.CheckParam(input.Schema == InputSchema, nameof(input), "Schema did not match original schema");
var activeIndices = activeColumns.Select(c => c.Index).ToArray();
if (InnerMappers.Length == 0)
{
bool differentActive = false;
for (int c = 0; c < input.Schema.Count; ++c)
{
bool wantsActive = activeIndices.Contains(c);
bool isActive = input.IsColumnActive(input.Schema[c]);
differentActive |= wantsActive != isActive;
if (wantsActive && !isActive)
throw Contracts.ExceptParam(nameof(input), $"Mapper required column '{input.Schema[c].Name}' active but it was not.");
}
return input;
}
// For each of the inner mappers, we will be calling their GetRow method, but to do so we need to know
// what we need from them. The last one will just have the input, but the rest will need to be
// computed based on the dependencies of the next one in the chain.
IEnumerable<DataViewSchema.Column>[] deps = new IEnumerable<DataViewSchema.Column>[InnerMappers.Length];
deps[deps.Length - 1] = OutputSchema.Where(c => activeIndices.Contains(c.Index));
for (int i = deps.Length - 1; i >= 1; --i)
deps[i - 1] = InnerMappers[i].GetDependencies(deps[i]);
DataViewRow result = input;
for (int i = 0; i < InnerMappers.Length; ++i)
result = InnerMappers[i].GetRow(result, deps[i]);
return result;
}
private sealed class SubsetActive : DataViewRow
{
private readonly DataViewRow _row;
private Func<int, bool> _pred;
public SubsetActive(DataViewRow row, Func<int, bool> pred)
{
Contracts.AssertValue(row);
Contracts.AssertValue(pred);
_row = row;
_pred = pred;
}
public override DataViewSchema Schema => _row.Schema;
public override long Position => _row.Position;
public override long Batch => _row.Batch;
/// <summary>
/// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
/// This throws if the column is not active in this row, or if the type
/// <typeparamref name="TValue"/> differs from this column's type.
/// </summary>
/// <typeparam name="TValue"> is the column's content type.</typeparam>
/// <param name="column"> is the output column whose getter should be returned.</param>
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column) => _row.GetGetter<TValue>(column);
public override ValueGetter<DataViewRowId> GetIdGetter() => _row.GetIdGetter();
/// <summary>
/// Returns whether the given column is active in this row.
/// </summary>
public override bool IsColumnActive(DataViewSchema.Column column) => _pred(column.Index);
}
#region IDisposable Support
private bool _disposed;
void IDisposable.Dispose()
{
if (_disposed)
return;
foreach (var mapper in InnerMappers)
(mapper as IDisposable)?.Dispose();
_disposed = true;
}
#endregion
}
}