@@ -15,6 +15,7 @@ limitations under the License.
1515#if GOOGLE_CUDA
1616
1717#include " tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.h"
18+
1819#include " tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_gpu.h"
1920#include " tensorflow_recommenders_addons/dynamic_embedding/core/utils/utils.h"
2021
@@ -214,10 +215,12 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
214215 if (cur_size > 0 ) {
215216 CUDA_CHECK (cudaMallocManaged ((void **)&d_dump_counter, sizeof (size_t )));
216217 CUDA_CHECK (cudaMallocManaged ((void **)&d_keys, sizeof (K) * cur_size));
217- CUDA_CHECK (cudaMallocManaged ((void **)&d_values, sizeof (V) * runtime_dim_ * cur_size));
218+ CUDA_CHECK (cudaMallocManaged ((void **)&d_values,
219+ sizeof (V) * runtime_dim_ * cur_size));
218220 table_->dump (d_keys, (gpu::ValueArrayBase<V>*)d_values, 0 , capacity,
219- d_dump_counter, stream);
220- cudaMemcpyAsync (&h_dump_counter, d_dump_counter, sizeof (size_t ), cudaMemcpyDeviceToHost, stream);
221+ d_dump_counter, stream);
222+ cudaMemcpyAsync (&h_dump_counter, d_dump_counter, sizeof (size_t ),
223+ cudaMemcpyDeviceToHost, stream);
221224 CUDA_CHECK (cudaStreamSynchronize (stream));
222225 }
223226
@@ -226,8 +229,9 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
226229 CreateTable (new_max_size, &table_);
227230
228231 if (cur_size > 0 ) {
229- table_->upsert ((const K*)d_keys, (const gpu::ValueArrayBase<V>*)d_values,
230- h_dump_counter, stream);
232+ table_->upsert ((const K*)d_keys,
233+ (const gpu::ValueArrayBase<V>*)d_values, h_dump_counter,
234+ stream);
231235 cudaStreamSynchronize (stream);
232236 cudaFree (d_keys);
233237 cudaFree (d_values);
@@ -387,6 +391,54 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
387391 return Status::OK ();
388392 }
389393
394+ Status ExportValuesToFile (OpKernelContext* ctx, const string filepath,
395+ const size_t buffer_size) {
396+ cudaStream_t _stream;
397+ CUDA_CHECK (cudaStreamCreate (&_stream));
398+
399+ {
400+ tf_shared_lock l (mu_);
401+ table_->dump_to_file (ctx, filepath, runtime_dim_, _stream, buffer_size);
402+ CUDA_CHECK (cudaStreamSynchronize (_stream));
403+ }
404+ CUDA_CHECK (cudaStreamDestroy (_stream));
405+ return Status::OK ();
406+ }
407+
408+ Status ImportValuesFromFile (OpKernelContext* ctx, const string filepath,
409+ const size_t buffer_size) {
410+ cudaStream_t _stream;
411+ CUDA_CHECK (cudaStreamCreate (&_stream));
412+
413+ {
414+ tf_shared_lock l (mu_);
415+
416+ string keyfile = filepath + " .keys" ;
417+ FILE* tmpfd = fopen (keyfile.c_str (), " rb" );
418+ if (tmpfd == nullptr ) {
419+ return errors::NotFound (" Failed to read key file" , keyfile);
420+ }
421+ fseek (tmpfd, 0 , SEEK_END);
422+ long int filesize = ftell (tmpfd);
423+ if (filesize <= 0 ) {
424+ fclose (tmpfd);
425+ return errors::NotFound (" Empty key file." );
426+ }
427+ size_t size = static_cast <size_t >(filesize) / sizeof (K);
428+ fseek (tmpfd, 0 , SEEK_SET);
429+ fclose (tmpfd);
430+
431+ table_->clear (_stream);
432+ CUDA_CHECK (cudaStreamSynchronize (_stream));
433+ RehashIfNeeded (_stream, size);
434+ table_->load_from_file (ctx, filepath, size, runtime_dim_, _stream,
435+ buffer_size);
436+ CUDA_CHECK (cudaStreamSynchronize (_stream));
437+ }
438+ CUDA_CHECK (cudaStreamDestroy (_stream));
439+ return Status::OK ();
440+ }
441+
390442 DataType key_dtype () const override { return DataTypeToEnum<K>::v (); }
391443 DataType value_dtype () const override { return DataTypeToEnum<V>::v (); }
392444 TensorShape key_shape () const final { return TensorShape (); }
@@ -625,6 +677,36 @@ REGISTER_KERNEL_BUILDER(
625677 Name (PREFIX_OP_NAME(CuckooHashTableExport)).Device(DEVICE_GPU),
626678 HashTableExportGpuOp);
627679
680+ // Op that export all keys and values to file.
681+ template <class K , class V >
682+ class HashTableExportToFileGpuOp : public OpKernel {
683+ public:
684+ explicit HashTableExportToFileGpuOp (OpKernelConstruction* ctx)
685+ : OpKernel(ctx) {
686+ int64 signed_buffer_size = 0 ;
687+ ctx->GetAttr (" buffer_size" , &signed_buffer_size);
688+ buffer_size_ = static_cast <size_t >(signed_buffer_size);
689+ }
690+
691+ void Compute (OpKernelContext* ctx) override {
692+ lookup::LookupInterface* table;
693+ OP_REQUIRES_OK (ctx, GetLookupTable (" table_handle" , ctx, &table));
694+ core::ScopedUnref unref_me (table);
695+
696+ const Tensor& ftensor = ctx->input (1 );
697+ OP_REQUIRES (ctx, TensorShapeUtils::IsScalar (ftensor.shape ()),
698+ errors::InvalidArgument (" filepath must be scalar." ));
699+ string filepath = string (ftensor.scalar <tstring>()().data ());
700+ lookup::CuckooHashTableOfTensorsGpu<K, V>* table_cuckoo =
701+ (lookup::CuckooHashTableOfTensorsGpu<K, V>*)table;
702+ OP_REQUIRES_OK (
703+ ctx, table_cuckoo->ExportValuesToFile (ctx, filepath, buffer_size_));
704+ }
705+
706+ private:
707+ size_t buffer_size_;
708+ };
709+
628710// Clear the table and insert data.
629711class HashTableImportGpuOp : public OpKernel {
630712 public:
@@ -651,33 +733,76 @@ REGISTER_KERNEL_BUILDER(
651733 Name (PREFIX_OP_NAME(CuckooHashTableImport)).Device(DEVICE_GPU),
652734 HashTableImportGpuOp);
653735
736+ // Op that import from file.
737+ template <class K , class V >
738+ class HashTableImportFromFileGpuOp : public OpKernel {
739+ public:
740+ explicit HashTableImportFromFileGpuOp (OpKernelConstruction* ctx)
741+ : OpKernel(ctx) {
742+ int64 signed_buffer_size = 0 ;
743+ ctx->GetAttr (" buffer_size" , &signed_buffer_size);
744+ buffer_size_ = static_cast <size_t >(signed_buffer_size);
745+ }
746+
747+ void Compute (OpKernelContext* ctx) override {
748+ lookup::LookupInterface* table;
749+ OP_REQUIRES_OK (ctx, GetLookupTable (" table_handle" , ctx, &table));
750+ core::ScopedUnref unref_me (table);
751+
752+ const Tensor& ftensor = ctx->input (1 );
753+ OP_REQUIRES (ctx, TensorShapeUtils::IsScalar (ftensor.shape ()),
754+ errors::InvalidArgument (" filepath must be scalar." ));
755+ string filepath = string (ftensor.scalar <tstring>()().data ());
756+ lookup::CuckooHashTableOfTensorsGpu<K, V>* table_cuckoo =
757+ (lookup::CuckooHashTableOfTensorsGpu<K, V>*)table;
758+ OP_REQUIRES_OK (
759+ ctx, table_cuckoo->ImportValuesFromFile (ctx, filepath, buffer_size_));
760+ }
761+
762+ private:
763+ size_t buffer_size_;
764+ };
765+
654766// Register the CuckooHashTableOfTensors op.
655767
656- #define REGISTER_KERNEL (key_dtype, value_dtype ) \
657- REGISTER_KERNEL_BUILDER ( \
658- Name (PREFIX_OP_NAME(CuckooHashTableOfTensors)) \
659- .Device(DEVICE_GPU) \
660- .TypeConstraint<key_dtype>(" key_dtype" ) \
661- .TypeConstraint<value_dtype>(" value_dtype" ), \
662- HashTableGpuOp< \
663- lookup::CuckooHashTableOfTensorsGpu<key_dtype, value_dtype>, \
664- key_dtype, value_dtype>); \
665- REGISTER_KERNEL_BUILDER (Name(PREFIX_OP_NAME(CuckooHashTableClear)) \
666- .Device(DEVICE_GPU) \
667- .TypeConstraint<key_dtype>(" key_dtype" ) \
668- .TypeConstraint<value_dtype>(" value_dtype" ), \
669- HashTableClearGpuOp<key_dtype, value_dtype>) \
670- REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \
671- .Device(DEVICE_GPU) \
672- .TypeConstraint<key_dtype>(" key_dtype" ) \
673- .TypeConstraint<value_dtype>(" value_dtype" ), \
674- HashTableAccumGpuOp<key_dtype, value_dtype>) \
675- REGISTER_KERNEL_BUILDER( \
676- Name (PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \
677- .Device(DEVICE_GPU) \
678- .TypeConstraint<key_dtype>(" Tin" ) \
679- .TypeConstraint<value_dtype>(" Tout" ), \
680- HashTableFindWithExistsGpuOp<key_dtype, value_dtype>)
768+ #define REGISTER_KERNEL (key_dtype, value_dtype ) \
769+ REGISTER_KERNEL_BUILDER ( \
770+ Name (PREFIX_OP_NAME(CuckooHashTableOfTensors)) \
771+ .Device(DEVICE_GPU) \
772+ .TypeConstraint<key_dtype>(" key_dtype" ) \
773+ .TypeConstraint<value_dtype>(" value_dtype" ), \
774+ HashTableGpuOp< \
775+ lookup::CuckooHashTableOfTensorsGpu<key_dtype, value_dtype>, \
776+ key_dtype, value_dtype>); \
777+ REGISTER_KERNEL_BUILDER (Name(PREFIX_OP_NAME(CuckooHashTableClear)) \
778+ .Device(DEVICE_GPU) \
779+ .TypeConstraint<key_dtype>(" key_dtype" ) \
780+ .TypeConstraint<value_dtype>(" value_dtype" ), \
781+ HashTableClearGpuOp<key_dtype, value_dtype>); \
782+ REGISTER_KERNEL_BUILDER (Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \
783+ .Device(DEVICE_GPU) \
784+ .TypeConstraint<key_dtype>(" key_dtype" ) \
785+ .TypeConstraint<value_dtype>(" value_dtype" ), \
786+ HashTableAccumGpuOp<key_dtype, value_dtype>); \
787+ REGISTER_KERNEL_BUILDER (Name(PREFIX_OP_NAME(CuckooHashTableExportToFile)) \
788+ .Device(DEVICE_GPU) \
789+ .HostMemory(" filepath" ) \
790+ .TypeConstraint<key_dtype>(" key_dtype" ) \
791+ .TypeConstraint<value_dtype>(" value_dtype" ), \
792+ HashTableExportToFileGpuOp<key_dtype, value_dtype>); \
793+ REGISTER_KERNEL_BUILDER ( \
794+ Name (PREFIX_OP_NAME(CuckooHashTableImportFromFile)) \
795+ .Device(DEVICE_GPU) \
796+ .HostMemory(" filepath" ) \
797+ .TypeConstraint<key_dtype>(" key_dtype" ) \
798+ .TypeConstraint<value_dtype>(" value_dtype" ), \
799+ HashTableImportFromFileGpuOp<key_dtype, value_dtype>); \
800+ REGISTER_KERNEL_BUILDER ( \
801+ Name (PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \
802+ .Device(DEVICE_GPU) \
803+ .TypeConstraint<key_dtype>(" Tin" ) \
804+ .TypeConstraint<value_dtype>(" Tout" ), \
805+ HashTableFindWithExistsGpuOp<key_dtype, value_dtype>);
681806
682807REGISTER_KERNEL (int64, float );
683808REGISTER_KERNEL (int64, Eigen::half);
0 commit comments