@@ -151,6 +151,12 @@ struct rpc_msg_buffer_clear_req {
151151 uint8_t value;
152152};
153153
154+ struct rpc_msg_set_tensor_hash_req {
155+ rpc_tensor tensor;
156+ uint64_t offset;
157+ uint64_t hash;
158+ };
159+
154160struct rpc_msg_set_tensor_hash_rsp {
155161 uint8_t result;
156162};
@@ -548,15 +554,12 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
548554 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
549555 rpc_tensor rpc_tensor = serialize_tensor (tensor);
550556 if (size > HASH_THRESHOLD) {
551- // input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes)
552- size_t input_size = sizeof (rpc_tensor) + sizeof (uint64_t ) + sizeof (uint64_t );
553- std::vector<uint8_t > input (input_size, 0 );
554- uint64_t hash = fnv_hash ((const uint8_t *)data, size);
555- memcpy (input.data (), &rpc_tensor, sizeof (rpc_tensor));
556- memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
557- memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), &hash, sizeof (hash));
557+ rpc_msg_set_tensor_hash_req request;
558+ request.tensor = rpc_tensor;
559+ request.offset = offset;
560+ request.hash = fnv_hash ((const uint8_t *)data, size);
558561 rpc_msg_set_tensor_hash_rsp response;
559- bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR_HASH, input. data (), input. size ( ), &response, sizeof (response));
562+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR_HASH, &request, sizeof (request ), &response, sizeof (response));
560563 GGML_ASSERT (status);
561564 if (response.result ) {
562565 // the server has the same data, no need to send it
@@ -864,7 +867,7 @@ class rpc_server {
864867 bool free_buffer (const rpc_msg_free_buffer_req & request);
865868 bool buffer_clear (const rpc_msg_buffer_clear_req & request);
866869 bool set_tensor (const std::vector<uint8_t > & input);
867- bool set_tensor_hash (const std::vector< uint8_t > & input , rpc_msg_set_tensor_hash_rsp & response);
870+ bool set_tensor_hash (const rpc_msg_set_tensor_hash_req & request , rpc_msg_set_tensor_hash_rsp & response);
868871 bool get_tensor (const rpc_msg_get_tensor_req & request, std::vector<uint8_t > & response);
869872 bool copy_tensor (const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
870873 bool graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response);
@@ -1101,18 +1104,10 @@ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
11011104 return true ;
11021105}
11031106
1104- bool rpc_server::set_tensor_hash (const std::vector< uint8_t > & input , rpc_msg_set_tensor_hash_rsp & response)
1107+ bool rpc_server::set_tensor_hash (const rpc_msg_set_tensor_hash_req & request , rpc_msg_set_tensor_hash_rsp & response)
11051108{
1106- // serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
1107- if (input.size () != sizeof (rpc_tensor) + 16 ) {
1108- return false ;
1109- }
1110- const rpc_tensor * in_tensor = (const rpc_tensor *)input.data ();
1111- uint64_t offset;
1112- memcpy (&offset, input.data () + sizeof (rpc_tensor), sizeof (offset));
1113- const uint64_t * hash = (const uint64_t *)(input.data () + sizeof (rpc_tensor) + sizeof (offset));
11141109 std::vector<uint8_t > cached_file;
1115- if (!get_cached_file (* hash, cached_file)) {
1110+ if (!get_cached_file (request. hash , cached_file)) {
11161111 response.result = 0 ;
11171112 return true ;
11181113 }
@@ -1125,25 +1120,28 @@ bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set
11251120 ggml_context_ptr ctx_ptr { ggml_init (params) };
11261121 GGML_ASSERT (ctx_ptr != nullptr );
11271122 ggml_context * ctx = ctx_ptr.get ();
1128- ggml_tensor * tensor = deserialize_tensor (ctx, in_tensor );
1123+ ggml_tensor * tensor = deserialize_tensor (ctx, &request. tensor );
11291124 if (tensor == nullptr ) {
11301125 GGML_LOG_ERROR (" [%s] error deserializing tensor\n " , __func__);
11311126 return false ;
11321127 }
1133- GGML_PRINT_DEBUG (" [%s] buffer: %p, data: %p, offset: %" PRIu64 " , size: %zu, hash: %" PRIx64 " \n " , __func__, (void *)tensor->buffer , tensor->data , offset, size, *hash);
1128+ GGML_PRINT_DEBUG (" [%s] buffer: %p, data: %p, offset: %" PRIu64 " , size: %zu, hash: %" PRIx64 " \n " ,
1129+ __func__, (void *)tensor->buffer , tensor->data , request.offset , size, request.hash );
11341130
11351131 // sanitize tensor->data
11361132 {
11371133 const size_t p0 = (size_t ) ggml_backend_buffer_get_base (tensor->buffer );
11381134 const size_t p1 = p0 + ggml_backend_buffer_get_size (tensor->buffer );
11391135
1140- if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
1136+ if (request.tensor .data + request.offset < p0
1137+ || request.tensor .data + request.offset >= p1
1138+ || size > (p1 - request.tensor .data - request.offset )) {
11411139 GGML_LOG_ERROR (" [%s] tensor data region (data=0x%" PRIx64 " , offset=%" PRIu64 " , size=%zu, hash=0x%" PRIx64 " ) out of buffer bounds [0x%zx, 0x%zx)\n " ,
1142- __func__, in_tensor-> data , offset, size, * hash, p0, p1);
1140+ __func__, request. tensor . data , request. offset , size, request. hash , p0, p1);
11431141 return false ;
11441142 }
11451143 }
1146- ggml_backend_tensor_set (tensor, cached_file.data (), offset, size);
1144+ ggml_backend_tensor_set (tensor, cached_file.data (), request. offset , size);
11471145 response.result = 1 ;
11481146 return true ;
11491147}
@@ -1503,12 +1501,12 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
15031501 break ;
15041502 }
15051503 case RPC_CMD_SET_TENSOR_HASH: {
1506- std::vector< uint8_t > input ;
1507- if (!recv_msg (sockfd, input )) {
1504+ rpc_msg_set_tensor_hash_req request ;
1505+ if (!recv_msg (sockfd, &request, sizeof (request) )) {
15081506 return ;
15091507 }
15101508 rpc_msg_set_tensor_hash_rsp response;
1511- if (!server.set_tensor_hash (input , response)) {
1509+ if (!server.set_tensor_hash (request , response)) {
15121510 return ;
15131511 }
15141512 if (!send_msg (sockfd, &response, sizeof (response))) {
0 commit comments