@@ -551,32 +551,29 @@ def get_fam_loss(self, fam_target, s2anet_head_out, reg_loss_type='gwd'):
551551 fam_cls_score1 = fam_cls_score
552552
553553 feat_labels = paddle .to_tensor (feat_labels )
554- if (feat_labels >= 0 ).astype (paddle .int32 ).sum () > 0 :
555- feat_labels_one_hot = paddle .nn .functional .one_hot (
556- feat_labels , self .cls_out_channels + 1 )
557- feat_labels_one_hot = feat_labels_one_hot [:, 1 :]
558- feat_labels_one_hot .stop_gradient = True
559-
560- num_total_samples = paddle .to_tensor (
561- num_total_samples , dtype = 'float32' , stop_gradient = True )
562-
563- fam_cls = F .sigmoid_focal_loss (
564- fam_cls_score1 ,
565- feat_labels_one_hot ,
566- normalizer = num_total_samples ,
567- reduction = 'none' )
568-
569- feat_label_weights = feat_label_weights .reshape (
570- feat_label_weights .shape [0 ], 1 )
571- feat_label_weights = np .repeat (
572- feat_label_weights , self .cls_out_channels , axis = 1 )
573- feat_label_weights = paddle .to_tensor (
574- feat_label_weights , stop_gradient = True )
575-
576- fam_cls = fam_cls * feat_label_weights
577- fam_cls_total = paddle .sum (fam_cls )
578- else :
579- fam_cls_total = paddle .zeros ([0 ], dtype = fam_cls_score1 .dtype )
554+ feat_labels_one_hot = paddle .nn .functional .one_hot (
555+ feat_labels , self .cls_out_channels + 1 )
556+ feat_labels_one_hot = feat_labels_one_hot [:, 1 :]
557+ feat_labels_one_hot .stop_gradient = True
558+
559+ num_total_samples = paddle .to_tensor (
560+ num_total_samples , dtype = 'float32' , stop_gradient = True )
561+
562+ fam_cls = F .sigmoid_focal_loss (
563+ fam_cls_score1 ,
564+ feat_labels_one_hot ,
565+ normalizer = num_total_samples ,
566+ reduction = 'none' )
567+
568+ feat_label_weights = feat_label_weights .reshape (
569+ feat_label_weights .shape [0 ], 1 )
570+ feat_label_weights = np .repeat (
571+ feat_label_weights , self .cls_out_channels , axis = 1 )
572+ feat_label_weights = paddle .to_tensor (
573+ feat_label_weights , stop_gradient = True )
574+
575+ fam_cls = fam_cls * feat_label_weights
576+ fam_cls_total = paddle .sum (fam_cls )
580577 fam_cls_losses .append (fam_cls_total )
581578
582579 # step3: regression loss
@@ -673,31 +670,28 @@ def get_odm_loss(self, odm_target, s2anet_head_out, reg_loss_type='gwd'):
673670 odm_cls_score1 = odm_cls_score
674671
675672 feat_labels = paddle .to_tensor (feat_labels )
676- if (feat_labels >= 0 ).astype (paddle .int32 ).sum () > 0 :
677- feat_labels_one_hot = paddle .nn .functional .one_hot (
678- feat_labels , self .cls_out_channels + 1 )
679- feat_labels_one_hot = feat_labels_one_hot [:, 1 :]
680- feat_labels_one_hot .stop_gradient = True
681-
682- num_total_samples = paddle .to_tensor (
683- num_total_samples , dtype = 'float32' , stop_gradient = True )
684- odm_cls = F .sigmoid_focal_loss (
685- odm_cls_score1 ,
686- feat_labels_one_hot ,
687- normalizer = num_total_samples ,
688- reduction = 'none' )
689-
690- feat_label_weights = feat_label_weights .reshape (
691- feat_label_weights .shape [0 ], 1 )
692- feat_label_weights = np .repeat (
693- feat_label_weights , self .cls_out_channels , axis = 1 )
694- feat_label_weights = paddle .to_tensor (feat_label_weights )
695- feat_label_weights .stop_gradient = True
696-
697- odm_cls = odm_cls * feat_label_weights
698- odm_cls_total = paddle .sum (odm_cls )
699- else :
700- odm_cls_total = paddle .zeros ([0 ], dtype = odm_cls_score1 .dtype )
673+ feat_labels_one_hot = paddle .nn .functional .one_hot (
674+ feat_labels , self .cls_out_channels + 1 )
675+ feat_labels_one_hot = feat_labels_one_hot [:, 1 :]
676+ feat_labels_one_hot .stop_gradient = True
677+
678+ num_total_samples = paddle .to_tensor (
679+ num_total_samples , dtype = 'float32' , stop_gradient = True )
680+ odm_cls = F .sigmoid_focal_loss (
681+ odm_cls_score1 ,
682+ feat_labels_one_hot ,
683+ normalizer = num_total_samples ,
684+ reduction = 'none' )
685+
686+ feat_label_weights = feat_label_weights .reshape (
687+ feat_label_weights .shape [0 ], 1 )
688+ feat_label_weights = np .repeat (
689+ feat_label_weights , self .cls_out_channels , axis = 1 )
690+ feat_label_weights = paddle .to_tensor (feat_label_weights )
691+ feat_label_weights .stop_gradient = True
692+
693+ odm_cls = odm_cls * feat_label_weights
694+ odm_cls_total = paddle .sum (odm_cls )
701695 odm_cls_losses .append (odm_cls_total )
702696
703697 # # step3: regression loss
0 commit comments