Skip to content

Commit c84e924

Browse files
authored
[Minor] Fix a dtype bug (#79)
1 parent c9d5b6d commit c84e924

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

cacheflow/models/model_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@
3737

3838

3939
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
40-
config_dtype: torch.dtype = getattr(config, 'torch_dtype', torch.float32)
40+
# NOTE: getattr(config, 'torch_dtype', torch.float32) is not correct
41+
# because config.torch_dtype can be None.
42+
config_dtype = getattr(config, 'torch_dtype', None)
43+
if config_dtype is None:
44+
config_dtype = torch.float32
4145
if dtype == 'default':
4246
if config_dtype == torch.float32:
4347
# Following the common practice, we use float16 for float32 models.

0 commit comments

Comments
 (0)