From 06d13926c86f8aaf6e869bfe54ca9a53f37bed33 Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 25 Jul 2024 03:45:52 +0000 Subject: [PATCH 1/9] add session_ids arg for multithread use of pipeline.stream_infer --- lmdeploy/serve/async_engine.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 0c00adcc89..2fdff0d49c 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -353,6 +353,7 @@ async def get_generator(self, stop: bool, session_id: int): def batch_infer( self, prompts: Union[List[str], str, List[Dict], List[List[Dict]]], + session_ids: Union[List[int], int] = None, gen_config: Optional[Union[GenerationConfig, List[GenerationConfig], EngineGenerationConfig, @@ -367,6 +368,8 @@ def batch_infer( prompts (List[str] | str | List[Dict] | List[Dict]): a batch of prompts. It accepts: string prompt, a list of string prompts, a chat history in OpenAI format or a list of chat history. + session_ids (List[int] | int): a batch of session_ids. If not + provided, it will be [0, number of prompts] gen_config (GenerationConfig | None): a instance of or a list of GenerationConfig. Default to None. do_preprocess (bool): whether pre-process the messages. Default to @@ -389,7 +392,15 @@ def batch_infer( assert len(prompts) == len(gen_config),\ 'input gen_confg length differs from the length of prompts' # noqa prompt_num = len(prompts) - outputs = [Response('', 0, 0, i) for i in range(prompt_num)] + if session_ids is None: + session_ids = range(prompt_num) + elif isinstance(session_ids, int): + session_ids = [session_ids] + assert len(prompts) == len(session_ids), \ + 'input session_ids length differs from the length of prompts' + outputs = [ + Response('', 0, 0, session_ids[i]) for i in range(prompt_num) + ] generators = [] if use_tqdm: import tqdm @@ -397,7 +408,7 @@ def batch_infer( for i, prompt in enumerate(prompts): generators.append( self.generate(prompt, - i, + session_ids[i], gen_config=gen_config[i], stream_response=True, sequence_start=True, @@ -432,6 +443,7 @@ async def gather(): def stream_infer( self, prompts: Union[List[str], str, List[Dict], List[List[Dict]]], + session_ids: Union[List[int], int] = None, gen_config: Optional[Union[GenerationConfig, List[GenerationConfig], EngineGenerationConfig, @@ -445,6 +457,8 @@ def stream_infer( prompts (List[str] | str | List[Dict] | List[Dict]): a batch of prompts. It accepts: string prompt, a list of string prompts, a chat history in OpenAI format or a list of chat history. + session_ids (List[int] | int): a batch of session_ids. If not + provided, it will be [0, number of prompts] gen_config (GenerationConfig | None): a instance of or a list of GenerationConfig. Default to None. do_preprocess (bool): whether pre-process the messages. Default to @@ -465,12 +479,18 @@ def stream_infer( gen_config = [gen_config] * len(prompts) assert len(prompts) == len(gen_config),\ 'input gen_confg length differs from the length of prompts' # noqa + if session_ids is None: + session_ids = range(len(prompts)) + elif isinstance(session_ids, int): + session_ids = [session_ids] + assert len(prompts) == len(session_ids), \ + 'input session_ids length differs from the length of prompts' # noqa outputs = Queue() generators = [] for i, prompt in enumerate(prompts): generators.append( self.generate(prompt, - i, + session_ids[i], gen_config=gen_config[i], stream_response=True, sequence_start=True, @@ -487,8 +507,10 @@ async def _inner_call(i, generator): out.token_ids, out.logprobs)) async def gather(): - await asyncio.gather( - *[_inner_call(i, generators[i]) for i in range(len(prompts))]) + await asyncio.gather(*[ + _inner_call(session_ids[i], generators[i]) + for i in range(len(prompts)) + ]) outputs.put(None) loop = _get_event_loop() From 2b74d4623d7ddcc13af8fc2a231098c50cfc42cd Mon Sep 17 00:00:00 2001 From: irexyc Date: Tue, 30 Jul 2024 21:54:58 +0800 Subject: [PATCH 2/9] Revert "disable peer access code (#2082)" This reverts commit 263e8cfbced7d8261a1f66223ade9427af795eba. --- src/turbomind/utils/allocator.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/turbomind/utils/allocator.h b/src/turbomind/utils/allocator.h index d995e2a9bc..2a5d01cd06 100644 --- a/src/turbomind/utils/allocator.h +++ b/src/turbomind/utils/allocator.h @@ -167,7 +167,6 @@ class Allocator: public IAllocator { check_cuda_error(cudaGetDeviceCount(&device_count)); cudaMemPool_t mempool; check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id)); -#if TM_ENABLE_CUSTOM_ALL_REDUCE cudaMemAccessDesc desc = {}; int peer_access_available = 0; for (int i = 0; i < device_count; i++) { @@ -185,7 +184,6 @@ class Allocator: public IAllocator { desc.flags = cudaMemAccessFlagsProtReadWrite; check_cuda_error(cudaMemPoolSetAccess(mempool, &desc, 1)); } -#endif // set memory pool threshold to avoid shrinking the pool uint64_t setVal = UINT64_MAX; check_cuda_error(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &setVal)); From ba2fe36f796c4e53fe0b5ad423021715896cc028 Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 1 Aug 2024 11:57:59 +0000 Subject: [PATCH 3/9] Revert "Revert "disable peer access code (#2082)"" This reverts commit 2b74d4623d7ddcc13af8fc2a231098c50cfc42cd. --- src/turbomind/utils/allocator.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/turbomind/utils/allocator.h b/src/turbomind/utils/allocator.h index 2a5d01cd06..d995e2a9bc 100644 --- a/src/turbomind/utils/allocator.h +++ b/src/turbomind/utils/allocator.h @@ -167,6 +167,7 @@ class Allocator: public IAllocator { check_cuda_error(cudaGetDeviceCount(&device_count)); cudaMemPool_t mempool; check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id)); +#if TM_ENABLE_CUSTOM_ALL_REDUCE cudaMemAccessDesc desc = {}; int peer_access_available = 0; for (int i = 0; i < device_count; i++) { @@ -184,6 +185,7 @@ class Allocator: public IAllocator { desc.flags = cudaMemAccessFlagsProtReadWrite; check_cuda_error(cudaMemPoolSetAccess(mempool, &desc, 1)); } +#endif // set memory pool threshold to avoid shrinking the pool uint64_t setVal = UINT64_MAX; check_cuda_error(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &setVal)); From 76c9fb94cd765acc869683ee5c50ca3b55013ae8 Mon Sep 17 00:00:00 2001 From: irexyc Date: Thu, 1 Aug 2024 12:25:11 +0000 Subject: [PATCH 4/9] update --- lmdeploy/messages.py | 6 +++-- lmdeploy/serve/async_engine.py | 41 +++++++++++++--------------------- 2 files changed, 19 insertions(+), 28 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 46e569a372..35753ed1b9 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -253,8 +253,7 @@ class Response: generate_token_len (int): the response token length. input_token_len (int): the input prompt token length. Note that it may contains chat template part. - session_id (int): the id for running the session. Basically, it refers - to the position index of the input request batch. + session_id (int): the id for running the session. finish_reason ('stop' | 'length' | None): the reason the model stopped generating tokens. This will be 'stop' if the model hit a natural stop point or a provided stop sequence, 'length' if the maximum @@ -262,6 +261,8 @@ class Response: token_ids: (List[int]): the output token ids. logprobs: (List[Dict[int, float]]): the top logprobs for each output position. + index_id (int): it refers to the position index of the input request + batch """ text: str generate_token_len: int @@ -270,6 +271,7 @@ class Response: finish_reason: Optional[Literal['stop', 'length']] = None token_ids: List[int] = field(default_factory=list) logprobs: List[Dict[int, float]] = None + index_id: int = 0 @dataclass diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 2fdff0d49c..cf4759cf09 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -215,6 +215,7 @@ def __init__(self, self.gens_set = set() for i in range(self.instance_num): self.gens_set.add(self.engine.create_instance()) + self._session_ids = count(0) def _build_turbomind( self, @@ -353,7 +354,6 @@ async def get_generator(self, stop: bool, session_id: int): def batch_infer( self, prompts: Union[List[str], str, List[Dict], List[List[Dict]]], - session_ids: Union[List[int], int] = None, gen_config: Optional[Union[GenerationConfig, List[GenerationConfig], EngineGenerationConfig, @@ -368,8 +368,6 @@ def batch_infer( prompts (List[str] | str | List[Dict] | List[Dict]): a batch of prompts. It accepts: string prompt, a list of string prompts, a chat history in OpenAI format or a list of chat history. - session_ids (List[int] | int): a batch of session_ids. If not - provided, it will be [0, number of prompts] gen_config (GenerationConfig | None): a instance of or a list of GenerationConfig. Default to None. do_preprocess (bool): whether pre-process the messages. Default to @@ -392,14 +390,10 @@ def batch_infer( assert len(prompts) == len(gen_config),\ 'input gen_confg length differs from the length of prompts' # noqa prompt_num = len(prompts) - if session_ids is None: - session_ids = range(prompt_num) - elif isinstance(session_ids, int): - session_ids = [session_ids] - assert len(prompts) == len(session_ids), \ - 'input session_ids length differs from the length of prompts' + session_ids = [next(self._session_ids) for _ in range(prompt_num)] outputs = [ - Response('', 0, 0, session_ids[i]) for i in range(prompt_num) + Response('', 0, 0, session_ids[i], index_id=i) + for i in range(prompt_num) ] generators = [] if use_tqdm: @@ -443,7 +437,6 @@ async def gather(): def stream_infer( self, prompts: Union[List[str], str, List[Dict], List[List[Dict]]], - session_ids: Union[List[int], int] = None, gen_config: Optional[Union[GenerationConfig, List[GenerationConfig], EngineGenerationConfig, @@ -457,8 +450,6 @@ def stream_infer( prompts (List[str] | str | List[Dict] | List[Dict]): a batch of prompts. It accepts: string prompt, a list of string prompts, a chat history in OpenAI format or a list of chat history. - session_ids (List[int] | int): a batch of session_ids. If not - provided, it will be [0, number of prompts] gen_config (GenerationConfig | None): a instance of or a list of GenerationConfig. Default to None. do_preprocess (bool): whether pre-process the messages. Default to @@ -479,12 +470,7 @@ def stream_infer( gen_config = [gen_config] * len(prompts) assert len(prompts) == len(gen_config),\ 'input gen_confg length differs from the length of prompts' # noqa - if session_ids is None: - session_ids = range(len(prompts)) - elif isinstance(session_ids, int): - session_ids = [session_ids] - assert len(prompts) == len(session_ids), \ - 'input session_ids length differs from the length of prompts' # noqa + session_ids = [next(self._session_ids) for _ in range(len(prompts))] outputs = Queue() generators = [] for i, prompt in enumerate(prompts): @@ -502,15 +488,18 @@ def stream_infer( async def _inner_call(i, generator): async for out in generator: outputs.put( - Response(out.response, out.generate_token_len, - out.input_token_len, i, out.finish_reason, - out.token_ids, out.logprobs)) + Response(out.response, + out.generate_token_len, + out.input_token_len, + session_ids[i], + out.finish_reason, + out.token_ids, + out.logprobs, + index_id=i)) async def gather(): - await asyncio.gather(*[ - _inner_call(session_ids[i], generators[i]) - for i in range(len(prompts)) - ]) + await asyncio.gather( + *[_inner_call(i, generators[i]) for i in range(len(prompts))]) outputs.put(None) loop = _get_event_loop() From bba878fc48817c5beab3b088e0713f91aaa3232e Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Thu, 1 Aug 2024 13:39:52 +0000 Subject: [PATCH 5/9] add peer allocator --- src/turbomind/models/llama/LlamaBatch.cc | 9 +- src/turbomind/models/llama/LlamaBatch.h | 1 + src/turbomind/models/llama/LlamaV2.cc | 3 + src/turbomind/models/llama/LlamaV2.h | 2 + .../triton_backend/llama/LlamaTritonModel.cc | 20 ++-- .../llama/LlamaTritonModelInstance.cc | 4 - .../llama/LlamaTritonModelInstance.h | 1 + src/turbomind/utils/allocator.h | 104 ++++++++++-------- 8 files changed, 80 insertions(+), 64 deletions(-) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 58649d3a21..91d8bd4218 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -715,7 +715,7 @@ void LlamaBatch::AllocateBuffer(size_t batch_size, size_t session_len, int ca context_decoder_output_buf_ = (T*)allocator_->reMalloc(context_decoder_output_buf_, sz, false); } else { - context_decoder_output_buf_ = (T*)allocator_->reMalloc( + context_decoder_output_buf_ = (T*)peer_allocator_->reMalloc( context_decoder_output_buf_, sizeof(T) * max_forward_token_num_ * hidden_units, false); } @@ -850,7 +850,7 @@ void LlamaBatch::FreeBuffer() TM_LOG_DEBUG(__PRETTY_FUNCTION__); if (is_allocate_buffer_) { allocator_->free((void**)&context_decoder_input_buf_); - allocator_->free((void**)&context_decoder_output_buf_); + peer_allocator_->free((void**)&context_decoder_output_buf_); allocator_->free((void**)&context_decoder_ids_buf_); allocator_->free((void**)&lora_mask_buf_); @@ -871,7 +871,7 @@ void LlamaBatch::FreeBuffer() allocator_->free((void**)&local_logits_buf_); if (local_context_logits_buf_) { - allocator_->free((void**)&local_context_logits_buf_); + peer_allocator_->free((void**)&local_context_logits_buf_); } if (context_logits_buf_) { allocator_->free((void**)&context_logits_buf_); @@ -944,6 +944,7 @@ LlamaBatch::LlamaBatch(const EngineParams& params, int cache_block_seq_len, i { stream_ = model_->stream_; allocator_ = model_->allocator_; + peer_allocator_ = model_->peer_allcator_; cublas_wrapper_ = model_->cublas_wrapper_; const int elem_bits = quant_policy ? quant_policy : bitsof; @@ -1172,7 +1173,7 @@ void LlamaBatch::OutputContextLogits(T* cont NcclGuard guard(model_->tensor_para_, stream_, true); FT_CHECK(model_->vocab_size_padded_ % tp == 0); const auto local_vocab_size = model_->vocab_size_padded_ / tp; - local_context_logits_buf_ = (float*)allocator_->reMalloc( + local_context_logits_buf_ = (float*)peer_allocator_->reMalloc( local_context_logits_buf_, sizeof(float) * model_->vocab_size_padded_ * num_token, false); } } diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 12897931f8..f0345af6d2 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -288,6 +288,7 @@ class LlamaBatch { cudaStream_t stream_{}; cublasMMWrapper* cublas_wrapper_{}; IAllocator* allocator_{}; + IAllocator* peer_allocator_{}; std::thread internal_thread_; diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index d08b0210b0..e42fef226a 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -33,6 +33,7 @@ #include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/models/llama/unified_decoder.h" #include "src/turbomind/utils/Tensor.h" +#include "src/turbomind/utils/allocator.h" #include "src/turbomind/utils/anomaly_handler.h" #include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/logger.h" @@ -64,6 +65,7 @@ LlamaV2::LlamaV2(size_t head_num, cudaStream_t stream, cublasMMWrapper* cublas_wrapper, IAllocator* allocator, + IAllocator* peer_alloctor, bool is_free_buffer_after_forward, cudaDeviceProp* cuda_device_prop): head_num_(head_num), @@ -84,6 +86,7 @@ LlamaV2::LlamaV2(size_t head_num, stream_(stream), cublas_wrapper_(cublas_wrapper), allocator_(allocator), + peer_allcator_(peer_alloctor), is_free_buffer_after_forward_(is_free_buffer_after_forward), cuda_device_prop_(cuda_device_prop), debug_(isDebug()), diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index f8638614f9..61d83b90e0 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -75,6 +75,7 @@ class LlamaV2 { cudaStream_t stream, cublasMMWrapper* cublas_wrapper, IAllocator* allocator, + IAllocator* peer_allocator, bool is_free_buffer_after_forward, cudaDeviceProp* cuda_device_prop); @@ -179,6 +180,7 @@ class LlamaV2 { cudaStream_t stream_; cublasMMWrapper* cublas_wrapper_; IAllocator* allocator_; + IAllocator* peer_allcator_; bool is_free_buffer_after_forward_; cudaDeviceProp* cuda_device_prop_; diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index ba38fa785c..34f410d445 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -288,15 +288,16 @@ std::unique_ptr> LlamaTritonModel::createSh ft::check_cuda_error(cudaSetDevice(device_id)); const int comms_rank = device_id % (tensor_para_size_ * pipeline_para_size_); - std::unique_ptr> allocator( - new ft::Allocator(device_id)); - /// TODO: this stream handle is leaked cudaStream_t stream{}; ft::check_cuda_error(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + auto allocator = std::make_unique>(device_id, false); allocator->setStream(stream); + auto peer_allocator = std::make_unique>(device_id, true); + peer_allocator->setStream(stream); + cublasHandle_t cublas_handle; cublasLtHandle_t cublaslt_handle; @@ -353,11 +354,13 @@ std::unique_ptr> LlamaTritonModel::createSh stream, cublas_wrapper.get(), allocator.get(), + peer_allocator.get(), false, // is_free_buffer_after_forward, cuda_device_prop_ptr.get()); return std::make_unique>( LlamaTritonSharedModelInstance{std::move(allocator), + std::move(peer_allocator), std::move(cublas_algo_map), std::move(cublas_wrapper_mutex), std::move(cublas_wrapper), @@ -389,8 +392,7 @@ LlamaTritonModel::createModelInstance(int } } - std::unique_ptr> allocator( - new ft::Allocator(device_id)); + auto allocator = std::make_unique>(device_id, false); allocator->setStream(stream); @@ -441,10 +443,10 @@ template std::string LlamaTritonModel::toString() { std::stringstream ss; - ss << "Model: " - << "\nhead_num: " << head_num_ << "\nkv_head_num: " << kv_head_num_ << "\nsize_per_head: " << size_per_head_ - << "\ninter_size: " << inter_size_ << "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_ - << "\nattn_bias: " << attn_bias_ << "\nmax_batch_size: " << engine_params_.max_batch_size + ss << "Model: " << "\nhead_num: " << head_num_ << "\nkv_head_num: " << kv_head_num_ + << "\nsize_per_head: " << size_per_head_ << "\ninter_size: " << inter_size_ << "\nnum_layer: " << num_layer_ + << "\nvocab_size: " << vocab_size_ << "\nattn_bias: " << attn_bias_ + << "\nmax_batch_size: " << engine_params_.max_batch_size << "\nmax_prefill_token_num: " << engine_params_.max_prefill_token_num << "\nmax_context_token_num: " << engine_params_.max_context_token_num << "\nsession_len: " << engine_params_.session_len << "\nstep_length: " << engine_params_.step_length diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc b/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc index e3ce79826d..d133d171ef 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc @@ -250,10 +250,6 @@ void LlamaTritonModelInstance::allocateBuffer(const size_t request_batch_size d_output_ids_ = (int*)std::realloc(d_output_ids_, sizeof(int) * request_batch_size * beam_width * session_len); d_sequence_lengths_ = (int*)std::realloc(d_sequence_lengths_, sizeof(int) * request_batch_size * beam_width); - // d_output_log_probs_ = (float*)(allocator_->reMalloc( - // d_output_log_probs_, sizeof(float) * request_batch_size * beam_width * session_len, false)); - // d_cum_log_probs_ = - // (float*)(allocator_->reMalloc(d_cum_log_probs_, sizeof(float) * request_batch_size * beam_width, false)); if (is_return_logits) { d_output_logits_ = (float*)allocator_->reMalloc( d_output_logits_, sizeof(float) * request_batch_size * max_input_len * instance_->llm->vocab_size(), false); diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h b/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h index 0d69d785ce..e33b616f73 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h @@ -30,6 +30,7 @@ namespace ft = turbomind; template struct LlamaTritonSharedModelInstance { std::unique_ptr> allocator; + std::unique_ptr> peer_allocator; std::unique_ptr cublas_algo_map; std::unique_ptr cublas_wrapper_mutex; std::unique_ptr cublas_wrapper; diff --git a/src/turbomind/utils/allocator.h b/src/turbomind/utils/allocator.h index d995e2a9bc..26cfc7e55d 100644 --- a/src/turbomind/utils/allocator.h +++ b/src/turbomind/utils/allocator.h @@ -50,15 +50,13 @@ namespace turbomind { -enum class AllocatorType -{ +enum class AllocatorType { CUDA, TF, TH }; -enum class ReallocType -{ +enum class ReallocType { INCREASE, REUSE, DECREASE, @@ -69,7 +67,7 @@ class IAllocator { virtual ~IAllocator(){}; virtual void* malloc(size_t size, const bool is_set_zero = true, bool is_host = false) = 0; - virtual void free(void** ptr, bool is_host = false) const = 0; + virtual void free(void** ptr, bool is_host = false) = 0; virtual void setStream(cudaStream_t stream) = 0; virtual cudaStream_t returnStream() = 0; virtual void memSet(void* ptr, const int val, const size_t size) = 0; @@ -125,27 +123,28 @@ class Allocator; template<> class Allocator: public IAllocator { private: - enum class MemoryType - { + enum class MemoryType { HOST, DEVICE }; - const int device_id_; - cudaStream_t stream_ = 0; // initialize as default stream - std::unordered_map>* pointer_mapping_; + const int device_id_; + bool enable_peer_access_{false}; + cudaStream_t stream_ = 0; // initialize as default stream + cudaMemPool_t mempool_{}; + std::unordered_map> pointer_mapping_; bool isExist(void* address) const { - return pointer_mapping_->count(address) > 0; + return pointer_mapping_.count(address) > 0; } ReallocType isReMalloc(void* address, size_t size) const { FT_CHECK(isExist(address)); - if (pointer_mapping_->at(address).first < size) { + if (pointer_mapping_.at(address).first < size) { return ReallocType::INCREASE; } - else if (pointer_mapping_->at(address).first == size) { + else if (pointer_mapping_.at(address).first == size) { return ReallocType::REUSE; } else { @@ -154,53 +153,64 @@ class Allocator: public IAllocator { } public: - Allocator(int device_id): device_id_(device_id) + Allocator(int device_id, bool enable_peer_access): device_id_(device_id), enable_peer_access_(enable_peer_access) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); - pointer_mapping_ = new std::unordered_map>(); + // pointer_mapping_ = new std::unordered_map>(); #if defined(CUDA_MEMORY_POOL_DISABLED) TM_LOG_WARNING( "Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync cudaMalloc/Free." "Note this may lead to hang with NCCL kernels launched in parallel; if so, try NCCL_LAUNCH_MODE=GROUP"); #else - int device_count = 1; - check_cuda_error(cudaGetDeviceCount(&device_count)); - cudaMemPool_t mempool; - check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool, device_id)); -#if TM_ENABLE_CUSTOM_ALL_REDUCE - cudaMemAccessDesc desc = {}; - int peer_access_available = 0; - for (int i = 0; i < device_count; i++) { - if (i == device_id) { - continue; - } - check_cuda_error(cudaDeviceCanAccessPeer(&peer_access_available, device_id, i)); - if (!peer_access_available) { - TM_LOG_WARNING("Device " + std::to_string(device_id) + " peer access Device " + std::to_string(i) - + " is not available."); - continue; + + if (enable_peer_access) { + cudaMemPoolProps props{}; + props.allocType = cudaMemAllocationTypePinned; + props.handleTypes = cudaMemHandleTypeNone; + props.location.type = cudaMemLocationTypeDevice; + props.location.id = device_id; + cudaMemPoolCreate(&mempool_, &props); + cudaMemAccessDesc desc = {}; + int peer_access_available = 0; + int device_count = 1; + check_cuda_error(cudaGetDeviceCount(&device_count)); + for (int i = 0; i < device_count; i++) { + if (i == device_id) { + continue; + } + check_cuda_error(cudaDeviceCanAccessPeer(&peer_access_available, device_id, i)); + if (!peer_access_available) { + TM_LOG_WARNING("Devicle " + std::to_string(device_id) + " peer access Device " + std::to_string(i) + + " is not available."); + continue; + } + desc.location.type = cudaMemLocationTypeDevice; + desc.location.id = i; + desc.flags = cudaMemAccessFlagsProtReadWrite; + check_cuda_error(cudaMemPoolSetAccess(mempool_, &desc, 1)); } - desc.location.type = cudaMemLocationTypeDevice; - desc.location.id = i; - desc.flags = cudaMemAccessFlagsProtReadWrite; - check_cuda_error(cudaMemPoolSetAccess(mempool, &desc, 1)); } -#endif + else { + check_cuda_error(cudaDeviceGetDefaultMemPool(&mempool_, device_id)); + } // set memory pool threshold to avoid shrinking the pool uint64_t setVal = UINT64_MAX; - check_cuda_error(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &setVal)); + check_cuda_error(cudaMemPoolSetAttribute(mempool_, cudaMemPoolAttrReleaseThreshold, &setVal)); #endif } virtual ~Allocator() { TM_LOG_DEBUG(__PRETTY_FUNCTION__); - while (!pointer_mapping_->empty()) { - auto ptr = pointer_mapping_->begin()->first; - auto size_and_type = pointer_mapping_->begin()->second; + while (!pointer_mapping_.empty()) { + auto ptr = pointer_mapping_.begin()->first; + auto size_and_type = pointer_mapping_.begin()->second; free(&ptr, size_and_type.second == MemoryType::HOST); } - delete pointer_mapping_; + if (enable_peer_access_) { // We own the pool in this case + cudaMemPoolDestroy(mempool_); + mempool_ = {}; + } } void setStream(cudaStream_t stream) @@ -230,7 +240,7 @@ class Allocator: public IAllocator { #if defined(CUDA_MEMORY_POOL_DISABLED) check_cuda_error(cudaMalloc(&ptr, (size_t)(ceil(size / 32.)) * 32)); #else - check_cuda_error(cudaMallocAsync(&ptr, (size_t)(ceil(size / 32.)) * 32, stream_)); + check_cuda_error(cudaMallocFromPoolAsync(&ptr, (size_t)(ceil(size / 32.)) * 32, mempool_, stream_)); #endif } if (is_set_zero) { @@ -239,19 +249,19 @@ class Allocator: public IAllocator { check_cuda_error(getSetDevice(o_device)); TM_LOG_DEBUG("malloc buffer %p with size %ld", ptr, size); - pointer_mapping_->insert({getAddress(ptr), {size, is_host ? MemoryType::HOST : MemoryType::DEVICE}}); + pointer_mapping_.insert({getAddress(ptr), {size, is_host ? MemoryType::HOST : MemoryType::DEVICE}}); return ptr; } - void free(void** ptr, bool _ = false) const + void free(void** ptr, bool _ = false) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); void* address = getAddress(*ptr); if (*ptr != nullptr) { int o_device = 0; - if (pointer_mapping_->count(address)) { - const auto is_host = pointer_mapping_->at(address).second == MemoryType::HOST; + if (pointer_mapping_.count(address)) { + const auto is_host = pointer_mapping_.at(address).second == MemoryType::HOST; TM_LOG_DEBUG("Free buffer %p", address); check_cuda_error(getSetDevice(device_id_, &o_device)); if (is_host) { @@ -265,7 +275,7 @@ class Allocator: public IAllocator { #endif } check_cuda_error(getSetDevice(o_device)); - pointer_mapping_->erase(address); + pointer_mapping_.erase(address); } else { TM_LOG_WARNING("pointer_mapping_ does not have information of ptr at %p.", address); From d65f198a97b79ee82efa157629103dfde5563527 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Thu, 1 Aug 2024 14:17:07 +0000 Subject: [PATCH 6/9] fix lint --- .../triton_backend/llama/LlamaTritonModel.cc | 8 ++++---- src/turbomind/utils/allocator.h | 12 ++++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 34f410d445..87fd2cdf59 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -443,10 +443,10 @@ template std::string LlamaTritonModel::toString() { std::stringstream ss; - ss << "Model: " << "\nhead_num: " << head_num_ << "\nkv_head_num: " << kv_head_num_ - << "\nsize_per_head: " << size_per_head_ << "\ninter_size: " << inter_size_ << "\nnum_layer: " << num_layer_ - << "\nvocab_size: " << vocab_size_ << "\nattn_bias: " << attn_bias_ - << "\nmax_batch_size: " << engine_params_.max_batch_size + ss << "Model: " + << "\nhead_num: " << head_num_ << "\nkv_head_num: " << kv_head_num_ << "\nsize_per_head: " << size_per_head_ + << "\ninter_size: " << inter_size_ << "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_ + << "\nattn_bias: " << attn_bias_ << "\nmax_batch_size: " << engine_params_.max_batch_size << "\nmax_prefill_token_num: " << engine_params_.max_prefill_token_num << "\nmax_context_token_num: " << engine_params_.max_context_token_num << "\nsession_len: " << engine_params_.session_len << "\nstep_length: " << engine_params_.step_length diff --git a/src/turbomind/utils/allocator.h b/src/turbomind/utils/allocator.h index 26cfc7e55d..9e9ea21b3a 100644 --- a/src/turbomind/utils/allocator.h +++ b/src/turbomind/utils/allocator.h @@ -50,13 +50,15 @@ namespace turbomind { -enum class AllocatorType { +enum class AllocatorType +{ CUDA, TF, TH }; -enum class ReallocType { +enum class ReallocType +{ INCREASE, REUSE, DECREASE, @@ -123,7 +125,8 @@ class Allocator; template<> class Allocator: public IAllocator { private: - enum class MemoryType { + enum class MemoryType + { HOST, DEVICE }; @@ -153,7 +156,8 @@ class Allocator: public IAllocator { } public: - Allocator(int device_id, bool enable_peer_access): device_id_(device_id), enable_peer_access_(enable_peer_access) + Allocator(int device_id, bool enable_peer_access = false): + device_id_(device_id), enable_peer_access_(enable_peer_access) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); // pointer_mapping_ = new std::unordered_map>(); From 7a3cf702936b3a8e9dd7e9129a7e07bcc9bcaa30 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Thu, 1 Aug 2024 14:36:13 +0000 Subject: [PATCH 7/9] check cuda error --- src/turbomind/utils/allocator.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/turbomind/utils/allocator.h b/src/turbomind/utils/allocator.h index 9e9ea21b3a..1313bb38f1 100644 --- a/src/turbomind/utils/allocator.h +++ b/src/turbomind/utils/allocator.h @@ -173,7 +173,7 @@ class Allocator: public IAllocator { props.handleTypes = cudaMemHandleTypeNone; props.location.type = cudaMemLocationTypeDevice; props.location.id = device_id; - cudaMemPoolCreate(&mempool_, &props); + check_cuda_error(cudaMemPoolCreate(&mempool_, &props)); cudaMemAccessDesc desc = {}; int peer_access_available = 0; int device_count = 1; @@ -212,7 +212,7 @@ class Allocator: public IAllocator { free(&ptr, size_and_type.second == MemoryType::HOST); } if (enable_peer_access_) { // We own the pool in this case - cudaMemPoolDestroy(mempool_); + check_cuda_error(cudaMemPoolDestroy(mempool_)); mempool_ = {}; } } From 0aa0d498880b8b4a18df14f35f462c6a5263a84d Mon Sep 17 00:00:00 2001 From: irexyc Date: Fri, 2 Aug 2024 02:43:41 +0000 Subject: [PATCH 8/9] fix comments --- lmdeploy/messages.py | 4 ++-- lmdeploy/serve/async_engine.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 35753ed1b9..f250e55763 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -261,7 +261,7 @@ class Response: token_ids: (List[int]): the output token ids. logprobs: (List[Dict[int, float]]): the top logprobs for each output position. - index_id (int): it refers to the position index of the input request + index (int): it refers to the position index of the input request batch """ text: str @@ -271,7 +271,7 @@ class Response: finish_reason: Optional[Literal['stop', 'length']] = None token_ids: List[int] = field(default_factory=list) logprobs: List[Dict[int, float]] = None - index_id: int = 0 + index: int = 0 @dataclass diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index cf4759cf09..067d1fc9cc 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -215,7 +215,7 @@ def __init__(self, self.gens_set = set() for i in range(self.instance_num): self.gens_set.add(self.engine.create_instance()) - self._session_ids = count(0) + self._session_id = count(0) def _build_turbomind( self, @@ -390,9 +390,9 @@ def batch_infer( assert len(prompts) == len(gen_config),\ 'input gen_confg length differs from the length of prompts' # noqa prompt_num = len(prompts) - session_ids = [next(self._session_ids) for _ in range(prompt_num)] + session_ids = [next(self._session_id) for _ in range(prompt_num)] outputs = [ - Response('', 0, 0, session_ids[i], index_id=i) + Response('', 0, 0, session_ids[i], index=i) for i in range(prompt_num) ] generators = [] @@ -470,7 +470,7 @@ def stream_infer( gen_config = [gen_config] * len(prompts) assert len(prompts) == len(gen_config),\ 'input gen_confg length differs from the length of prompts' # noqa - session_ids = [next(self._session_ids) for _ in range(len(prompts))] + session_ids = [next(self._session_id) for _ in range(len(prompts))] outputs = Queue() generators = [] for i, prompt in enumerate(prompts): @@ -495,7 +495,7 @@ async def _inner_call(i, generator): out.finish_reason, out.token_ids, out.logprobs, - index_id=i)) + index=i)) async def gather(): await asyncio.gather( From 36ca39f5ab8d80d1a2dc73500b0dbf57ce1074d1 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Fri, 2 Aug 2024 05:29:02 +0000 Subject: [PATCH 9/9] fix wrong allocator --- src/turbomind/models/llama/LlamaBatch.cc | 2 +- src/turbomind/utils/allocator.h | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index 91d8bd4218..efec8f609e 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -712,7 +712,7 @@ void LlamaBatch::AllocateBuffer(size_t batch_size, size_t session_len, int ca if (model_->lora_params_.policy == LoraPolicy::kPlora) { lora_mask_buf_ = (int*)allocator_->reMalloc(lora_mask_buf_, sizeof(int) * max_forward_token_num_, false); size_t sz = sizeof(T) * max_forward_token_num_ * (hidden_units + model_->lora_params_.max_wo_r); - context_decoder_output_buf_ = (T*)allocator_->reMalloc(context_decoder_output_buf_, sz, false); + context_decoder_output_buf_ = (T*)peer_allocator_->reMalloc(context_decoder_output_buf_, sz, false); } else { context_decoder_output_buf_ = (T*)peer_allocator_->reMalloc( diff --git a/src/turbomind/utils/allocator.h b/src/turbomind/utils/allocator.h index 1313bb38f1..bdcb9bfc46 100644 --- a/src/turbomind/utils/allocator.h +++ b/src/turbomind/utils/allocator.h @@ -160,7 +160,6 @@ class Allocator: public IAllocator { device_id_(device_id), enable_peer_access_(enable_peer_access) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); - // pointer_mapping_ = new std::unordered_map>(); #if defined(CUDA_MEMORY_POOL_DISABLED) TM_LOG_WARNING( "Async cudaMalloc/Free is not supported before CUDA 11.2. Using Sync cudaMalloc/Free."