diff --git a/tensorflow/lite/micro/kernels/expand_dims.cc b/tensorflow/lite/micro/kernels/expand_dims.cc index 6bae37b9049..d47b42cbe0c 100644 --- a/tensorflow/lite/micro/kernels/expand_dims.cc +++ b/tensorflow/lite/micro/kernels/expand_dims.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -128,13 +130,18 @@ TfLiteStatus ExpandDimsEval(TfLiteContext* context, TfLiteNode* node) { memCopyN(tflite::micro::GetTensorData(output), tflite::micro::GetTensorData(input), flat_size); } break; + case kTfLiteInt16: { + memCopyN(tflite::micro::GetTensorData(output), + tflite::micro::GetTensorData(input), flat_size); + } break; case kTfLiteInt8: { memCopyN(tflite::micro::GetTensorData(output), tflite::micro::GetTensorData(input), flat_size); } break; default: MicroPrintf( - "Expand_Dims only currently supports int8 and float32, got %d.", + "Expand_Dims only currently supports int8, int16 and float32, got " + "%d.", input->type); return kTfLiteError; } diff --git a/tensorflow/lite/micro/kernels/expand_dims_test.cc b/tensorflow/lite/micro/kernels/expand_dims_test.cc index d8e217e588b..39a83b57471 100644 --- a/tensorflow/lite/micro/kernels/expand_dims_test.cc +++ b/tensorflow/lite/micro/kernels/expand_dims_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" @@ -138,6 +140,20 @@ TF_LITE_MICRO_TEST(ExpandDimsPositiveAxisTest2) { golden_data, output_data); } +TF_LITE_MICRO_TEST(ExpandDimsPositiveAxisTest3) { + int16_t output_data[6]; + int input_dims[] = {3, 3, 1, 2}; + const int16_t input_data[] = {-1, 1, 2, -2, 0, 3}; + const int16_t golden_data[] = {-1, 1, 2, -2, 0, 3}; + int axis_dims[] = {1, 1}; + const int32_t axis_data[] = {3}; + int golden_dims[] = {1, 3, 1, 2}; + int output_dims[] = {4, 3, 1, 2, 1}; + tflite::testing::TestExpandDims(input_dims, input_data, axis_dims, + axis_data, golden_dims, output_dims, + golden_data, output_data); +} + TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest4) { int8_t output_data[6]; int input_dims[] = {3, 3, 1, 2};