diff --git a/.gitmodules b/.gitmodules index fd892b2a58718..accbde3a0c93d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -40,12 +40,12 @@ [submodule "cmake/external/cub"] path = cmake/external/cub url = https://github.com/NVlabs/cub.git -[submodule "cmake/external/onnx-tensorrt"] - path = cmake/external/onnx-tensorrt - url = https://github.com/onnx/onnx-tensorrt.git [submodule "cmake/external/wil"] path = cmake/external/wil url = https://github.com/microsoft/wil +[submodule "cmake/external/onnx-tensorrt"] + path = cmake/external/onnx-tensorrt + url = https://github.com/onnx/onnx-tensorrt.git [submodule "cmake/external/json"] path = cmake/external/json url = https://github.com/nlohmann/json diff --git a/BUILD.md b/BUILD.md index 5bef7adbfbdac..9a360c3104010 100644 --- a/BUILD.md +++ b/BUILD.md @@ -189,7 +189,7 @@ See more information on the TensorRT Execution Provider [here](./docs/execution_ * The path to the CUDA `bin` directory must be added to the PATH environment variable so that `nvcc` is found. * The path to the cuDNN installation (path to folder that contains libcudnn.so) must be provided via the cuDNN_PATH environment variable, or `--cudnn_home parameter`. * Install [TensorRT](https://developer.nvidia.com/nvidia-tensorrt-download) - * The TensorRT execution provider for ONNX Runtime is built and tested with TensorRT 6.0.1.5 but validated with the feature set equivalent to TensorRT 5. Some TensorRT 6 new features such as dynamic shape is not available at this time. + * The TensorRT execution provider for ONNX Runtime is built and tested with TensorRT 6.0.1.5. * The path to TensorRT installation must be provided via the `--tensorrt_home parameter`. #### Build Instructions diff --git a/cmake/external/onnx-tensorrt b/cmake/external/onnx-tensorrt index bba4dee184cc0..8716c9b32dcc9 160000 --- a/cmake/external/onnx-tensorrt +++ b/cmake/external/onnx-tensorrt @@ -1 +1 @@ -Subproject commit bba4dee184cc03d6fd5086c90d974537e72eba23 +Subproject commit 8716c9b32dcc947287f2ede9ef7d563601bb2ee0 diff --git a/docs/execution_providers/TensorRT-ExecutionProvider.md b/docs/execution_providers/TensorRT-ExecutionProvider.md index 2ff12ad75764b..d516188d6c8c0 100644 --- a/docs/execution_providers/TensorRT-ExecutionProvider.md +++ b/docs/execution_providers/TensorRT-ExecutionProvider.md @@ -7,11 +7,11 @@ With the TensorRT execution provider, the ONNX Runtime delivers better inferenci ## Build For build instructions, please see the [BUILD page](../../BUILD.md#tensorrt). -The TensorRT execution provider for ONNX Runtime is built and tested with TensorRT 6.0.1.5 but validated with the feature set equivalent to TensorRT 5. Some TensorRT 6 new features such as dynamic shape is not available as this time. +The TensorRT execution provider for ONNX Runtime is built and tested with TensorRT 6.0.1.5. ## Using the TensorRT execution provider ### C/C++ -The TensortRT execution provider needs to be registered with ONNX Runtime to enable in the inference session. +The TensorRT execution provider needs to be registered with ONNX Runtime to enable in the inference session. ``` InferenceSession session_object{so}; session_object.RegisterExecutionProvider(std::make_unique<::onnxruntime::TensorrtExecutionProvider>()); @@ -19,8 +19,23 @@ status = session_object.Load(model_file_name); ``` The C API details are [here](../C_API.md#c-api). +#### Sample +To run Faster R-CNN model on TensorRT execution provider, + +First, download Faster R-CNN onnx model from onnx model zoo [here](https://github.com/onnx/models/tree/master/vision/object_detection_segmentation/faster-rcnn). + +Second, infer shapes in the model by running shape inference script [here](https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/nuphar/scripts/symbolic_shape_infer.py), +``` +python symbolic_shape_infer.py --input /path/to/onnx/model/model.onnx --output /path/to/onnx/model/new_model.onnx --auto_merge +``` + +Third, replace original model with the new model and run onnx_test_runner tool under ONNX Runtime build directory, +``` +./onnx_test_runner -e tensorrt /path/to/onnx/model/ +``` + ### Python -When using the Python wheel from the ONNX Runtime build with TensorRT execution provider, it will be automatically prioritized over the default GPU or CPU execution providers. There is no need to separately register the execution provider. Python APIs details are [here](https://microsoft.github.io/onnxruntime/api_summary.html). +When using the Python wheel from the ONNX Runtime build with TensorRT execution provider, it will be automatically prioritized over the default GPU or CPU execution providers. There is no need to separately register the execution provider. Python APIs details are . #### Sample Please see [this Notebook](../python/notebooks/onnx-inference-byoc-gpu-cpu-aks.ipynb) for an example of running a model on GPU using ONNX Runtime through Azure Machine Learning Services. @@ -30,14 +45,25 @@ For performance tuning, please see guidance on this page: [ONNX Runtime Perf Tun When/if using [onnxruntime_perf_test](../../onnxruntime/test/perftest#onnxruntime-performance-test), use the flag `-e tensorrt` -## Configuring Engine Max Batch Size and Workspace Size -By default TensorRT execution provider builds an ICudaEngine with max batch size = 1 and max workspace size = 1 GB -One can override these defaults by setting environment variables ORT_TENSORRT_MAX_BATCH_SIZE and ORT_TENSORRT_MAX_WORKSPACE_SIZE. -e.g. on Linux +## Configuring environment variables +There are three environment variables for TensorRT execution provider. + +ORT_TENSORRT_MAX_WORKSPACE_SIZE: maximum workspace size for TensorRT engine. + +ORT_TENSORRT_MAX_PARTITION_ITERATIONS: maximum number of iterations allowed in model partitioning for TensorRT. If target model can't be successfully partitioned when the maximum number of iterations is reached, the whole model will fall back to other execution providers such as CUDA or CPU. -### override default batch size to 10 -export ORT_TENSORRT_MAX_BATCH_SIZE=10 +ORT_TENSORRT_MIN_SUBGRAPH_SIZE: minimum node size in a subgraph after partitioning. Subgraphs with smaller size will fall back to other execution providers. + +By default TensorRT execution provider builds an ICudaEngine with max workspace size = 1 GB, max partition iterations = 1000 and min subgraph size = 1. + +One can override these defaults by setting environment variables ORT_TENSORRT_MAX_WORKSPACE_SIZE, ORT_TENSORRT_MAX_PARTITION_ITERATIONS and ORT_TENSORRT_MIN_SUBGRAPH_SIZE. +e.g. on Linux ### override default max workspace size to 2GB export ORT_TENSORRT_MAX_WORKSPACE_SIZE=2147483648 +### override default maximum number of iterations to 10 +export ORT_TENSORRT_MAX_PARTITION_ITERATIONS=10 + +### override default minimum subgraph node size to 5 +export ORT_TENSORRT_MIN_SUBGRAPH_SIZE=5 diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 9e5eaaed02d44..028a34e09d2e0 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -89,9 +89,21 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv allocator_ = CreateAllocator(default_memory_info, device_id_); InsertAllocator(allocator_); - DeviceAllocatorRegistrationInfo pinned_memory_info( + DeviceAllocatorRegistrationInfo pinned_allocator_info( {OrtMemTypeCPUOutput, [](int) { return onnxruntime::make_unique(0, TRT_PINNED); }, std::numeric_limits::max()}); - InsertAllocator(CreateAllocator(pinned_memory_info, device_id_)); + InsertAllocator(CreateAllocator(pinned_allocator_info, device_id_)); + + const char* batch_env = getenv("ORT_TENSORRT_MAX_PARTITION_ITERATIONS"); + if (batch_env) + max_partition_iterations_ = atoi(batch_env); + + const char* subgraph_size_env = getenv("ORT_TENSORRT_MIN_SUBGRAPH_SIZE"); + if (subgraph_size_env) + min_subgraph_size_ = atoi(subgraph_size_env); + + const char* workspace_env = getenv("ORT_TENSORRT_MAX_WORKSPACE_SIZE"); + if (workspace_env) + max_workspace_size_ = atoi(workspace_env); } TensorrtExecutionProvider::~TensorrtExecutionProvider() {} @@ -108,6 +120,33 @@ std::unique_ptr TensorrtExecutionProvider::GetDataTr return onnxruntime::make_unique(); } +void ToGraphProtoInternal(const onnxruntime::GraphViewer& graph, ONNX_NAMESPACE::GraphProto& graph_proto) { //const + for (const auto* input_arg : graph.GetInputs()) { + *(graph_proto.mutable_input()->Add()) = input_arg->ToProto(); + } + + // Add all graph's initializers to the subgraph + const auto& init_tensors = graph.GetAllInitializedTensors(); + for (const auto& tensor : init_tensors) { + *(graph_proto.mutable_initializer()->Add()) = *(tensor.second); + } + + for (const auto* output_arg : graph.GetOutputs()) { + *(graph_proto.mutable_output()->Add()) = output_arg->ToProto(); + } + + for (const auto* value_info : graph.GetValueInfo()) { + *(graph_proto.mutable_value_info()->Add()) = value_info->ToProto(); + } + + // Nodes must be sorted in Topological Order in the GraphProto per ONNX spec. + for (auto& node_idx : graph.GetNodesInTopologicalOrder()) { + const gsl::not_null node_proto{graph_proto.add_node()}; + const gsl::not_null p_node{graph.GetNode(node_idx)}; + p_node->ToProto(*node_proto); + } +} + std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph_t graph_nodes_index, int& kernels_index, const onnxruntime::GraphViewer& graph) const { const std::vector& node_index = graph.GetNodesInTopologicalOrder(); std::unordered_set node_set; @@ -115,10 +154,16 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph for (const auto& index : graph_nodes_index.first) { node_set.insert(node_index[index]); } - std::unique_ptr sub_graph = onnxruntime::make_unique(); + + // Get parent graph output names + std::unordered_set graph_output_names; + for (const auto* output_arg : graph.GetOutputs()) { + graph_output_names.insert(output_arg->Name()); + } // Find inputs and outputs of the subgraph - std::unordered_map fused_inputs, fused_outputs, fused_outputs_to_add; + std::unique_ptr sub_graph = onnxruntime::make_unique(); + std::unordered_map fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; std::unordered_set erased; int input_order = 0; int output_order = 0; @@ -132,15 +177,17 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph fused_outputs.erase(it); erased.insert(input); } else if (erased.find(input) == erased.end()) { - //only when input is neither in output list nor erased list, add the input to input list + // Only when input is neither in output list nor erased list, add the input to input list fused_inputs[input] = input_order++; } } - // For output searching, there is a special case: - // If node's OutputEdges are more than its outputs, meaning certain output is used more than once, + // For output searching, there is two special cases, + // One is, if node's OutputEdges are more than its outputs, meaning certain output is used more than once, // if the output is connected to nodes that don't belong to the subgraph, the output need to be added // to the output list + // The other one is, if subgraph's node output is parent graph's output. the node output should + // be also added to the subgraph's output list if (node->GetOutputEdgesCount() > node->OutputDefs().size()) { for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { const auto& node_idx = it->GetNode().Index(); @@ -151,6 +198,9 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph fused_inputs.erase(iter); erased.insert(output); } else if (erased.find(output) == erased.end()) { + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + graph_outputs_to_add[output] = output_order; + } fused_outputs[output] = output_order++; } } else { @@ -164,8 +214,11 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph fused_inputs.erase(it); erased.insert(output); } - // only when output is neither in input list nor erased list, add the output to output list + // Only when output is neither in input list nor erased list, add the output to output list else if (erased.find(output) == erased.end()) { + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + graph_outputs_to_add[output] = output_order; + } fused_outputs[output] = output_order++; } } @@ -173,9 +226,10 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } fused_outputs.insert(fused_outputs_to_add.begin(), fused_outputs_to_add.end()); + fused_outputs.insert(graph_outputs_to_add.begin(), graph_outputs_to_add.end()); // Sort inputs and outputs by the order they were added - std::multimap inputs, outputs; + std::multimap inputs, outputs; for (auto it = fused_inputs.begin(), end = fused_inputs.end(); it != end; ++it) { inputs.insert(std::pair(it->second, it->first)); } @@ -212,24 +266,30 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect return nodes_list_output; } + // Get parent graph output names + std::unordered_set graph_output_names; + for (const auto* output_arg : graph.GetOutputs()) { + graph_output_names.insert(output_arg->Name()); + } + iterations++; const std::vector& node_index = graph.GetNodesInTopologicalOrder(); - int counter = 0; for (const auto& group : nodes_vector_input) { - //construct subgraph + // Construct subgraph if (!group.first.empty()) { - std::unique_ptr sub_graph = GetSubGraph(group, counter, graph); - if (group.second) { nodes_list_output.push_back(group); } else { onnxruntime::Model model_build(graph.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap(), std::vector(), *GetLogger()); onnxruntime::Graph& graph_build = model_build.MainGraph(); - //Add node and node args + // Add node and node args + // If node output is also parent graph output, the output will be added to the + // subgraph's output list + std::vector subgraph_output_names; for (const auto& index : group.first) { const auto& node = graph.GetNode(node_index[index]); - std::vector inputs, outputs; + std::vector inputs, outputs; for (auto input : node->InputDefs()) { auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); inputs.push_back(&n_input); @@ -237,21 +297,46 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect for (auto output : node->OutputDefs()) { auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); outputs.push_back(&n_output); + const auto name = output->Name(); + if (graph_output_names.find(name) != graph_output_names.end()) { + subgraph_output_names.push_back(name); + } } graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain()); } ORT_ENFORCE(graph_build.Resolve().IsOK()); - for (const auto& input : sub_graph->GetMetaDef()->inputs) { - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph.GetInitializedTensor(input, initializer)) { - graph_build.AddInitializedTensor(*initializer); - } + // Add parent graph output to the subgraph + int i = 0; + std::vector subgraph_outputs; + subgraph_outputs.resize(subgraph_output_names.size()); + for (auto& name : subgraph_output_names) { + auto output_arg = graph.GetNodeArg(name); + auto& subgraph_output_arg = graph_build.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + subgraph_outputs[i] = &subgraph_output_arg; + ++i; + } + auto& graph_build_outputs = graph_build.GetOutputs(); + subgraph_outputs.insert(subgraph_outputs.begin(), graph_build_outputs.begin(), graph_build_outputs.end()); + graph_build.SetOutputs(graph_build_outputs); + + // Add initializers to the subgraph + const auto& init_tensors = graph.GetAllInitializedTensors(); + for (const auto& tensor : init_tensors) { + graph_build.AddInitializedTensor(*(tensor.second)); } + ORT_ENFORCE(graph_build.Resolve().IsOK()); + // Serialize modelproto to string - ONNX_NAMESPACE::ModelProto model_proto = model_build.ToProto(); + const onnxruntime::GraphViewer graph_viewer(graph_build); + + onnxruntime::Model model(graph_viewer.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph_viewer.DomainToVersionMap(), std::vector(), *GetLogger()); + ONNX_NAMESPACE::ModelProto model_proto = model.ToProto(); + ToGraphProtoInternal(graph_viewer, *(model_proto.mutable_graph())); + model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + std::string string_buf; model_proto.SerializeToString(&string_buf); @@ -266,7 +351,6 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list); SubGraphCollection_t next_nodes_list; - const onnxruntime::GraphViewer graph_viewer(graph_build); const std::vector& subgraph_node_index = graph_viewer.GetNodesInTopologicalOrder(); next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, graph_viewer, early_termination); for (int i = 0, end = next_nodes_list.size(); i < end; ++i) { @@ -287,27 +371,30 @@ TensorrtExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // Construct modelproto from graph onnxruntime::Model model(graph.Name(), true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), graph.DomainToVersionMap(), std::vector(), *GetLogger()); onnxruntime::Graph& graph_build = model.MainGraph(); - for (const auto& node : graph.Nodes()) { - std::vector inputs, outputs; - for (auto input : node.InputDefs()) { + const std::vector& node_index = graph.GetNodesInTopologicalOrder(); + + for (const auto& index : node_index) { + const auto& node = graph.GetNode(index); + std::vector inputs, outputs; + for (auto input : node->InputDefs()) { auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); inputs.push_back(&n_input); } - for (auto output : node.OutputDefs()) { + for (auto output : node->OutputDefs()) { auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); outputs.push_back(&n_output); } - graph_build.AddNode(node.Name(), node.OpType(), node.Description(), inputs, outputs, &node.GetAttributes(), node.Domain()); + graph_build.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain()); } + graph_build.SetOutputs(graph.GetOutputs()); - auto status = graph_build.Resolve(); - - //Add initializer to graph + // Add initializer to graph const auto& init_tensors = graph.GetAllInitializedTensors(); for (const auto& tensor : init_tensors) { graph_build.AddInitializedTensor(*(tensor.second)); } + auto status = graph_build.Resolve(); ORT_ENFORCE(status.IsOK(), status); ONNX_NAMESPACE::ModelProto model_proto = model.ToProto(); model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); @@ -317,27 +404,64 @@ TensorrtExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, model_proto.SerializeToString(&string_buf); // Get supported node list - SubGraphCollection_t parser_nodes_vector; - TensorrtLogger& trt_logger = GetTensorrtLogger(); - auto trt_builder = unique_pointer(nvinfer1::createInferBuilder(trt_logger)); - const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); - auto trt_network = unique_pointer(trt_builder->createNetworkV2(explicitBatch)); - auto trt_parser = unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); - trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_vector); - + std::vector nodes_vector(node_index.size()); + std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); + SubGraphCollection_t parser_nodes_vector = {{nodes_vector, false}}; SubGraphCollection_t supported_nodes_vector; - const char* batch_env = getenv("ORT_TENSORRT_MAX_PARSER_ITERATIONS"); - const int max_iterations = batch_env ? atoi(batch_env) : max_parser_iterations_; bool early_termination = false; - supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_iterations, graph, &early_termination); + supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); if (early_termination) { supported_nodes_vector.clear(); } + // Remove nodes with empty shape (for example [1, 0]) because TensorRT 6 doens't support empty shape + SubGraphCollection_t post_processed_supported_nodes_vector; + for (auto& group : supported_nodes_vector) { + // Right now only NonZero and NonMaxSuppression related empty shape nodes are removed. + // The typical cases are Faster-rcnn and Mask-rcnn + // TODO: Remove the code if TensorRT fixed the issue in future release, or find a better generic way here to work around + post_processed_supported_nodes_vector.push_back({}); + for (const auto& index : group.first) { + const auto& node = graph.GetNode(node_index[index]); + bool exclude_node = false; + for (auto input : node->InputDefs()) { + auto input_shape = input->Shape(); + if (input_shape) { + for (auto dim : input_shape->dim()) { + std::string dim_name = dim.dim_param(); + std::string exclude_dim_name1 = "NonZero"; + std::string exclude_dim_name2 = "NonMaxSuppression"; + if (!dim_name.empty()) { + if ((dim_name.find(exclude_dim_name1) != std::string::npos) || (dim_name.find(exclude_dim_name2) != std::string::npos)) { + exclude_node = true; + break; + } + } + } + } + + if (exclude_node) { + break; + } + } + if (!exclude_node) { + post_processed_supported_nodes_vector.back().first.push_back(index); + } + } + + // Remove subgraph if it is empty or its size is smaller than the predefined minimal subgraph size + const int subgraph_size = post_processed_supported_nodes_vector.back().first.size(); + if (subgraph_size == 0 || subgraph_size < min_subgraph_size_) { + post_processed_supported_nodes_vector.pop_back(); + } else { + post_processed_supported_nodes_vector.back().second = group.second; + } + } + // Construct subgraph capability from node list std::vector> result; int counter = 0; - for (const auto& group : supported_nodes_vector) { + for (const auto& group : post_processed_supported_nodes_vector) { if (!group.first.empty()) { std::unique_ptr sub_graph = GetSubGraph(group, counter, graph); result.push_back(onnxruntime::make_unique(std::move(sub_graph))); @@ -351,9 +475,8 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& node_compute_funcs) { for (const auto* fused_node : fused_nodes) { std::vector input_indexes; - std::vector input_dim_sizes; std::vector output_indexes; - std::vector output_dim_sizes; + std::unordered_map>> input_shape_ranges; std::vector> output_shapes; std::vector output_types; @@ -391,50 +514,49 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(nvinfer1::createInferBuilder(trt_logger)); const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); auto trt_network = unique_pointer(trt_builder->createNetworkV2(explicitBatch)); - + auto trt_config = unique_pointer(trt_builder->createBuilderConfig()); auto trt_parser = unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); trt_parser->parse(string_buf.data(), string_buf.size()); - - const char* batch_env = getenv("ORT_TENSORRT_MAX_BATCH_SIZE"); - if (batch_env) { - const int max_batch_size = atoi(batch_env); - SetMaxBatchSize(max_batch_size); - } - - const char* workspace_env = getenv("ORT_TENSORRT_MAX_WORKSPACE_SIZE"); - if (workspace_env) { - const size_t max_workspace_size = atoi(workspace_env); - SetMaxWorkspaceSize(max_workspace_size); - } - - trt_builder->setMaxBatchSize(max_batch_size_); - auto trt_config = unique_pointer(trt_builder->createBuilderConfig()); trt_config->setMaxWorkspaceSize(max_workspace_size_); - //Set optimization profile for dynamic shapes - //Only support dynamic batch size on the first dimension for now - //TODO: add full dynamic shape support + // Set optimization profile for dynamic shapes auto trt_profile = trt_builder->createOptimizationProfile(); - bool dynamic_shape = false; - for (unsigned int i = 0, n = trt_network->getNbInputs(); i < n; i++) { + for (unsigned int i = 0, end = trt_network->getNbInputs(); i < end; ++i) { auto input = trt_network->getInput(i); nvinfer1::Dims dims = input->getDimensions(); nvinfer1::Dims dims_min = dims; nvinfer1::Dims dims_opt = dims; nvinfer1::Dims dims_max = dims; - if (dims.d[0] == -1) { - dims_min.d[0] = 1; - dims_opt.d[0] = max_batch_size_; - dims_max.d[0] = max_batch_size_; + + int nb_dims = dims.nbDims; + if (input->isShapeTensor()) { // Shape tensor + std::vector shapes_min(nb_dims), shapes_opt(nb_dims), shapes_max(nb_dims); + for (int j = 0, end = nb_dims; j < end; ++j) { + shapes_min[j] = 1; + shapes_opt[j] = 1; + shapes_max[j] = 1; + } + trt_profile->setShapeValues(input->getName(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], nb_dims); + trt_profile->setShapeValues(input->getName(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], nb_dims); + trt_profile->setShapeValues(input->getName(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], nb_dims); + } else { // Execution tensor + for (int j = 0, end = nb_dims; j < end; ++j) { + // For dynamic shape subgraph, a dummy engine is created at compile phase. + // Real engine will be created at compute phase based on input data + if (dims.d[j] == -1) { // Dynamic shape + dims_min.d[j] = 1; + dims_opt.d[j] = 1; + dims_max.d[j] = 1; + } + } + // TRT6: Optimization profile need to be provided for all inputs if any of them has dynamic shape trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMIN, dims_min); trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kOPT, dims_opt); trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMAX, dims_max); - dynamic_shape = true; } } - if (dynamic_shape) { - trt_config->addOptimizationProfile(trt_profile); - } + + trt_config->addOptimizationProfile(trt_profile); auto trt_engine = unique_pointer(trt_builder->buildEngineWithConfig(*trt_network, *trt_config)); if (trt_engine == nullptr) { @@ -452,26 +574,31 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorgetNbInputs(); input_indexes.resize(num_inputs); - input_dim_sizes.resize(num_inputs); for (int i = 0; i < num_inputs; ++i) { - const std::string& name = trt_network->getInput(i)->getName(); + auto input = trt_network->getInput(i); + const std::string& name = input->getName(); size_t bindingIndex = trt_engine->getBindingIndex(name.c_str()); nvinfer1::Dims dimensions = trt_engine->getBindingDimensions(static_cast(bindingIndex)); auto iter = input_map.find(name); if (iter != input_map.end()) { input_indexes[bindingIndex] = iter->second; } - size_t dim_size = 1; - for (int j = 0, end = dimensions.nbDims; j < end; ++j) { - dim_size *= dimensions.d[j]; + if (input->isShapeTensor()) { // Shape tensor + for (int j = 0, end = dimensions.nbDims; j < end; ++j) { + input_shape_ranges[bindingIndex][j] = std::make_pair(INT_MAX, INT_MIN); + } + } else { + for (int j = 0, end = dimensions.nbDims; j < end; ++j) { + if (dimensions.d[j] == -1) { + input_shape_ranges[bindingIndex][j] = std::make_pair(INT_MAX, INT_MIN); + } + } } - input_dim_sizes[bindingIndex] = dim_size; } // Get output shape and binding index int num_outputs = trt_network->getNbOutputs(); output_indexes.resize(num_outputs); - output_dim_sizes.resize(num_outputs); output_shapes.resize(num_outputs); output_types.resize(num_outputs); for (int i = 0; i < num_outputs; ++i) { @@ -483,12 +610,9 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsecond; } - size_t dim_size = 1; for (int j = 0, end = dimensions.nbDims; j < end; ++j) { output_shapes[bindingIndex].push_back(dimensions.d[j]); - dim_size *= dimensions.d[j]; } - output_dim_sizes[bindingIndex] = dim_size; const auto& graph_output = model_proto.graph().output(); const auto& tensor_type = graph_output[i].type().tensor_type(); @@ -501,11 +625,12 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorName(), std::move(trt_parser)); engines_.emplace(fused_node->Name(), std::move(trt_engine)); contexts_.emplace(fused_node->Name(), std::move(trt_context)); + builders_.emplace(fused_node->Name(), std::move(trt_builder)); + networks_.emplace(fused_node->Name(), std::move(trt_network)); input_info_[fused_node->Name()].push_back(input_indexes); - input_info_[fused_node->Name()].push_back(input_dim_sizes); output_info_[fused_node->Name()].push_back(output_indexes); - output_info_[fused_node->Name()].push_back(output_dim_sizes); output_info_[fused_node->Name()].push_back(output_types); + input_shape_ranges_[fused_node->Name()] = input_shape_ranges; output_shapes_[fused_node->Name()] = output_shapes; // Create function state @@ -513,8 +638,10 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector p = onnxruntime::make_unique(); - *p = {context->allocate_func, context->release_func, context->allocator_handle, parsers_[context->node_name].get(), engines_[context->node_name].get(), contexts_[context->node_name].get(), - input_info_[context->node_name], output_info_[context->node_name], output_shapes_[context->node_name], &tensorrt_mu_}; + *p = {context->allocate_func, context->release_func, context->allocator_handle, parsers_[context->node_name].get(), + engines_[context->node_name].get(), contexts_[context->node_name].get(), builders_[context->node_name].get(), + networks_[context->node_name].get(), input_info_[context->node_name], output_info_[context->node_name], + input_shape_ranges_[context->node_name], output_shapes_[context->node_name], &tensorrt_mu_}; *state = p.release(); return 0; }; @@ -529,39 +656,124 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(state); + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); const std::vector& input_indexes = (trt_state->input_info)[0]; const std::vector& output_indexes = (trt_state->output_info)[0]; - const std::vector& output_types = (trt_state->output_info)[2]; + const std::vector& output_types = (trt_state->output_info)[1]; int num_binding_inputs = input_indexes.size(); int num_binding_outputs = output_indexes.size(); int total_bindings = num_binding_inputs + num_binding_outputs; std::vector buffers(total_bindings); + //TODO: check shape tensor inputs by allInutShapesSpecified() bool dynamic_shape = false; - if (!trt_state->context->allInputDimensionsSpecified()) { + auto trt_context = trt_state->context; + if (!trt_context->allInputDimensionsSpecified()) { dynamic_shape = true; } - // Get batch size and allocate cuda memory for inputs + // Update shape ranges + bool dimension_update = false; + auto trt_builder = trt_state->builder; + auto trt_profile = trt_builder->createOptimizationProfile(); + for (int i = 0, end = num_binding_inputs; i < end; ++i) { + // TODO: check if getInput indexing is same with binding index + auto input = trt_state->network->getInput(i); + nvinfer1::Dims dims = input->getDimensions(); + nvinfer1::Dims dims_min = dims; + nvinfer1::Dims dims_opt = dims; + nvinfer1::Dims dims_max = dims; + + // Check and update shape ranges for dynamic shape inputs + auto& shape_ranges = trt_state->input_shape_ranges; + if (shape_ranges.find(i) != shape_ranges.end()) { + const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_indexes[i]); + auto tensor_info = ort.GetTensorTypeAndShape(input_tensor); + const auto& tensor_shape = ort.GetTensorShape(tensor_info); + auto& engine = trt_context->getEngine(); + nvinfer1::Dims dimensions = engine.getBindingDimensions(static_cast(i)); + int nb_dims = dimensions.nbDims; + for (int j = 0, end = nb_dims; j < end; ++j) { + auto& shape_range = shape_ranges[i]; + if (shape_range.find(j) != shape_range.end()) { + // Update minimum dimension + if (tensor_shape[j] < shape_range[j].first) { + shape_range[j].first = tensor_shape[j]; + dims_min.d[j] = tensor_shape[j]; + dimension_update = true; + } + // Update maximum dimension + if (tensor_shape[j] > shape_range[j].second) { + shape_range[j].second = tensor_shape[j]; + dims_max.d[j] = tensor_shape[j]; + dims_opt.d[j] = tensor_shape[j]; + dimension_update = true; + } + } + } + + if (dimension_update) { + if (engine.isShapeBinding(i)) { + std::vector shapes_min(nb_dims), shapes_opt(nb_dims), shapes_max(nb_dims); + for (int j = 0, end = nb_dims; j < end; ++j) { + shapes_min[j] = dims_min.d[j]; + shapes_opt[j] = dims_opt.d[j]; + shapes_max[j] = dims_max.d[j]; + } + trt_profile->setShapeValues(input->getName(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], nb_dims); + trt_profile->setShapeValues(input->getName(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], nb_dims); + trt_profile->setShapeValues(input->getName(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], nb_dims); + } else { + trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMIN, dims_min); + trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kOPT, dims_opt); + trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMAX, dims_max); + } + } + } + + // TensorRT6 requires optimization profile to be defined for all inputs if any input dimension is symbolic + if (dynamic_shape) { + trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMIN, dims_min); + trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kOPT, dims_opt); + trt_profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMAX, dims_max); + } + } + + // Regenerate engine and context + // Only one profile is generated, so no need to explicitly set optimization profile + if (dimension_update) { + auto trt_config = unique_pointer(trt_builder->createBuilderConfig()); + trt_config->addOptimizationProfile(trt_profile); + trt_state->engine = trt_builder->buildEngineWithConfig(*trt_state->network, *trt_config); + ORT_ENFORCE(trt_state->engine != nullptr); + + trt_state->context = trt_state->engine->createExecutionContext(); + ORT_ENFORCE(trt_state->context != nullptr); + trt_context = trt_state->context; + } + + // Set input shapes and assign input buffers for (int i = 0, end = num_binding_inputs; i < end; ++i) { const OrtValue* input_tensor = ort.KernelContext_GetInput(context, input_indexes[i]); auto tensor_info = ort.GetTensorTypeAndShape(input_tensor); const auto& tensor_shape = ort.GetTensorShape(tensor_info); - //Set dynamic shapes - nvinfer1::Dims dimensions = trt_state->context->getEngine().getBindingDimensions(static_cast(i)); - if (dynamic_shape) { - for (int j = 0, end = tensor_shape.size(); j < end; ++j) + // Set dynamic shapes + nvinfer1::Dims dimensions = trt_context->getBindingDimensions(static_cast(i)); + int nb_dims = dimensions.nbDims; + if (dimension_update) { + for (int j = 0, end = nb_dims; j < end; ++j) dimensions.d[j] = tensor_shape[j]; - trt_state->context->setBindingDimensions(i, dimensions); + trt_context->setBindingDimensions(i, dimensions); } auto tensor_type = ort.GetTensorElementType(tensor_info); ort.ReleaseTensorTypeAndShapeInfo(tensor_info); - if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { buffers[i] = const_cast(ort.GetTensorData(input_tensor)); + } else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { + buffers[i] = const_cast(ort.GetTensorData(input_tensor)); } else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) { buffers[i] = const_cast(ort.GetTensorData(input_tensor)); } else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { @@ -569,7 +781,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector output_dim_size(num_binding_outputs, 1); + // Set output shapes and assign output buffers + std::vector output_dim_sizes(num_binding_outputs, 1); std::vector output_tensor(num_binding_outputs, nullptr); for (int i = 0, end = num_binding_outputs; i < end; ++i) { // Set dynamic shapes - nvinfer1::Dims dimensions = trt_state->context->getBindingDimensions(static_cast(i + num_binding_inputs)); - for (int j = 0, end = trt_state->output_shapes[i].size(); j < end; ++j) { + nvinfer1::Dims dimensions = trt_context->getBindingDimensions(static_cast(i + num_binding_inputs)); + int nb_dims = dimensions.nbDims; + for (int j = 0, end = nb_dims; j < end; ++j) { trt_state->output_shapes[i][j] = dimensions.d[j]; } @@ -595,17 +808,18 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(output_tensor[i]); + } else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { + buffers[i + num_binding_inputs] = ort.GetTensorMutableData(output_tensor[i]); } else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) { buffers[i + num_binding_inputs] = ort.GetTensorMutableData(output_tensor[i]); } else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { buffers[i + num_binding_inputs] = ort.GetTensorMutableData(output_tensor[i]); } else if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { // Allocate INT32 CUDA memory for INT64 output type because TensorRT doesn't fully support INT64 - for (int j = 0, end = dimensions.nbDims; j < end; ++j) { - output_dim_size[i] *= dimensions.d[j]; + for (int j = 0, end = nb_dims; j < end; ++j) { + output_dim_sizes[i] *= dimensions.d[j]; } - CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[i + num_binding_inputs], output_dim_size[i] * sizeof(int32_t))); - + CUDA_RETURN_IF_ERROR(cudaMalloc(&buffers[i + num_binding_inputs], output_dim_sizes[i] * sizeof(int32_t))); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP output onnx tensor data type: " + std::to_string(output_types[i]) + " not supported."); @@ -613,13 +827,14 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector lock(*(trt_state->tensorrt_mu_ptr)); - trt_state->context->enqueueV2(&buffers[0], nullptr, nullptr); + if (!trt_context->enqueueV2(&buffers[0], nullptr, nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP Execution Context Enqueue Failed."); + } // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64 for (int i = 0, end = num_binding_outputs; i < end; ++i) { if (output_types[i] == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { - cuda::Impl_Cast(reinterpret_cast(buffers[i + num_binding_inputs]), ort.GetTensorMutableData(output_tensor[i]), output_dim_size[i]); + cuda::Impl_Cast(reinterpret_cast(buffers[i + num_binding_inputs]), ort.GetTensorMutableData(output_tensor[i]), output_dim_sizes[i]); } } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h old mode 100755 new mode 100644 index 7eecc29892cf8..99b73a5ee01b6 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -43,8 +43,11 @@ struct TensorrtFuncState { nvonnxparser::IParser* parser = nullptr; nvinfer1::ICudaEngine* engine = nullptr; nvinfer1::IExecutionContext* context = nullptr; + nvinfer1::IBuilder* builder = nullptr; + nvinfer1::INetworkDefinition* network = nullptr; std::vector> input_info; std::vector> output_info; + std::unordered_map>> input_shape_ranges; std::vector> output_shapes; OrtMutex* tensorrt_mu_ptr = nullptr; }; @@ -69,18 +72,10 @@ class TensorrtExecutionProvider : public IExecutionProvider { AllocatorPtr GetAllocator(int id, OrtMemType mem_type) const override; - void SetMaxBatchSize(const int batch_size) { - max_batch_size_ = batch_size; - } - - void SetMaxWorkspaceSize(const size_t workspace_size) { - max_workspace_size_ = workspace_size; - } - private: - int max_batch_size_ = 1; size_t max_workspace_size_ = 1 << 30; // 1GB - int max_parser_iterations_ = 6; + int max_partition_iterations_ = 1000; + int min_subgraph_size_ = 1; struct InferDeleter { template @@ -99,8 +94,11 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unordered_map> parsers_; std::unordered_map> engines_; std::unordered_map> contexts_; + std::unordered_map> builders_; + std::unordered_map> networks_; std::unordered_map>> input_info_; std::unordered_map>> output_info_; + std::unordered_map>>> input_shape_ranges_; std::unordered_map>> output_shapes_; /**Get IndexedSubGraph based on node list of the subgraph*/ diff --git a/onnxruntime/test/providers/cpu/tensor/tensor_op_test.cc b/onnxruntime/test/providers/cpu/tensor/tensor_op_test.cc index 001083b1af864..57619960dd025 100644 --- a/onnxruntime/test/providers/cpu/tensor/tensor_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/tensor_op_test.cc @@ -30,7 +30,7 @@ TEST(TensorOpTest, ReshapeWithEmptyDim) { test.AddInput("data", {1, 1, 1}, std::vector(1, 1.0f)); test.AddInput("shape", {0}, {}, true); test.AddOutput("reshaped", {}, std::vector(1, 1.0f)); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT doesn't support empty dimension } TEST(TensorOpTest, ReshapeWithInitializer) { diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 1d51175cb0aa3..14e3fb283af8d 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -56,7 +56,7 @@ TEST(TensorrtExecutionProviderTest, FunctionTest) { auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()); - std::string model_file_name = "trt_execution_provider_test_graph.onnx"; + std::string model_file_name = "trt_execution_provider_function_test.onnx"; status = onnxruntime::Model::Save(model, model_file_name); std::vector dims_mul_x = {1, 3, 2}; @@ -165,7 +165,7 @@ TEST(TensorrtExecutionProviderTest, NodeIndexMappingTest) { auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()); - std::string model_file_name = "trt_execution_provider_NodeIndexMappingTest.onnx"; + std::string model_file_name = "trt_execution_provider_nodeindexmapping_test.onnx"; status = onnxruntime::Model::Save(model, model_file_name); std::vector dims_mul_x = {1, 3, 2}; diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py old mode 100755 new mode 100644 index 3fc03f13a73ac..d5be9f21297c7 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -15,6 +15,7 @@ import subprocess import sys import hashlib +import itertools from os.path import expanduser logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.DEBUG) @@ -491,15 +492,15 @@ def setup_tensorrt_vars(args): "tensorrt_home='{}' valid={}." .format(tensorrt_home, tensorrt_home_valid)) - # Set maximum batch size for TensorRT. The number needs to be no less than maximum batch size in all unit tests - os.environ["ORT_TENSORRT_MAX_BATCH_SIZE"] = "13" - # Set maximum workspace size in byte for TensorRT (1GB = 1073741824 bytes) os.environ["ORT_TENSORRT_MAX_WORKSPACE_SIZE"] = "1073741824" # Set maximum number of iterations to detect unsupported nodes and partition the models for TensorRT - os.environ["ORT_TENSORRT_MAX_PARSER_ITERATIONS"] = "6" - + os.environ["ORT_TENSORRT_MAX_PARTITION_ITERATIONS"] = "1000" + + # Set minimum subgraph node size in graph partitioning for TensorRT + os.environ["ORT_TENSORRT_MIN_SUBGRAPH_SIZE"] = "1" + return tensorrt_home def setup_dml_build(args, cmake_path, build_dir, configs): @@ -605,9 +606,6 @@ def run_onnx_tests(build_dir, configs, onnx_test_data_dir, provider, enable_mult cmd += ['-d', '1'] if config != 'Debug' and os.path.exists(model_dir): - # some models in opset9 and above are not supported by TensorRT yet - if provider == 'tensorrt': - model_dir = os.path.join(model_dir, "opset8") cmd.append(model_dir) if os.path.exists(onnx_test_data_dir): cmd.append(onnx_test_data_dir) @@ -619,6 +617,33 @@ def run_onnx_tests(build_dir, configs, onnx_test_data_dir, provider, enable_mult if enable_parallel_executor_test: run_subprocess([exe,'-x'] + cmd, cwd=cwd) +# tensorrt function to run onnx test and model test. +def tensorrt_run_onnx_tests(build_dir, configs, onnx_test_data_dir): + for config in configs: + cwd = get_config_build_dir(build_dir, config) + if is_windows(): + exe = os.path.join(cwd, config, 'onnx_test_runner') + model_dir = os.path.join(cwd, "models") + else: + exe = os.path.join(cwd, 'onnx_test_runner') + model_dir = os.path.join(build_dir, "models") + cmd_base = ['-e', 'tensorrt', '-j', '1'] + + #onnx test + if os.path.exists(onnx_test_data_dir): + onnx_test_cmd = cmd_base + [onnx_test_data_dir] + run_subprocess([exe] + onnx_test_cmd, cwd=cwd) + + # model test + # TensorRT can run most of the model tests, but only part of them is enabled here to save CI build time. + if config != 'Debug' and os.path.exists(model_dir): + model_dir_opset8 = os.path.join(model_dir, "opset8") + model_dir_opset8 = glob.glob(os.path.join(model_dir_opset8, "test_*")) + model_dir_opset10 = os.path.join(model_dir, "opset10") + model_dir_opset10 = glob.glob(os.path.join(model_dir_opset10, "tf_inception_v1")) + for dir_path in itertools.chain(model_dir_opset8, model_dir_opset10): + model_test_cmd = cmd_base + [dir_path] + run_subprocess([exe] + model_test_cmd, cwd=cwd) # dnnl temporary function for running onnx tests and model tests separately. def dnnl_run_onnx_tests(build_dir, configs, onnx_test_data_dir): @@ -956,7 +981,9 @@ def main(): # Disable some onnx unit tests that TensorRT doesn't supported yet if not is_windows(): onnx_test_data_dir = os.path.join(source_dir, "cmake", "external", "onnx", "onnx", "backend", "test", "data", "simple") - run_onnx_tests(build_dir, configs, onnx_test_data_dir, 'tensorrt', args.enable_multi_device_test, False, 1) + tensorrt_run_onnx_tests(build_dir, configs, onnx_test_data_dir) + else: + tensorrt_run_onnx_tests(build_dir, configs, "") if args.use_cuda: run_onnx_tests(build_dir, configs, onnx_test_data_dir, 'cuda', args.enable_multi_device_test, False, 2) diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml index c2b9ec1f2b13b..636215952890f 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml @@ -4,13 +4,13 @@ jobs: AgentPool : 'Win-GPU-CUDA10' DoDebugBuild: 'true' DoCompliance: 'false' - BuildCommand: '$(Build.SourcesDirectory)\tools\ci_build\build.py --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --cmake_path $(Build.BinariesDirectory)\cmake\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake\bin\ctest.exe --enable_pybind --use_dnnl --use_tensorrt --tensorrt_home="C:\local\TensorRT-6.0.1.5" --build_shared_lib --build_csharp --enable_onnx_tests --use_cuda --cuda_version=10.0 --cuda_home="C:\local\cuda_10.0.130_win10_trt6015dll" --cudnn_home="C:\local\cudnn-10.0-windows10-x64-v7.3.1.20\cuda" --gen_doc' + BuildCommand: '$(Build.SourcesDirectory)\tools\ci_build\build.py --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --cmake_path $(Build.BinariesDirectory)\cmake\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake\bin\ctest.exe --enable_pybind --use_dnnl --use_tensorrt --tensorrt_home="C:\local\TensorRT-6.0.1.5" --build_shared_lib --build_csharp --enable_onnx_tests --use_cuda --cuda_version=10.0 --cuda_home="C:\local\cuda_10.0.130_win10" --cudnn_home="C:\local\cudnn-10.0-windows10-x64-v7.3.1.20\cuda" --gen_doc' JobName: 'Windows_CI_GPU_Dev' DoNugetPack: 'false' NuPackScript : '' DoTestCoverage: 'false' BuildArch: 'amd64' SetVcvars: 'true' - MsbuildArguments: '/m /p:CudaToolkitDir=C:\local\cuda_10.0.130_win10_trt6015dll\' + MsbuildArguments: '/m /p:CudaToolkitDir=C:\local\cuda_10.0.130_win10\' EnvSetupScript: 'setup_env_cuda.bat' CudaVersion: '10.0'