Skip to content

Commit 82d35de

Browse files
Ryanunderhill/rel 1.1.0 (#2651)
* Add missig env variables for mac pipeline test (#2595) * Java API for onnxruntime (#2215) * Rename automl python tools folder to featurizer_ops. (#2593) * Make sure fenced tensor could not reuse other tensor. (#2561) * Add support for opset 11 in reshape fusion (#2592) * Support opset 11 subgraph of Squad model in Embed Layer Normalization (#2605) * Allow providers to be set for InferenceSession at construction (#2606) * EmbedLayerNormalization Fusion For Dynamic Squad Model Opset 10 (#2613) * Improve Embed Layer Norm Fusion for SQuAD with static input shape (#2621) * Improve cuda expand() opeator's performance. (#2624) * Cuda pad optimize when no padding is needed. (#2625) * Shortcut cuda Pad() when no padding is needed. * Improve performance of resize() in Nearest mode (#2626) * Optimize cuda scatter() on 2D compatible. (#2628) * Optimize cuda scatter() on 2D compatible. * fix float16 comparison in initializer (#2629) * epsilon attribute for layernormalization fusion (#2639) * Fix memory exception in Layer Norm Fusion (#2644)
1 parent 6049de8 commit 82d35de

22 files changed

+1369
-372
lines changed

onnxruntime/core/framework/allocation_planner.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ class PlannerImpl {
325325
auto p_required_buffer_shape = context_.GetShape(output_arg);
326326
if (nullptr == p_required_buffer_shape) return false;
327327
auto& required_memory_info = AllocPlan(output_arg.Name()).location;
328+
if (HasFence(&output_arg)) return false;
328329

329330
for (auto it = freelist_.begin(); it != freelist_.end(); ++it) {
330331
size_t reusable = static_cast<size_t>(it->ml_value);

onnxruntime/core/optimizer/embed_layer_norm_fusion.cc

Lines changed: 483 additions & 124 deletions
Large diffs are not rendered by default.

onnxruntime/core/optimizer/layer_norm_fusion.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
153153
continue;
154154
}
155155
nodes_to_remove.push_back(add2_node);
156-
157156
// Traceback the add node to find reduceMean --> add
158157
const Node* p_reduce_mean2 = nullptr;
159158

@@ -255,6 +254,15 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
255254
layer_norm_input_defs,
256255
{}, {}, kOnnxDomain);
257256

257+
// Get constant "epsilon" from "Add2" node if available. Else, default value will be used.
258+
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, add2_node.MutableInputDefs()[1]->Name());
259+
if (tensor_proto != nullptr) {
260+
if (tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
261+
auto initializer = onnxruntime::make_unique<Initializer>(*tensor_proto);
262+
layer_norm_node.AddAttribute("epsilon", initializer->data<float>()[0]);
263+
}
264+
}
265+
258266
// Assign provider to this new node. Provider should be same as the provider for old node.
259267
layer_norm_node.SetExecutionProviderType(reduce_mean_node.GetExecutionProviderType());
260268

onnxruntime/core/optimizer/reshape_fusion.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ bool ReshapeFusion::Fuse_Subgraph1(Node& reshape, Graph& graph, const logging::L
7272
}
7373
const Node& concat = *p_concat;
7474

75-
if (!graph_utils::IsSupportedOptypeVersionAndDomain(concat, "Concat", {1, 4})) {
75+
if (!graph_utils::IsSupportedOptypeVersionAndDomain(concat, "Concat", {1, 4, 11})) {
7676
return false;
7777
}
7878

@@ -83,8 +83,8 @@ bool ReshapeFusion::Fuse_Subgraph1(Node& reshape, Graph& graph, const logging::L
8383

8484
// path 1: [Root] --> Shape --> Gather(indices=0) --> Unsqueeze (axes=0) --> Concat [input 0]
8585
std::vector<graph_utils::EdgeEndToMatch> parent_path{
86-
{0, 0, "Unsqueeze", {1}, kOnnxDomain},
87-
{0, 0, "Gather", {1}, kOnnxDomain},
86+
{0, 0, "Unsqueeze", {1, 11}, kOnnxDomain},
87+
{0, 0, "Gather", {1, 11}, kOnnxDomain},
8888
{0, 0, "Shape", {1}, kOnnxDomain}};
8989

9090
std::vector<const Node::EdgeEnd*> edges;
@@ -114,8 +114,8 @@ bool ReshapeFusion::Fuse_Subgraph1(Node& reshape, Graph& graph, const logging::L
114114

115115
// path 2: [Root] --> Shape --> Gather(indices=1) --> Unsqueeze (axes=0) --> Concat [input 1]
116116
std::vector<graph_utils::EdgeEndToMatch> parent_path2 {
117-
{0, 1, "Unsqueeze", {1}, kOnnxDomain},
118-
{0, 0, "Gather", {1}, kOnnxDomain},
117+
{0, 1, "Unsqueeze", {1, 11}, kOnnxDomain},
118+
{0, 0, "Gather", {1, 11}, kOnnxDomain},
119119
{0, 0, "Shape", {1}, kOnnxDomain}};
120120

121121
if (!graph_utils::FindPath(concat, true, parent_path2, edges, logger)) {

onnxruntime/core/optimizer/utils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ bool IsInitializerWithExpectedValue(const Graph& graph, const NodeArg& input_arg
6363
}
6464
} else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
6565
const MLFloat16* val = init_const->data<MLFloat16>();
66-
float diff = std::abs(math::halfToFloat(val[0].val) - static_cast<float>(expected_value));
66+
float diff = std::abs(math::halfToFloat(val[0].val) - math::halfToFloat(math::floatToHalf(expected_value)));
6767
if (diff > FLT_EPSILON) {
6868
return false;
6969
}

onnxruntime/core/providers/cuda/tensor/expand.cc

Lines changed: 80 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,69 +5,105 @@
55
#include "expand_impl.h"
66
#include "core/providers/cpu/tensor/utils.h"
77

8+
using std::vector;
9+
810
namespace onnxruntime {
911
namespace cuda {
1012

13+
// Logically expanded y could just be a view of x.
14+
static void CalcEffectiveDims(vector<int64_t>& x_dims, vector<int64_t>& y_dims) {
15+
vector<int64_t> x_reverse;
16+
vector<int64_t> y_reverse;
17+
18+
int xi = gsl::narrow_cast<int>(x_dims.size()) - 1;
19+
for (int yi = gsl::narrow_cast<int>(y_dims.size()) - 1; yi >= 0; --yi, --xi) {
20+
int64_t xdim = (xi >= 0) ? x_dims[xi] : 1;
21+
int64_t ydim = y_dims[yi];
22+
if (xdim == ydim || xdim == 1) {
23+
x_reverse.push_back(xdim);
24+
y_reverse.push_back(ydim);
25+
}
26+
else { // xdim < ydim && xdim > 1, split
27+
ydim /= xdim;
28+
x_reverse.push_back(xdim);
29+
y_reverse.push_back(xdim);
30+
x_reverse.push_back(1);
31+
y_reverse.push_back(ydim);
32+
}
33+
}
34+
35+
x_dims.clear();
36+
y_dims.clear();
37+
x_dims.push_back(1);
38+
y_dims.push_back(1);
39+
// compact the dims, remove (x=1, y=1), merge (x=1, y1*y2...)
40+
for (int i = gsl::narrow_cast<int>(y_reverse.size()) - 1; i >= 0; --i) {
41+
if (x_reverse[i] == 1) {
42+
if (y_reverse[i] == 1) {
43+
continue;
44+
}
45+
if (x_dims.back() == 1) {
46+
y_dims.back() *= y_reverse[i];
47+
}
48+
else {
49+
x_dims.push_back(1);
50+
y_dims.push_back(y_reverse[i]);
51+
}
52+
}
53+
else { // x_reverse[i] == y_reverse[i]
54+
if (x_dims.back() == y_dims.back()) {
55+
x_dims.back() *= x_reverse[i];
56+
y_dims.back() *= y_reverse[i];
57+
}
58+
else {
59+
x_dims.push_back(x_reverse[i]);
60+
y_dims.push_back(y_reverse[i]);
61+
}
62+
}
63+
}
64+
}
65+
1166
Status Expand::ComputeInternal(OpKernelContext* ctx) const {
12-
const auto& input0 = *ctx->Input<Tensor>(0);
13-
const auto& input1 = *ctx->Input<Tensor>(1);
67+
const auto& input_data_tensor = *ctx->Input<Tensor>(0);
68+
const auto& input_shape_tensor = *ctx->Input<Tensor>(1);
1469

1570
// new shape to be expanded to
16-
const auto* p_shape = input1.template Data<int64_t>();
17-
std::vector<int64_t> output_dims{p_shape, p_shape + input1.Shape().Size()};
71+
const auto* p_shape = input_shape_tensor.template Data<int64_t>();
72+
std::vector<int64_t> output_dims{p_shape, p_shape + input_shape_tensor.Shape().Size()};
1873
TensorShape output_shape(output_dims);
1974

20-
ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input0.Shape(), output_dims, output_shape));
21-
auto rank = output_shape.NumDimensions();
75+
ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input_data_tensor.Shape(), output_dims, output_shape));
2276
auto& output_tensor = *ctx->Output(0, output_shape);
23-
2477
if (0 == output_shape.Size()) {
2578
return Status::OK();
2679
}
2780

28-
auto input_shape = input0.Shape().GetDims();
81+
output_dims = output_shape.GetDims();
82+
auto input_dims = input_data_tensor.Shape().GetDims();
2983

30-
// pad input_dims with 1 to make ranks match
31-
for (size_t i = 0; i < rank - input_shape.size(); i++) {
32-
input_shape.insert(input_shape.begin(), 1);
33-
}
84+
CalcEffectiveDims(input_dims, output_dims);
85+
int rank = gsl::narrow_cast<int>(output_dims.size());
3486

35-
// create fast_divmod using dimension values
36-
CudaAsyncBuffer<fast_divmod> fdm_input_dims(this, rank);
37-
CudaAsyncBuffer<fast_divmod> fdm_output_dims(this, rank);
38-
CudaAsyncBuffer<fast_divmod> fdm_output_subdim_size(this, rank);
39-
{
40-
auto in_span = fdm_input_dims.CpuSpan();
41-
auto out_span = fdm_output_dims.CpuSpan();
42-
auto sdm_span = fdm_output_subdim_size.CpuSpan();
43-
auto subdim_size = output_shape.Size();
44-
for (size_t i = 0; i < rank; i++) {
45-
in_span[i] = fast_divmod(static_cast<int>(input_shape[i]));
46-
out_span[i] = fast_divmod(static_cast<int>(output_shape[i]));
47-
// output_shape[i] won't be 0 here, it's covered in (0 == output_shape.Size())
48-
// a null output will be returned for that case
49-
subdim_size /= output_shape[i];
50-
sdm_span[i] = static_cast<int>(subdim_size);
51-
}
87+
CudaAsyncBuffer<fast_divmod> fdm_output_strides(this, rank);
88+
ORT_ENFORCE(CalculateFdmStrides(fdm_output_strides.CpuSpan(), output_dims));
89+
90+
CudaAsyncBuffer<int64_t> input_view_strides(this, rank);
91+
TensorPitches::Calculate(input_view_strides.CpuSpan(), input_dims);
92+
for (int i = 0; i < rank; ++i) {
93+
if (input_dims[i] == 1) input_view_strides.CpuSpan()[i] = 0;
5294
}
53-
ORT_RETURN_IF_ERROR(fdm_input_dims.CopyToGpu());
54-
ORT_RETURN_IF_ERROR(fdm_output_dims.CopyToGpu());
55-
ORT_RETURN_IF_ERROR(fdm_output_subdim_size.CopyToGpu());
56-
57-
ExpandImpl(
58-
input0.DataType()->Size(),
59-
output_shape.NumDimensions(),
60-
output_shape.Size(),
61-
input0.Shape().Size(),
62-
input0.DataRaw(),
63-
output_tensor.MutableDataRaw(),
64-
fdm_input_dims.GpuPtr(),
65-
fdm_output_dims.GpuPtr(),
66-
fdm_output_subdim_size.GpuPtr());
6795

68-
return Status::OK();
96+
return ExpandImpl(
97+
input_data_tensor.DataType()->Size(),
98+
gsl::narrow_cast<int>(output_shape.Size()),
99+
gsl::narrow_cast<int>(input_data_tensor.Shape().Size()),
100+
input_data_tensor.DataRaw(),
101+
output_tensor.MutableDataRaw(),
102+
fdm_output_strides,
103+
input_view_strides);
69104
}
70105

106+
71107
ONNX_OPERATOR_KERNEL_EX(
72108
Expand,
73109
kOnnxDomain,

0 commit comments

Comments
 (0)