We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c9d5b6d commit c84e924Copy full SHA for c84e924
cacheflow/models/model_utils.py
@@ -37,7 +37,11 @@
37
38
39
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
40
- config_dtype: torch.dtype = getattr(config, 'torch_dtype', torch.float32)
+ # NOTE: getattr(config, 'torch_dtype', torch.float32) is not correct
41
+ # because config.torch_dtype can be None.
42
+ config_dtype = getattr(config, 'torch_dtype', None)
43
+ if config_dtype is None:
44
+ config_dtype = torch.float32
45
if dtype == 'default':
46
if config_dtype == torch.float32:
47
# Following the common practice, we use float16 for float32 models.
0 commit comments