Skip to content

Commit c314f43

Browse files
committed
Merge branch 'fix-pad-inconsistency-feature-extractor' of https://github.com/nvidia/nemo into fix-pad-inconsistency-feature-extractor
2 parents 5cd3bd7 + c8e467a commit c314f43

File tree

3 files changed

+41
-15
lines changed

3 files changed

+41
-15
lines changed

nemo/collections/asr/models/confidence_ensemble.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
from dataclasses import dataclass
2020
from typing import Dict, List, Optional, Union
2121

22-
import joblib
22+
try:
23+
from joblib.numpy_pickle_utils import _read_fileobject as _validate_joblib_file
24+
except ImportError:
25+
from joblib.numpy_pickle_utils import _validate_fileobject_and_memmap as _validate_joblib_file
2326
import numpy as np
2427
import torch
2528
from lightning.pytorch import Trainer
@@ -205,13 +208,16 @@ def find_class(self, module, name):
205208
warnings.simplefilter("ignore")
206209
# First try to load with our custom unpickler
207210
try:
208-
with open(file_path, 'rb') as f:
209-
unpickler = RestrictedUnpickler(f)
210-
model = unpickler.load()
211-
except (pickle.UnpicklingError, AttributeError):
212-
# If that fails, try loading with joblib's default loader first
213-
# then validate the loaded object
214-
model = joblib.load(file_path)
211+
with open(file_path, 'rb') as rawf:
212+
with _validate_joblib_file(rawf, file_path, mmap_mode=None) as stream:
213+
if isinstance(stream, tuple):
214+
stream = stream[0]
215+
216+
if isinstance(stream, str):
217+
with open(stream, "rb") as f:
218+
model = RestrictedUnpickler(f).load()
219+
else:
220+
model = RestrictedUnpickler(stream).load()
215221

216222
# Validate the loaded object is a sklearn Pipeline
217223
if not isinstance(model, Pipeline):
@@ -222,6 +228,9 @@ def find_class(self, module, name):
222228
if not (isinstance(step_obj, (StandardScaler, LogisticRegression))):
223229
raise ValueError(f"Unauthorized pipeline step: {type(step_obj)}")
224230

231+
except (pickle.UnpicklingError, AttributeError) as e:
232+
raise SecurityError(f"Failed to safely load model: {e}")
233+
225234
return model
226235

227236
except Exception as e:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
bitsandbytes==0.45.3 ; (platform_machine == 'x86_64' and platform_system != 'Darwin')
1+
bitsandbytes==0.45.5 ; (platform_machine == 'x86_64' and platform_system != 'Darwin')
22
# liger-kernel ; (platform_machine == 'x86_64' and platform_system != 'Darwin')

tests/collections/asr/test_confidence_ensembles.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
import joblib
1616
import pytest
1717
from omegaconf import DictConfig, ListConfig
18+
from sklearn.linear_model import LogisticRegression
19+
from sklearn.pipeline import Pipeline
20+
from sklearn.preprocessing import StandardScaler
1821

1922
from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel, EncDecRNNTModel
2023
from nemo.collections.asr.models.confidence_ensemble import ConfidenceEnsembleModel
@@ -98,26 +101,34 @@ class TestConfidenceEnsembles:
98101

99102
@pytest.mark.unit
100103
@pytest.mark.parametrize(
101-
"model_class0", [EncDecCTCModel, EncDecRNNTModel, EncDecHybridRNNTCTCModel],
104+
"model_class0",
105+
[EncDecCTCModel, EncDecRNNTModel, EncDecHybridRNNTCTCModel],
102106
)
103107
@pytest.mark.parametrize(
104-
"model_class1", [EncDecCTCModel, EncDecRNNTModel, EncDecHybridRNNTCTCModel],
108+
"model_class1",
109+
[EncDecCTCModel, EncDecRNNTModel, EncDecHybridRNNTCTCModel],
105110
)
106111
def test_model_creation_2models(self, tmp_path, model_class0, model_class1):
107112
"""Basic test to check that ensemble of 2 models can be created."""
108113
model_config0 = get_model_config(model_class0)
109114
model_config1 = get_model_config(model_class1)
110115

111116
# dummy pickle file for the model selection block
112-
joblib.dump({}, tmp_path / 'dummy.pkl')
117+
pipe = Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression())])
118+
joblib.dump(pipe, tmp_path / 'dummy.pkl')
113119

114120
# default confidence
115121
confidence_config = ConfidenceConfig(
116122
# we keep frame confidences and apply aggregation manually to get full-utterance confidence
117123
preserve_frame_confidence=True,
118124
exclude_blank=True,
119125
aggregation="mean",
120-
method_cfg=ConfidenceMethodConfig(name="entropy", entropy_type="renyi", alpha=0.25, entropy_norm="lin",),
126+
method_cfg=ConfidenceMethodConfig(
127+
name="entropy",
128+
entropy_type="renyi",
129+
alpha=0.25,
130+
entropy_norm="lin",
131+
),
121132
)
122133

123134
# just checking that no errors are raised when creating the model
@@ -140,15 +151,21 @@ def test_model_creation_5models(self, tmp_path):
140151
model_configs = [get_model_config(EncDecCTCModel) for _ in range(5)]
141152

142153
# dummy pickle file for the model selection block
143-
joblib.dump({}, tmp_path / 'dummy.pkl')
154+
pipe = Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression())])
155+
joblib.dump(pipe, tmp_path / 'dummy.pkl')
144156

145157
# default confidence
146158
confidence_config = ConfidenceConfig(
147159
# we keep frame confidences and apply aggregation manually to get full-utterance confidence
148160
preserve_frame_confidence=True,
149161
exclude_blank=True,
150162
aggregation="mean",
151-
method_cfg=ConfidenceMethodConfig(name="entropy", entropy_type="renyi", alpha=0.25, entropy_norm="lin",),
163+
method_cfg=ConfidenceMethodConfig(
164+
name="entropy",
165+
entropy_type="renyi",
166+
alpha=0.25,
167+
entropy_norm="lin",
168+
),
152169
)
153170

154171
# just checking that no errors are raised when creating the model

0 commit comments

Comments
 (0)