|
BASE_SERIALIZATION_CLASSES = { |
|
"builtins": [ |
|
"Exception", "ValueError", "NotImplementedError", "AttributeError", |
|
"AssertionError" |
|
], # each Exception Error class needs to be added explicitly |
|
"collections": ["OrderedDict", "defaultdict"], |
|
"datetime": ["timedelta"], |
|
"pathlib": ["PosixPath"], |
|
"functools": ["partial"], |
|
"transformers4rec.torch.model.base": ["Model", "Head", "PredictionTask"], |
|
"transformers4rec.torch.block.base": ["SequentialBlock"], |
|
"transformers4rec.torch.features.sequence": ["TabularSequenceFeatures", "SequenceEmbeddingFeatures"], |
|
"transformers4rec.torch.features.continuous": ["ContinuousFeatures"], |
|
"transformers4rec.torch.features.embedding": ["FeatureConfig", "EmbeddingConfig", "TableConfig", "PretrainedEmbeddingFeatures"], |
|
"transformers4rec.torch.features.sparse": ["SparseFeatures"], |
|
"transformers4rec.torch.features.tabular": ["TabularFeatures"], |
|
"transformers4rec.torch.tabular.base": ["FilterFeatures", "AsTabular"], |
|
"transformers4rec.torch.tabular.aggregation": ["ConcatFeatures"], |
|
"transformers4rec.torch.block.base": ["Block", "SequentialBlock"], |
|
"transformers4rec.torch.block.mlp": ["DenseBlock"], |
|
"transformers4rec.torch.block.transformer": ["TransformerBlock"], |
|
"transformers4rec.torch.masking": ["CausalLanguageModeling"], |
|
"transformers4rec.torch.model.base": ["forward_to_prediction_fn", "Model", "Head"], |
|
"transformers4rec.torch.model.prediction_task": ["NextItemPredictionTask", "_NextItemPredictionTask"], |
|
"transformers4rec.torch.ranking_metric": ["NDCGAt", "DCGAt", "AvgPrecisionAt", "PrecisionAt", "RecallAt"], |
|
"transformers4rec.config.transformer": ["XLNetConfig"], |
|
"torch.nn.modules.container": ["ModuleList","ModuleDict"], |
|
"torch.nn.modules.loss": ["CrossEntropyLoss"], |
|
"merlin_standard_lib.schema.schema": ["Schema", "ColumnSchema"], |
|
"merlin_standard_lib.proto.schema_bp": [ |
|
"FeaturePresence", "FeaturePresenceWithinGroup", "FeatureType", "FixedShape", "ValueCount", "ValueCountList", |
|
"IntDomain", "FloatDomain", "StringDomain", "BoolDomain", "StructDomain", "NaturalLanguageDomain", |
|
"FeatureCoverageConstraints", "SequenceLengthConstraints", "ImageDomain", "MIDDomain", "URLDomain", |
|
"TimeDomain", "TimeOfDayDomain", "DistributionConstraints", "Annotation", "FeatureComparator", "InfinityNorm", |
|
"JensenShannonDivergence", "UniqueConstraints", "DatasetConstraints", "NumericValueComparator"], |
|
"torch.nn.init": ["kaiming_normal_", "kaiming_uniform_", "xavier_normal_", "xavier_uniform_", "uniform_", "normal_", "zeros_", "ones_"], |
|
"torch._utils": ["_rebuild_tensor_v2", "_rebuild_parameter"], |
|
"torch": ["Size", "device"], |
|
"torch.storage": ["_load_from_bytes"], |
|
"torch._C._nn": ["gelu"], |
|
"torch.nn.module": ["Module"], |
|
"torch.nn.modules.activation": ["ReLU", "Sigmoid", "Tanh"], |
|
"torch.nn.modules.linear": ["Linear", "Identity"], |
|
"torch.nn.modules.conv": ["Conv1d", "Conv2d", "Conv3d"], |
|
"torch.nn.modules.pooling": ["MaxPool1d", "MaxPool2d", "MaxPool3d"], |
|
"torch.nn.modules.normalization": ["BatchNorm1d", "BatchNorm2d", "BatchNorm3d", "LayerNorm"], |
|
"torch.nn.modules.dropout": ["Dropout", "Dropout2d", "Dropout3d"], |
|
"torch.nn.modules.rnn": ["RNN", "LSTM", "GRU"], |
|
"torch.nn.modules.sparse": ["Embedding"], |
|
"torch.optim.adam": ["Adam"], |
|
"torchmetrics.metric": ["jit_distributed_available"], |
|
"torchmetrics.utilities.data": ["dim_zero_cat"], |
|
"transformers.models.xlnet.modeling_xlnet": ["XLNetModel", "XLNetLayer", "XLNetRelativeAttention", "XLNetFeedForward"], |
|
"transformers.activations": ["GELUActivation"], |
|
"transformers.modeling_utils": ["SequenceSummary"], |
|
"builtins": ["getattr"], |
|
|
|
} |
Bug description
The
BASE_SERIALIZATION_CLASSESallowlist introduced by PR #802 (commitab7207cf, "Sec pic fix") intransformers4rec/utils/serialization.pyhas two structural issues that defeat the intended restriction onUnpickler.find_class.File (at
ab7207cf):Transformers4Rec/transformers4rec/utils/serialization.py
Lines 11 to 68 in ab7207c
The file was subsequently removed in PR #808 (
41b14d7b), but checkpoints produced during the#802 → #807window still exist in the wild and the design questions apply to any follow-up allowlist.(1)
builtinskey appears twice. Python dict literals are last-write-wins, so only the second declaration survives:Net result at runtime:
builtinsmaps to["getattr"]. The Exception subclasses (intentionally allowlisted) are silently dropped, andbuiltins.getattr— a well-known primitive in pickle gadget chains (attribute traversal → code execution) — is approved.(2)
torch.storage._load_from_bytesis in the allowlist._load_from_byteswrapstorch.load, which itself performs unrestricted pickle deserialization when the installed PyTorch version predates theweights_only=Truedefault. Combined withgetattr, this provides a reachable path from the restricted unpickler to arbitrary code execution.(3)
transformers4rec.torch.model.baseis also duplicated. The first declaration includesPredictionTask; the second (which wins) does not:So any checkpoint containing
PredictionTaskfails to deserialize withValueErrorfrom the restricted unpickler. This is an availability regression, not a security one, but it indicates the dict literal was not reviewed carefully before merge.Steps/Code to reproduce bug
Structural evidence (no exploit payload):
Evaluating the module confirms the surviving values:
I am intentionally not attaching an end-to-end exploit payload to a public issue; the combination of
builtins.getattr+torch.storage._load_from_bytesis sufficient for any reader familiar with pickle gadget chains to reproduce.Expected behavior
builtins.getattrortorch.storage._load_from_bytes.flake8-bugbear B033or similar lint would have caught this).safetensors, ortorch.load(..., weights_only=True)after bumping the minimum PyTorch version (see the companiontorch.loadissue I am filing).Environment details
ab7207cf(PR Sec pic fix #802) through0e31f575(PR Fix scaler #807); file removed at41b14d7b(PR Fix scaler #808) but affected checkpoints persist._load_from_bytesgadget exists across all modern versions)Additional context
Suggested remediation order:
builtins.getattrandtorch.storage._load_from_bytes, de-duplicate all keys.safetensorsis the common choice for HF-adjacent projects).torch.loadwith an explicitweights_only=True(see the companion issue).I can prepare separate PRs for each step if the maintainers prefer.