@@ -440,17 +440,23 @@ def weight_loader(self,
440440 param .shard_weight_type [loaded_shard_id ] = loaded_weight .item ()
441441 return
442442
443- if is_gguf_weight and isinstance (param , UninitializedParameter ):
444- from gguf .constants import GGML_QUANT_SIZES
443+ if is_gguf_weight :
444+ tp_size = get_tensor_model_parallel_world_size ()
445+ tp_rank = get_tensor_model_parallel_rank ()
446+
447+ output_dim = getattr (param , "output_dim" , None )
448+ shard_size = loaded_weight .size (output_dim ) // tp_size
449+ start_idx = tp_rank * shard_size
450+
451+ loaded_weight = loaded_weight .narrow (output_dim , start_idx ,
452+ shard_size )
445453
446- ori_shape = param .tensor_shape
447- weight_types = self .qweight_type .shard_weight_type .values ()
448- row_size = []
449- for weight_type in weight_types :
450- block_size , type_size = GGML_QUANT_SIZES [weight_type ]
451- row_size .append (ori_shape [1 ] // block_size * type_size )
452- q_shape = (ori_shape [0 ], max (row_size ))
453- param .materialize (q_shape , dtype = loaded_weight .dtype )
454+ param .shard_id .append (loaded_shard_id )
455+ param .shard_id_map [loaded_shard_id ] = len (param .data_container )
456+ param .data_container .append (loaded_weight )
457+ if len (param .data_container ) == 2 :
458+ self .qweight = param .materialize_nested ()
459+ return
454460
455461 param_data = param .data
456462 output_dim = getattr (param , "output_dim" , None )
@@ -515,18 +521,6 @@ def weight_loader(self,
515521 shard_offset = loaded_weight .shape [output_dim ] * \
516522 loaded_shard_id
517523
518- if is_gguf_weight :
519- tp_size = get_tensor_model_parallel_world_size ()
520- output_dim = getattr (param , "output_dim" , None )
521- shard_shape = list (loaded_weight .shape )
522- shard_shape [output_dim ] = shard_shape [output_dim ] // tp_size
523- param .shard_id .append (loaded_shard_id )
524- param .shard_size [loaded_shard_id ] = shard_shape
525-
526- input_dim = getattr (param , "input_dim" , None )
527- input_size = loaded_weight .shape [input_dim ]
528- param_data = param_data .narrow (input_dim , 0 , input_size )
529-
530524 param_data = param_data .narrow (output_dim , shard_offset ,
531525 shard_size )
532526 start_idx = tp_rank * shard_size
@@ -783,17 +777,23 @@ def weight_loader(self,
783777 param .shard_weight_type [loaded_shard_id ] = loaded_weight .item ()
784778 return
785779
786- if is_gguf_weight and isinstance (param , UninitializedParameter ):
787- from gguf .constants import GGML_QUANT_SIZES
780+ if is_gguf_weight :
781+ tp_size = get_tensor_model_parallel_world_size ()
782+ tp_rank = get_tensor_model_parallel_rank ()
788783
789- ori_shape = param .tensor_shape
790- weight_types = self .qweight_type .shard_weight_type .values ()
791- row_size = []
792- for weight_type in weight_types :
793- block_size , type_size = GGML_QUANT_SIZES [weight_type ]
794- row_size .append (ori_shape [1 ] // block_size * type_size )
795- q_shape = (ori_shape [0 ], max (row_size ))
796- param .materialize (q_shape , dtype = loaded_weight .dtype )
784+ output_dim = getattr (param , "output_dim" , None )
785+ shard_size = loaded_weight .size (output_dim ) // tp_size
786+ start_idx = tp_rank * shard_size
787+
788+ loaded_weight = loaded_weight .narrow (output_dim , start_idx ,
789+ shard_size )
790+
791+ param .shard_id .append (loaded_shard_id )
792+ param .shard_id_map [loaded_shard_id ] = len (param .data_container )
793+ param .data_container .append (loaded_weight )
794+ if len (param .data_container ) == 3 :
795+ self .qweight = param .materialize_nested ()
796+ return
797797
798798 param_data = param .data
799799 output_dim = getattr (param , "output_dim" , None )
@@ -883,18 +883,6 @@ def weight_loader(self,
883883 shard_size , shard_offset = adjust_bitsandbytes_4bit_shard (
884884 param , orig_qkv_offsets , loaded_shard_id )
885885
886- if is_gguf_weight :
887- tp_size = get_tensor_model_parallel_world_size ()
888- output_dim = getattr (param , "output_dim" , None )
889- shard_shape = list (loaded_weight .shape )
890- shard_shape [output_dim ] = shard_shape [output_dim ] // tp_size
891- param .shard_id .append (loaded_shard_id )
892- param .shard_size [loaded_shard_id ] = shard_shape
893-
894- input_dim = getattr (param , "input_dim" , None )
895- input_size = loaded_weight .shape [input_dim ]
896- param_data = param_data .narrow (input_dim , 0 , input_size )
897-
898886 param_data = param_data .narrow (output_dim , shard_offset ,
899887 shard_size )
900888 if loaded_shard_id == "q" :
0 commit comments