From d79b42e3e5179c18a7b66098309f02f5845b58e6 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Thu, 20 Feb 2025 16:48:22 -0500 Subject: [PATCH 1/2] add: torch function for Proxy class to automatically unproxy when dispatched --- mldaikon/proxy_wrapper/proxy.py | 16 ++++++++++++++-- mldaikon/proxy_wrapper/proxy_basics.py | 6 ++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/mldaikon/proxy_wrapper/proxy.py b/mldaikon/proxy_wrapper/proxy.py index 3b9796b7..7b95b759 100644 --- a/mldaikon/proxy_wrapper/proxy.py +++ b/mldaikon/proxy_wrapper/proxy.py @@ -23,7 +23,7 @@ from mldaikon.utils import get_timestamp_ns, typename from .dumper import json_dumper as dumper -from .proxy_basics import unproxy_arg +from .proxy_basics import unproxy_arg, unproxy_args_kwargs from .proxy_handler import PROXY_SUPPORT_OBJ_TYPES from .proxy_registry import get_global_registry from .utils import print_debug @@ -306,7 +306,7 @@ def __init__( if type(obj) is Proxy: print_debug( - "logger_proxy: " + lambda: "logger_proxy: " + f"Object '{obj.__class__.__name__}' is already a proxy" ) @@ -481,10 +481,13 @@ def __call__(self, *args, **kwargs): def __getattr__(self, name): print_debug(lambda: f"logger_proxy: Accessing attribute '{name}'") + if name == "logdir": return self.__dict__.get("logdir", None) # in order to pass down the dir if name == "_obj": return self.__dict__.get("_obj", None) # in order to pass down the dir + if name == "__torch_function__": + return Proxy._unwrapping__torch_function__ attr = getattr(self._obj, name) if self.__dict__["var_name"] == "": @@ -601,3 +604,12 @@ def print_proxy_dict(self, proxy_dict): self.print_tensor(value) else: print_debug(lambda: f"logger_proxy: {k}: {value}") + + @classmethod + def _unwrapping__torch_function__(cls, func, types, args=(), kwargs=None): + # 🚨 Ensure Proxy does not interfere with PyTorch dispatch + if kwargs is None: + kwargs = {} + + real_args, real_kwargs = unproxy_args_kwargs(args, kwargs) + return func(*real_args, **real_kwargs) diff --git a/mldaikon/proxy_wrapper/proxy_basics.py b/mldaikon/proxy_wrapper/proxy_basics.py index 8d0c047f..7fd9d30a 100644 --- a/mldaikon/proxy_wrapper/proxy_basics.py +++ b/mldaikon/proxy_wrapper/proxy_basics.py @@ -51,6 +51,12 @@ def wrapper(*args, **kwargs): return wrapper +def unproxy_args_kwargs(args, kwargs, inspect_torch_module=False): + args = [unproxy_arg(arg, inspect_torch_module) for arg in args] + kwargs = {k: unproxy_arg(v) for k, v in kwargs.items()} + return args, kwargs + + def type_handle_mldaikon_proxy(x): if hasattr(x, "is_ml_daikon_proxied_obj"): return type(x._obj) From 4286c124cf98928e3f1d67a35d7c77784c220343 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Thu, 20 Feb 2025 18:35:21 -0500 Subject: [PATCH 2/2] fix: renaming Proxy._unwrapping__torch_function__ to Proxy.__torch_function__ for compatibility with . --- mldaikon/proxy_wrapper/proxy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mldaikon/proxy_wrapper/proxy.py b/mldaikon/proxy_wrapper/proxy.py index 7b95b759..27e72a3f 100644 --- a/mldaikon/proxy_wrapper/proxy.py +++ b/mldaikon/proxy_wrapper/proxy.py @@ -487,7 +487,7 @@ def __getattr__(self, name): if name == "_obj": return self.__dict__.get("_obj", None) # in order to pass down the dir if name == "__torch_function__": - return Proxy._unwrapping__torch_function__ + return Proxy.__torch_function__ attr = getattr(self._obj, name) if self.__dict__["var_name"] == "": @@ -606,7 +606,7 @@ def print_proxy_dict(self, proxy_dict): print_debug(lambda: f"logger_proxy: {k}: {value}") @classmethod - def _unwrapping__torch_function__(cls, func, types, args=(), kwargs=None): + def __torch_function__(cls, func, types, args=(), kwargs=None): # 🚨 Ensure Proxy does not interfere with PyTorch dispatch if kwargs is None: kwargs = {}