@@ -678,25 +678,17 @@ void launch_fattn(
678678) {
679679 constexpr int ncols = ncols1 * ncols2;
680680
681- const bool is_mla = DV == 512 ; // TODO better parameterization
682-
683681 const ggml_tensor * Q = dst->src [0 ];
684682 const ggml_tensor * K = dst->src [1 ];
685683 const ggml_tensor * V = dst->src [2 ];
686684
687- GGML_ASSERT (V || is_mla);
688-
689685 const ggml_tensor * mask = dst->src [3 ];
690686
691687 ggml_tensor * KQV = dst;
692688
693689 GGML_ASSERT (Q->type == GGML_TYPE_F32);
694690 GGML_ASSERT (KQV->type == GGML_TYPE_F32);
695691
696- GGML_ASSERT ( Q->nb [0 ] == ggml_element_size (Q));
697- GGML_ASSERT ( K->nb [0 ] == ggml_element_size (K));
698- GGML_ASSERT (!V || V->nb [0 ] == ggml_element_size (V));
699-
700692 GGML_ASSERT (!mask || mask->type == GGML_TYPE_F16);
701693 GGML_ASSERT (!mask || mask->ne [1 ] >= GGML_PAD (Q->ne [1 ], 16 ) &&
702694 " the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big" );
@@ -721,10 +713,10 @@ void launch_fattn(
721713 size_t nb12 = K->nb [2 ];
722714 size_t nb13 = K->nb [3 ];
723715
724- const char * V_data = V ? (const char *) V->data : nullptr ;
725- size_t nb21 = V ? V ->nb [1 ] : nb11 ;
726- size_t nb22 = V ? V ->nb [2 ] : nb12 ;
727- size_t nb23 = V ? V ->nb [3 ] : nb13 ;
716+ const char * V_data = (const char *) V->data ;
717+ size_t nb21 = V->nb [1 ];
718+ size_t nb22 = V->nb [2 ];
719+ size_t nb23 = V->nb [3 ];
728720
729721 if (need_f16_K && K->type != GGML_TYPE_F16) {
730722 K_f16.alloc (ggml_nelements (K));
@@ -740,8 +732,7 @@ void launch_fattn(
740732 nb13 = nb13*bs*sizeof (half)/ts;
741733 }
742734
743- if (V && need_f16_V && V->type != GGML_TYPE_F16) {
744- GGML_ASSERT (ggml_is_contiguously_allocated (V));
735+ if (need_f16_V && V->type != GGML_TYPE_F16) {
745736 V_f16.alloc (ggml_nelements (V));
746737 to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda (V->type );
747738 to_fp16 (V_data, V_f16.ptr , ggml_nelements (V), main_stream);
0 commit comments