1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from __future__ import annotations
16+
1517import argparse
1618import inspect
1719import logging
4850API_DIFF_SPEC_FN = 'dev_pr_diff_api.spec'
4951TEST_TIMEOUT = 10
5052
53+ PAT_API_SPEC_MEMBER = re .compile (r'\((paddle[^,]+)\W*document\W*([0-9a-z]{32})' )
54+ # insert ArgSpec for changing the API's type annotation can trigger the CI
55+ PAT_API_SPEC_SIGNATURE = re .compile (
56+ r'^(paddle[^,]+)\s+\((ArgSpec.*),.*document\W*([0-9a-z]{32})'
57+ )
58+
5159
5260class Result :
5361 # name/key for result
@@ -66,7 +74,7 @@ class Result:
6674 order : int = 0
6775
6876 @classmethod
69- def msg (cls , count : int , env : typing . Set ) -> str :
77+ def msg (cls , count : int , env : set ) -> str :
7078 """Message for logging with api `count` and running `env`."""
7179 raise NotImplementedError
7280
@@ -85,8 +93,8 @@ class MetaResult(type):
8593 def __new__ (
8694 mcs ,
8795 name : str ,
88- bases : typing . Tuple [type , ...],
89- namespace : typing . Dict [str , typing .Any ],
96+ bases : tuple [type , ...],
97+ namespace : dict [str , typing .Any ],
9098 ) -> type :
9199 cls = super ().__new__ (mcs , name , bases , namespace )
92100 if issubclass (cls , Result ):
@@ -104,7 +112,7 @@ def get(mcs, name: str) -> type:
104112 return mcs .__cls_map .get (name )
105113
106114 @classmethod
107- def cls_map (mcs ) -> typing . Dict [str , Result ]:
115+ def cls_map (mcs ) -> dict [str , Result ]:
108116 return mcs .__cls_map
109117
110118
@@ -290,7 +298,7 @@ def prepare(self, test_capacity: set) -> None:
290298 """
291299 pass
292300
293- def run (self , api_name : str , docstring : str ) -> typing . List [TestResult ]:
301+ def run (self , api_name : str , docstring : str ) -> list [TestResult ]:
294302 """Extract codeblocks from docstring, and run the test.
295303 Run only one docstring at a time.
296304
@@ -304,7 +312,7 @@ def run(self, api_name: str, docstring: str) -> typing.List[TestResult]:
304312 raise NotImplementedError
305313
306314 def print_summary (
307- self , test_results : typing . List [TestResult ], whl_error : typing . List [str ]
315+ self , test_results : list [TestResult ], whl_error : list [str ]
308316 ) -> None :
309317 """Post process test results and print test summary.
310318
@@ -333,17 +341,17 @@ def get_api_md5(path):
333341 API_spec = os .path .abspath (os .path .join (os .getcwd (), ".." , path ))
334342 if not os .path .isfile (API_spec ):
335343 return api_md5
336- pat = re .compile (r'\((paddle[^,]+)\W*document\W*([0-9a-z]{32})' )
337- patArgSpec = re .compile (
338- r'^(paddle[^,]+)\s+\(ArgSpec.*document\W*([0-9a-z]{32})'
339- )
344+
340345 with open (API_spec ) as f :
341346 for line in f .readlines ():
342- mo = pat .search (line )
343- if not mo :
344- mo = patArgSpec .search (line )
347+ mo = PAT_API_SPEC_MEMBER .search (line )
348+
345349 if mo :
346350 api_md5 [mo .group (1 )] = mo .group (2 )
351+ else :
352+ mo = PAT_API_SPEC_SIGNATURE .search (line )
353+ api_md5 [mo .group (1 )] = f'{ mo .group (2 )} , { mo .group (3 )} '
354+
347355 return api_md5
348356
349357
@@ -397,18 +405,6 @@ def get_full_api_from_pr_spec():
397405 get_full_api_by_walk ()
398406
399407
400- def get_full_api ():
401- """
402- get all the apis
403- """
404- global API_DIFF_SPEC_FN # readonly
405- from print_signatures import get_all_api_from_modulelist
406-
407- member_dict = get_all_api_from_modulelist ()
408- with open (API_DIFF_SPEC_FN , 'w' ) as f :
409- f .write ("\n " .join (member_dict .keys ()))
410-
411-
412408def extract_code_blocks_from_docstr (docstr , google_style = True ):
413409 """
414410 extract code-blocks from the given docstring.
@@ -599,9 +595,16 @@ def get_test_capacity(run_on_device="cpu"):
599595 return sample_code_test_capacity
600596
601597
602- def get_docstring (full_test = False ):
598+ def get_docstring (
599+ full_test : bool = False ,
600+ filter_api : typing .Callable [[str ], bool ] | None = None ,
601+ ):
603602 '''
604603 this function will get the docstring for test.
604+
605+ Args:
606+ full_test, get all api
607+ filter_api, a function that filter api, if `True` then skip add to `docstrings_to_test`.
605608 '''
606609 import paddle
607610 import paddle .static .quantization # noqa: F401
@@ -616,6 +619,9 @@ def get_docstring(full_test=False):
616619 with open (API_DIFF_SPEC_FN ) as f :
617620 for line in f .readlines ():
618621 api = line .replace ('\n ' , '' )
622+ if filter_api is not None and filter_api (api .strip ()):
623+ continue
624+
619625 try :
620626 api_obj = eval (api )
621627 except AttributeError :
@@ -637,7 +643,7 @@ def get_docstring(full_test=False):
637643 return docstrings_to_test , whl_error
638644
639645
640- def check_old_style (docstrings_to_test : typing . Dict [str , str ]):
646+ def check_old_style (docstrings_to_test : dict [str , str ]):
641647 old_style_apis = []
642648 for api_name , raw_docstring in docstrings_to_test .items ():
643649 for codeblock in extract_code_blocks_from_docstr (
@@ -715,8 +721,8 @@ def exec_gen_doc():
715721
716722
717723def get_test_results (
718- doctester : DocTester , docstrings_to_test : typing . Dict [str , str ]
719- ) -> typing . List [TestResult ]:
724+ doctester : DocTester , docstrings_to_test : dict [str , str ]
725+ ) -> list [TestResult ]:
720726 """Get test results from doctester with docstrings to test."""
721727 _test_style = (
722728 doctester .style
0 commit comments