Skip to content

Commit fd0a834

Browse files
committed
refactor: fix last commit according to review feedback
1 parent c5e51d4 commit fd0a834

File tree

3 files changed

+88
-107
lines changed

3 files changed

+88
-107
lines changed
Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import importlib.util
4+
import shutil
45
import sys
56
from pathlib import Path
67

@@ -15,30 +16,21 @@ def test_race_condition_simulation(tmp_path):
1516
1617
The test verifies that no NameError is raised for _DISTUTILS_PATCH.
1718
"""
18-
venv_path = tmp_path
19-
2019
# Create the _virtualenv.py file
21-
virtualenv_file = venv_path / "_virtualenv.py"
22-
source_file = Path(__file__).parent.parent / "src" / "virtualenv" / "create" / "via_global_ref" / "_virtualenv.py"
23-
24-
if not source_file.exists():
25-
return # Skip test if source file doesn't exist
20+
virtualenv_file = tmp_path / "_virtualenv.py"
21+
source_file = Path(__file__).parents[2] / "src" / "virtualenv" / "create" / "via_global_ref" / "_virtualenv.py"
2622

27-
content = source_file.read_text(encoding="utf-8")
28-
virtualenv_file.write_text(content, encoding="utf-8")
23+
shutil.copy(source_file, virtualenv_file)
2924

3025
# Create the _virtualenv.pth file
31-
pth_file = venv_path / "_virtualenv.pth"
26+
pth_file = tmp_path / "_virtualenv.pth"
3227
pth_file.write_text("import _virtualenv", encoding="utf-8")
3328

34-
# Simulate the race condition by alternating between importing and overwriting
29+
# Simulate the race condition by repeatedly importing
3530
errors = []
3631
for _ in range(5):
37-
# Overwrite the file
38-
virtualenv_file.write_text(content, encoding="utf-8")
39-
4032
# Try to import it
41-
sys.path.insert(0, str(venv_path))
33+
sys.path.insert(0, str(tmp_path))
4234
try:
4335
if "_virtualenv" in sys.modules:
4436
del sys.modules["_virtualenv"]
@@ -52,11 +44,7 @@ def test_race_condition_simulation(tmp_path):
5244
if "_DISTUTILS_PATCH" in str(e):
5345
errors.append(str(e))
5446
finally:
55-
if str(venv_path) in sys.path:
56-
sys.path.remove(str(venv_path))
57-
58-
# Clean up
59-
if "_virtualenv" in sys.modules:
60-
del sys.modules["_virtualenv"]
47+
if str(tmp_path) in sys.path:
48+
sys.path.remove(str(tmp_path))
6149

6250
assert not errors, f"Race condition detected: {errors}"
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import sys
2+
3+
class _Finder:
4+
fullname = None
5+
lock = []
6+
7+
def find_spec(self, fullname, path, target=None):
8+
# This should handle the NameError gracefully
9+
try:
10+
distutils_patch = _DISTUTILS_PATCH # noqa: F821
11+
except NameError:
12+
return None
13+
if fullname in distutils_patch and self.fullname is None:
14+
return None
15+
return None
16+
17+
@staticmethod
18+
def exec_module(old, module):
19+
old(module)
20+
try:
21+
distutils_patch = _DISTUTILS_PATCH # noqa: F821
22+
except NameError:
23+
return
24+
if module.__name__ in distutils_patch:
25+
pass # Would call patch_dist(module)
26+
27+
@staticmethod
28+
def load_module(old, name):
29+
module = old(name)
30+
try:
31+
distutils_patch = _DISTUTILS_PATCH # noqa: F821
32+
except NameError:
33+
return module
34+
if module.__name__ in distutils_patch:
35+
pass # Would call patch_dist(module)
36+
return module
37+
38+
finder = _Finder()

tests/unit/create/via_global_ref/test_race_condition.py

Lines changed: 41 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,97 +1,52 @@
11
from __future__ import annotations
22

33
import sys
4-
import tempfile
54
from pathlib import Path
6-
from textwrap import dedent
75

86

9-
def test_virtualenv_py_race_condition_find_spec():
7+
def test_virtualenv_py_race_condition_find_spec(tmp_path):
108
"""Test that _Finder.find_spec handles NameError gracefully when _DISTUTILS_PATCH is not defined."""
119
# Create a temporary file with partial _virtualenv.py content (simulating race condition)
12-
with tempfile.TemporaryDirectory() as tmpdir:
13-
venv_file = Path(tmpdir) / "_virtualenv_test.py"
14-
15-
# Write a partial version of _virtualenv.py that has _Finder but not _DISTUTILS_PATCH
16-
# This simulates the state during a race condition where the file is being rewritten
17-
partial_content = dedent("""
18-
import sys
19-
20-
class _Finder:
21-
fullname = None
22-
lock = []
23-
24-
def find_spec(self, fullname, path, target=None):
25-
# This should handle the NameError gracefully
26-
try:
27-
distutils_patch = _DISTUTILS_PATCH # noqa: F821
28-
except NameError:
29-
return None
30-
if fullname in distutils_patch and self.fullname is None:
31-
return None
32-
return None
33-
34-
@staticmethod
35-
def exec_module(old, module):
36-
old(module)
37-
try:
38-
distutils_patch = _DISTUTILS_PATCH # noqa: F821
39-
except NameError:
40-
return
41-
if module.__name__ in distutils_patch:
42-
pass # Would call patch_dist(module)
43-
44-
@staticmethod
45-
def load_module(old, name):
46-
module = old(name)
47-
try:
48-
distutils_patch = _DISTUTILS_PATCH # noqa: F821
49-
except NameError:
50-
return module
51-
if module.__name__ in distutils_patch:
52-
pass # Would call patch_dist(module)
53-
return module
54-
55-
finder = _Finder()
56-
""")
57-
58-
venv_file.write_text(partial_content, encoding="utf-8")
59-
60-
# Add the directory to sys.path temporarily
61-
sys.path.insert(0, tmpdir)
62-
try:
63-
# Import the module
64-
import _virtualenv_test # noqa: PLC0415
65-
66-
# Get the finder instance
67-
finder = _virtualenv_test.finder
68-
69-
# Try to call find_spec - this should not raise NameError
70-
result = finder.find_spec("distutils.dist", None)
71-
assert result is None, "find_spec should return None when _DISTUTILS_PATCH is not defined"
72-
73-
# Create a mock module object
74-
class MockModule:
75-
__name__ = "distutils.dist"
76-
77-
# Try to call exec_module - this should not raise NameError
78-
def mock_old_exec(_x):
79-
pass
80-
81-
finder.exec_module(mock_old_exec, MockModule())
82-
83-
# Try to call load_module - this should not raise NameError
84-
def mock_old_load(_name):
85-
return MockModule()
86-
87-
result = finder.load_module(mock_old_load, "distutils.dist")
88-
assert result.__name__ == "distutils.dist"
89-
90-
finally:
91-
# Clean up
92-
sys.path.remove(tmpdir)
93-
if "_virtualenv_test" in sys.modules:
94-
del sys.modules["_virtualenv_test"]
10+
venv_file = tmp_path / "_virtualenv_test.py"
11+
12+
# Write a partial version of _virtualenv.py that has _Finder but not _DISTUTILS_PATCH
13+
# This simulates the state during a race condition where the file is being rewritten
14+
helper_file = Path(__file__).parent / "_test_race_condition_helper.py"
15+
partial_content = helper_file.read_text(encoding="utf-8")
16+
17+
venv_file.write_text(partial_content, encoding="utf-8")
18+
19+
sys.path.insert(0, str(tmp_path))
20+
try:
21+
import _virtualenv_test # noqa: PLC0415
22+
23+
finder = _virtualenv_test.finder
24+
25+
# Try to call find_spec - this should not raise NameError
26+
result = finder.find_spec("distutils.dist", None)
27+
assert result is None, "find_spec should return None when _DISTUTILS_PATCH is not defined"
28+
29+
# Create a mock module object
30+
class MockModule:
31+
__name__ = "distutils.dist"
32+
33+
# Try to call exec_module - this should not raise NameError
34+
def mock_old_exec(_x):
35+
pass
36+
37+
finder.exec_module(mock_old_exec, MockModule())
38+
39+
# Try to call load_module - this should not raise NameError
40+
def mock_old_load(_name):
41+
return MockModule()
42+
43+
result = finder.load_module(mock_old_load, "distutils.dist")
44+
assert result.__name__ == "distutils.dist"
45+
46+
finally:
47+
sys.path.remove(str(tmp_path))
48+
if "_virtualenv_test" in sys.modules:
49+
del sys.modules["_virtualenv_test"]
9550

9651

9752
def test_virtualenv_py_normal_operation():

0 commit comments

Comments
 (0)