1515import joblib
1616import pytest
1717from omegaconf import DictConfig , ListConfig
18+ from sklearn .linear_model import LogisticRegression
19+ from sklearn .pipeline import Pipeline
20+ from sklearn .preprocessing import StandardScaler
1821
1922from nemo .collections .asr .models import EncDecCTCModel , EncDecHybridRNNTCTCModel , EncDecRNNTModel
2023from nemo .collections .asr .models .confidence_ensemble import ConfidenceEnsembleModel
@@ -98,26 +101,34 @@ class TestConfidenceEnsembles:
98101
99102 @pytest .mark .unit
100103 @pytest .mark .parametrize (
101- "model_class0" , [EncDecCTCModel , EncDecRNNTModel , EncDecHybridRNNTCTCModel ],
104+ "model_class0" ,
105+ [EncDecCTCModel , EncDecRNNTModel , EncDecHybridRNNTCTCModel ],
102106 )
103107 @pytest .mark .parametrize (
104- "model_class1" , [EncDecCTCModel , EncDecRNNTModel , EncDecHybridRNNTCTCModel ],
108+ "model_class1" ,
109+ [EncDecCTCModel , EncDecRNNTModel , EncDecHybridRNNTCTCModel ],
105110 )
106111 def test_model_creation_2models (self , tmp_path , model_class0 , model_class1 ):
107112 """Basic test to check that ensemble of 2 models can be created."""
108113 model_config0 = get_model_config (model_class0 )
109114 model_config1 = get_model_config (model_class1 )
110115
111116 # dummy pickle file for the model selection block
112- joblib .dump ({}, tmp_path / 'dummy.pkl' )
117+ pipe = Pipeline ([("scaler" , StandardScaler ()), ("clf" , LogisticRegression ())])
118+ joblib .dump (pipe , tmp_path / 'dummy.pkl' )
113119
114120 # default confidence
115121 confidence_config = ConfidenceConfig (
116122 # we keep frame confidences and apply aggregation manually to get full-utterance confidence
117123 preserve_frame_confidence = True ,
118124 exclude_blank = True ,
119125 aggregation = "mean" ,
120- method_cfg = ConfidenceMethodConfig (name = "entropy" , entropy_type = "renyi" , alpha = 0.25 , entropy_norm = "lin" ,),
126+ method_cfg = ConfidenceMethodConfig (
127+ name = "entropy" ,
128+ entropy_type = "renyi" ,
129+ alpha = 0.25 ,
130+ entropy_norm = "lin" ,
131+ ),
121132 )
122133
123134 # just checking that no errors are raised when creating the model
@@ -140,15 +151,21 @@ def test_model_creation_5models(self, tmp_path):
140151 model_configs = [get_model_config (EncDecCTCModel ) for _ in range (5 )]
141152
142153 # dummy pickle file for the model selection block
143- joblib .dump ({}, tmp_path / 'dummy.pkl' )
154+ pipe = Pipeline ([("scaler" , StandardScaler ()), ("clf" , LogisticRegression ())])
155+ joblib .dump (pipe , tmp_path / 'dummy.pkl' )
144156
145157 # default confidence
146158 confidence_config = ConfidenceConfig (
147159 # we keep frame confidences and apply aggregation manually to get full-utterance confidence
148160 preserve_frame_confidence = True ,
149161 exclude_blank = True ,
150162 aggregation = "mean" ,
151- method_cfg = ConfidenceMethodConfig (name = "entropy" , entropy_type = "renyi" , alpha = 0.25 , entropy_norm = "lin" ,),
163+ method_cfg = ConfidenceMethodConfig (
164+ name = "entropy" ,
165+ entropy_type = "renyi" ,
166+ alpha = 0.25 ,
167+ entropy_norm = "lin" ,
168+ ),
152169 )
153170
154171 # just checking that no errors are raised when creating the model
0 commit comments