From 9127d1f47f367f5c9cc49c73ad73557089d02cb8 Mon Sep 17 00:00:00 2001 From: deroholic <105595360+deroholic@users.noreply.github.com> Date: Mon, 11 Mar 2024 09:58:59 +0100 Subject: [PATCH 1/2] Fix directory creating for saving with multiprocessing (#185) --- mamba_ssm/models/mixer_seq_simple.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index 2f1d97fd8..cd224738d 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -252,8 +252,7 @@ def save_pretrained(self, save_directory): Save the model and its configuration file to a directory. """ # Ensure save_directory exists - if not os.path.exists(save_directory): - os.makedirs(save_directory) + os.makedirs(save_directory, exist_ok=True) # Save the model's state_dict model_path = os.path.join(save_directory, 'pytorch_model.bin') From c2d5b88d2f57c48e8f588c3ab2994c0da695e6cc Mon Sep 17 00:00:00 2001 From: Dmovic <944388576@qq.com> Date: Mon, 18 Mar 2024 08:24:59 +0000 Subject: [PATCH 2/2] fix typos --- csrc/selective_scan/selective_scan_bwd_kernel.cuh | 2 ++ csrc/selective_scan/selective_scan_fwd_kernel.cuh | 2 ++ 2 files changed, 4 insertions(+) diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index b204ab3f6..0b159871c 100644 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -253,6 +253,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { while (left <= right) { if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { delta_a_exp = 0.f; + break; } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { left = ((left + right) >> 1) + 1; } else { @@ -356,6 +357,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { delta_a_exp.real_ = 0.f; delta_a_exp.imag_ = 0.f; + break; } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { left = ((left + right) >> 1) + 1; } else { diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 8ecf126da..42a95b9de 100644 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -223,6 +223,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { while (left <= right) { if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { thread_data[i].x = 0.f; + break; } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { left = ((left + right) >> 1) + 1; } else { @@ -248,6 +249,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) { thread_data[i].x = 0.f; thread_data[i].y = 0.f; + break; } else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) { left = ((left + right) >> 1) + 1; } else {