diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 956f8e16c..b53aa3ca2 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -5,6 +5,7 @@ import inspect import os import random +import threading from abc import ABC, abstractmethod from contextlib import contextmanager from datetime import datetime @@ -197,18 +198,27 @@ def __call__(self, batch_count): class SingletonMeta(type): """ - Singleton Meta. + Thread-safe Singleton Meta with double-checked locking. + Reference: https://en.wikipedia.org/wiki/Double-checked_locking """ _instances = {} + _lock = threading.Lock() def __call__(cls, *args, **kwargs): + # First check (without locking) for performance reasons if cls not in cls._instances: - cls._instances[cls] = super().__call__(*args, **kwargs) + # Acquire a lock before proceeding to the second check + with cls._lock: + # Second check with lock held to ensure thread safety + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance else: assert ( len(args) == 0 and len(kwargs) == 0 - ), f"{cls.__name__} is a singleton class and a instance has been created." + ), f"{cls.__name__} is a singleton class and an instance has been created." + return cls._instances[cls]