@@ -82,24 +82,27 @@ void create_mask_matrix(const framework::ExecutionContext& context,
8282 const bool & is_reverse) {
8383 const auto & seq_len_vec = GetDataFromTensor<int >(sequence_length);
8484 const int & table_width = mask_matrix->dims ()[0 ];
85+ VLOG (2 ) << " INPUT MASK TENSOR SHAPE:" << mask_matrix->dims ();
8586 Tensor temp;
8687 temp.Resize (
8788 framework::make_ddim ({mask_matrix->dims ()[1 ], mask_matrix->dims ()[0 ]}));
8889 T* data_temp = temp.mutable_data <T>(context.GetPlace ());
89- std::memset (data_temp, 1 , mask_matrix->numel () * sizeof (T ));
90+ std::fill (data_temp, data_temp + mask_matrix->numel (), static_cast <T>( 1.0 ));
9091 for (unsigned int i = 0 ; i < seq_len_vec.size (); i++) {
9192 // reset the mask matrix
9293 if (seq_len_vec[i] == table_width) {
9394 continue ;
9495 }
9596 if (is_reverse) {
96- std::memset (data_temp + i * table_width * sizeof (T), 0 ,
97- (table_width - seq_len_vec[i]) * sizeof (T));
97+ std::fill (data_temp + i * table_width,
98+ data_temp + i * table_width + seq_len_vec[i],
99+ static_cast <T>(0 ));
98100 } else {
99- std::memset (data_temp + ( i * table_width + seq_len_vec[i]) * sizeof (T), 0 ,
100- (table_width - seq_len_vec[i] ) * sizeof (T ));
101+ std::fill (data_temp + i * table_width + seq_len_vec[i],
102+ data_temp + (i + 1 ) * table_width, static_cast <T>( 0 ));
101103 }
102104 }
105+ Print2DTensor<T>(&temp, " Original mask Tensor" );
103106 // transpose the result for the mask
104107 mask_matrix->mutable_data <T>(context.GetPlace ());
105108 std::vector<int > trans_vec;
@@ -125,8 +128,8 @@ void dropout_cpu_function_inplace(const framework::ExecutionContext& context,
125128 auto mask_data = mask->mutable_data <uint8_t >(context.GetPlace ());
126129 // Special case when dropout_prob is 1.0
127130 if (dropout_prob == 1 .0f ) {
128- std::memset (x_data, 0 , size * sizeof (*x_data ));
129- std::memset (mask_data, 0 , size * sizeof (*mask_data )); // NOLINT
131+ std::fill (x_data, x_data + size, static_cast <T>( 0 ));
132+ std::fill (mask_data, mask_data + size, static_cast <T>( 0 ));
130133 return ;
131134 }
132135 auto engine = framework::GetCPURandomEngine (seed_number);
@@ -145,7 +148,7 @@ void dropout_cpu_function_inplace(const framework::ExecutionContext& context,
145148 }
146149 auto mask_data = mask->data <uint8_t >();
147150 if (dropout_prob == 1 .0f ) {
148- std::memset (x_data, 0 , size * sizeof (*x_data ));
151+ std::fill (x_data, x_data + size, static_cast <T>( 0 ));
149152 return ;
150153 }
151154 for (size_t i = 0 ; i < size; ++i) {
@@ -300,15 +303,27 @@ struct LSTMCell : Cell<T> {
300303 cell_act, cand_act);
301304 framework::TensorCopy (*output, device_ctx->GetPlace (), *device_ctx, last_h);
302305 Print3DTensor<T>(last_h, " last_h" );
303- // auto eigen_output =
304- // framework::EigenMatrix<T>::Reshape(*output, output->dims().size() -
305- // 1);
306- // auto eigen_mask = framework::EigenMatrix<T>::From(
307- // mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1}));
308- // // eigen_output.device(device_ctx->eigen_device()) =
309- // eigen_output =
310- // eigen_output *
311- // eigen_mask.broadcast(Eigen::DSizes<int, 2>(1, output->dims()[1]));
306+
307+ auto eigen_init_h =
308+ framework::EigenMatrix<T>::Reshape (*init_h, init_h->dims ().size () - 1 );
309+ auto eigen_last_h =
310+ framework::EigenMatrix<T>::Reshape (*last_h, last_h->dims ().size () - 1 );
311+
312+ auto eigen_mask = framework::EigenMatrix<T>::From (
313+ mask_tensor, framework::make_ddim ({mask_tensor.dims ()[1 ], 1 }));
314+ // eigen_output.device(device_ctx->eigen_device()) =
315+ auto eigen_mask_broadcast =
316+ eigen_mask.broadcast (Eigen::DSizes<int , 2 >(1 , output->dims ()[1 ]));
317+ auto & place = *device_ctx->eigen_device ();
318+ eigen_last_h.device (place) = eigen_last_h * eigen_mask_broadcast +
319+ eigen_init_h * (1 - eigen_mask_broadcast);
320+
321+ auto eigen_init_c =
322+ framework::EigenMatrix<T>::Reshape (*init_c, init_c->dims ().size () - 1 );
323+ auto eigen_last_c =
324+ framework::EigenMatrix<T>::Reshape (*last_c, last_c->dims ().size () - 1 );
325+ eigen_last_c.device (place) = eigen_last_c * eigen_mask_broadcast +
326+ eigen_init_c * (1 - eigen_mask_broadcast);
312327 }
313328};
314329
@@ -367,7 +382,9 @@ struct Layer {
367382 framework::EigenMatrix<T>::Reshape (*output, output->dims ().size () - 1 );
368383 auto eigen_mask = framework::EigenMatrix<T>::From (
369384 mask_tensor, framework::make_ddim ({mask_tensor.dims ()[1 ], 1 }));
370- eigen_output =
385+ auto & place = *context.template device_context <platform::CPUDeviceContext>()
386+ .eigen_device ();
387+ eigen_output.device (place) =
371388 eigen_output *
372389 eigen_mask.broadcast (Eigen::DSizes<int , 2 >(1 , output->dims ()[1 ]));
373390 }
@@ -412,6 +429,7 @@ struct SingleLayer : Layer<T> {
412429 mask_matrix.Resize (framework::make_ddim ({time_step, input->dims ()[1 ]}));
413430 if (has_sequence_length) {
414431 create_mask_matrix<T>(context, sequence_length, &mask_matrix, false );
432+ Print2DTensor<T>(&mask_matrix, " Mask Matrix" );
415433 mask_tensor_list = Unbind (mask_matrix);
416434 }
417435
@@ -447,9 +465,9 @@ struct SingleLayer : Layer<T> {
447465 init_c_holder, last_h_holder, last_c_holder, &output_tensors[i],
448466 mask_tensor_list[i]);
449467 }
450- // if (has_sequence_length) {
451- // this->postprocess(context, &output_tensors[i], mask_tensor_list[i]);
452- // }
468+ if (has_sequence_length) {
469+ this ->postprocess (context, &output_tensors[i], mask_tensor_list[i]);
470+ }
453471 }
454472 if (time_step % 2 == 0 ) {
455473 framework::TensorCopy (*last_h_holder, context.GetPlace (), dev_ctx,
@@ -717,7 +735,13 @@ class CudnnLSTMCPUKernel : public framework::OpKernel<T> {
717735 auto * weight = ctx.Input <Tensor>(" W" );
718736 auto * init_h = ctx.Input <Tensor>(" InitH" );
719737 auto * init_c = ctx.Input <Tensor>(" InitC" );
720- auto * sequence_length = ctx.Input <Tensor>(" SequenceLength" );
738+
739+ bool has_seq_length = ctx.HasInput (" SequenceLength" );
740+ const Tensor* sequence_length = nullptr ;
741+ if (has_seq_length) {
742+ sequence_length = ctx.Input <Tensor>(" SequenceLength" );
743+ }
744+ // auto* sequence_length = ctx.Input<Tensor>("SequenceLength");
721745 auto * last_h = ctx.Output <Tensor>(" LastH" );
722746 auto * last_c = ctx.Output <Tensor>(" LastC" );
723747 auto * output = ctx.Output <Tensor>(" Out" );
0 commit comments