Skip to content

Commit 790c5c0

Browse files
authored
Merge commit from fork
Changes for users: - (BREAKING) `dotenv.set_key` and `dotenv.unset_key` used to follow symlinks in some situations. This is no longer the case. For that behavior to be restored in all cases, `follow_symlinks=True` should be used. - (BREAKING) In the CLI, `set` and `unset` used to follow symlinks in some situations. This is no longer the case. - (BREAKING) `dotenv.set_key`, `dotenv.unset_key` and the CLI commands `set` and `unset` used to reset the file mode of the modified .env file to `0o600` in some situations. This is no longer the case: The original mode of the file is now preserved. Is the file needed to be created or wasn't a regular file, mode `0o600` is used.
1 parent 43340da commit 790c5c0

File tree

3 files changed

+199
-15
lines changed

3 files changed

+199
-15
lines changed

src/dotenv/cli.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,13 @@ def list_values(ctx: click.Context, output_format: str) -> None:
114114
@click.argument("key", required=True)
115115
@click.argument("value", required=True)
116116
def set_value(ctx: click.Context, key: Any, value: Any) -> None:
117-
"""Store the given key/value."""
117+
"""
118+
Store the given key/value.
119+
120+
This doesn't follow symlinks, to avoid accidentally modifying a file at a
121+
potentially untrusted path.
122+
"""
123+
118124
file = ctx.obj["FILE"]
119125
quote = ctx.obj["QUOTE"]
120126
export = ctx.obj["EXPORT"]
@@ -146,7 +152,12 @@ def get(ctx: click.Context, key: Any) -> None:
146152
@click.pass_context
147153
@click.argument("key", required=True)
148154
def unset(ctx: click.Context, key: Any) -> None:
149-
"""Removes the given key."""
155+
"""
156+
Removes the given key.
157+
158+
This doesn't follow symlinks, to avoid accidentally modifying a file at a
159+
potentially untrusted path.
160+
"""
150161
file = ctx.obj["FILE"]
151162
quote = ctx.obj["QUOTE"]
152163
success, key = unset_key(file, key, quote)

src/dotenv/main.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import logging
33
import os
44
import pathlib
5-
import shutil
65
import stat
76
import sys
87
import tempfile
@@ -14,9 +13,7 @@
1413
from .variables import parse_variables
1514

1615
# A type alias for a string path to be used for the paths in this file.
17-
# These paths may flow to `open()` and `shutil.move()`; `shutil.move()`
18-
# only accepts string paths, not byte paths or file descriptors. See
19-
# https://github.com/python/typeshed/pull/6832.
16+
# These paths may flow to `open()` and `os.replace()`.
2017
StrPath = Union[str, "os.PathLike[str]"]
2118

2219
logger = logging.getLogger(__name__)
@@ -142,21 +139,54 @@ def get_key(
142139
def rewrite(
143140
path: StrPath,
144141
encoding: Optional[str],
142+
follow_symlinks: bool = False,
145143
) -> Iterator[Tuple[IO[str], IO[str]]]:
146-
pathlib.Path(path).touch()
144+
if follow_symlinks:
145+
path = os.path.realpath(path)
147146

148-
with tempfile.NamedTemporaryFile(mode="w", encoding=encoding, delete=False) as dest:
147+
try:
148+
source: IO[str] = open(path, encoding=encoding)
149+
try:
150+
path_stat = os.lstat(path)
151+
original_mode: Optional[int] = (
152+
stat.S_IMODE(path_stat.st_mode)
153+
if stat.S_ISREG(path_stat.st_mode)
154+
else None
155+
)
156+
except BaseException:
157+
source.close()
158+
raise
159+
except FileNotFoundError:
160+
source = io.StringIO("")
161+
original_mode = None
162+
163+
with tempfile.NamedTemporaryFile(
164+
mode="w",
165+
encoding=encoding,
166+
delete=False,
167+
prefix=".tmp_",
168+
dir=os.path.dirname(os.path.abspath(path)),
169+
) as dest:
170+
dest_path = pathlib.Path(dest.name)
149171
error = None
172+
150173
try:
151-
with open(path, encoding=encoding) as source:
174+
with source:
152175
yield (source, dest)
153176
except BaseException as err:
154177
error = err
155178

156179
if error is None:
157-
shutil.move(dest.name, path)
180+
try:
181+
if original_mode is not None:
182+
os.chmod(dest_path, original_mode)
183+
184+
os.replace(dest_path, path)
185+
except BaseException:
186+
dest_path.unlink(missing_ok=True)
187+
raise
158188
else:
159-
os.unlink(dest.name)
189+
dest_path.unlink(missing_ok=True)
160190
raise error from None
161191

162192

@@ -167,12 +197,16 @@ def set_key(
167197
quote_mode: str = "always",
168198
export: bool = False,
169199
encoding: Optional[str] = "utf-8",
200+
follow_symlinks: bool = False,
170201
) -> Tuple[Optional[bool], str, str]:
171202
"""
172203
Adds or Updates a key/value to the given .env
173204
174-
If the .env path given doesn't exist, fails instead of risking creating
175-
an orphan .env somewhere in the filesystem
205+
The target .env file is created if it doesn't exist.
206+
207+
This function doesn't follow symlinks by default, to avoid accidentally
208+
modifying a file at a potentially untrusted path. If you don't need this
209+
protection and need symlinks to be followed, use `follow_symlinks`.
176210
"""
177211
if quote_mode not in ("always", "auto", "never"):
178212
raise ValueError(f"Unknown quote_mode: {quote_mode}")
@@ -190,7 +224,10 @@ def set_key(
190224
else:
191225
line_out = f"{key_to_set}={value_out}\n"
192226

193-
with rewrite(dotenv_path, encoding=encoding) as (source, dest):
227+
with rewrite(dotenv_path, encoding=encoding, follow_symlinks=follow_symlinks) as (
228+
source,
229+
dest,
230+
):
194231
replaced = False
195232
missing_newline = False
196233
for mapping in with_warn_for_invalid_lines(parse_stream(source)):
@@ -213,19 +250,27 @@ def unset_key(
213250
key_to_unset: str,
214251
quote_mode: str = "always",
215252
encoding: Optional[str] = "utf-8",
253+
follow_symlinks: bool = False,
216254
) -> Tuple[Optional[bool], str]:
217255
"""
218256
Removes a given key from the given `.env` file.
219257
220258
If the .env path given doesn't exist, fails.
221259
If the given key doesn't exist in the .env, fails.
260+
261+
This function doesn't follow symlinks by default, to avoid accidentally
262+
modifying a file at a potentially untrusted path. If you don't need this
263+
protection and need symlinks to be followed, use `follow_symlinks`.
222264
"""
223265
if not os.path.exists(dotenv_path):
224266
logger.warning("Can't delete from %s - it doesn't exist.", dotenv_path)
225267
return None, key_to_unset
226268

227269
removed = False
228-
with rewrite(dotenv_path, encoding=encoding) as (source, dest):
270+
with rewrite(dotenv_path, encoding=encoding, follow_symlinks=follow_symlinks) as (
271+
source,
272+
dest,
273+
):
229274
for mapping in with_warn_for_invalid_lines(parse_stream(source)):
230275
if mapping.key == key_to_unset:
231276
removed = True

tests/test_main.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,86 @@ def test_set_key_encoding(dotenv_path):
6262
assert dotenv_path.read_text(encoding=encoding) == "a='é'\n"
6363

6464

65+
@pytest.mark.skipif(
66+
sys.platform == "win32", reason="file mode bits behave differently on Windows"
67+
)
68+
def test_set_key_preserves_file_mode(dotenv_path):
69+
dotenv_path.write_text("a=x\n")
70+
dotenv_path.chmod(0o640)
71+
mode_before = stat.S_IMODE(dotenv_path.stat().st_mode)
72+
73+
dotenv.set_key(dotenv_path, "a", "y")
74+
75+
mode_after = stat.S_IMODE(dotenv_path.stat().st_mode)
76+
assert mode_before == mode_after
77+
78+
79+
def test_rewrite_closes_file_handle_on_lstat_failure(tmp_path):
80+
dotenv_path = tmp_path / ".env"
81+
dotenv_path.write_text("a=x\n")
82+
real_open = open
83+
opened_handles = []
84+
85+
def tracking_open(*args, **kwargs):
86+
handle = real_open(*args, **kwargs)
87+
opened_handles.append(handle)
88+
return handle
89+
90+
with mock.patch("dotenv.main.os.lstat", side_effect=FileNotFoundError):
91+
with mock.patch("dotenv.main.open", side_effect=tracking_open):
92+
dotenv.set_key(dotenv_path, "a", "x")
93+
94+
assert opened_handles, "expected at least one file to be opened"
95+
assert all(handle.closed for handle in opened_handles)
96+
97+
98+
@pytest.mark.skipif(
99+
sys.platform == "win32", reason="symlinks require elevated privileges on Windows"
100+
)
101+
def test_set_key_symlink_to_existing_file(tmp_path):
102+
target = tmp_path / "target.env"
103+
target.write_text("a=x\n")
104+
symlink = tmp_path / ".env"
105+
symlink.symlink_to(target)
106+
107+
dotenv.set_key(symlink, "a", "y")
108+
109+
assert target.read_text() == "a=x\n"
110+
assert not symlink.is_symlink()
111+
assert "a='y'" in symlink.read_text()
112+
assert stat.S_IMODE(symlink.stat().st_mode) == 0o600
113+
114+
115+
@pytest.mark.skipif(
116+
sys.platform == "win32", reason="symlinks require elevated privileges on Windows"
117+
)
118+
def test_set_key_symlink_to_missing_file(tmp_path):
119+
target = tmp_path / "nx"
120+
symlink = tmp_path / ".env"
121+
symlink.symlink_to(target)
122+
123+
dotenv.set_key(symlink, "a", "x")
124+
125+
assert not target.exists()
126+
assert not symlink.is_symlink()
127+
assert symlink.read_text() == "a='x'\n"
128+
129+
130+
@pytest.mark.skipif(
131+
sys.platform == "win32", reason="symlinks require elevated privileges on Windows"
132+
)
133+
def test_set_key_follow_symlinks(tmp_path):
134+
target = tmp_path / "target.env"
135+
target.write_text("a=x\n")
136+
symlink = tmp_path / ".env"
137+
symlink.symlink_to(target)
138+
139+
dotenv.set_key(symlink, "a", "y", follow_symlinks=True)
140+
141+
assert target.read_text() == "a='y'\n"
142+
assert symlink.is_symlink()
143+
144+
65145
@pytest.mark.skipif(
66146
sys.platform != "win32" and os.geteuid() == 0,
67147
reason="Root user can access files even with 000 permissions.",
@@ -195,6 +275,54 @@ def test_unset_non_existent_file(tmp_path):
195275
)
196276

197277

278+
@pytest.mark.skipif(
279+
sys.platform == "win32", reason="symlinks require elevated privileges on Windows"
280+
)
281+
def test_unset_key_symlink_to_existing_file(tmp_path):
282+
target = tmp_path / "target.env"
283+
target.write_text("a=x\n")
284+
symlink = tmp_path / ".env"
285+
symlink.symlink_to(target)
286+
287+
dotenv.unset_key(symlink, "a")
288+
289+
assert target.read_text() == "a=x\n"
290+
assert not symlink.is_symlink()
291+
assert symlink.read_text() == ""
292+
293+
294+
@pytest.mark.skipif(
295+
sys.platform == "win32", reason="symlinks require elevated privileges on Windows"
296+
)
297+
def test_unset_key_symlink_to_missing_file(tmp_path):
298+
target = tmp_path / "nx"
299+
symlink = tmp_path / ".env"
300+
symlink.symlink_to(target)
301+
logger = logging.getLogger("dotenv.main")
302+
303+
with mock.patch.object(logger, "warning") as mock_warning:
304+
result = dotenv.unset_key(symlink, "a")
305+
306+
assert result == (None, "a")
307+
assert symlink.is_symlink()
308+
mock_warning.assert_called_once()
309+
310+
311+
@pytest.mark.skipif(
312+
sys.platform == "win32", reason="symlinks require elevated privileges on Windows"
313+
)
314+
def test_unset_key_follow_symlinks(tmp_path):
315+
target = tmp_path / "target.env"
316+
target.write_text("a=b\n")
317+
symlink = tmp_path / ".env"
318+
symlink.symlink_to(target)
319+
320+
dotenv.unset_key(symlink, "a", follow_symlinks=True)
321+
322+
assert target.read_text() == ""
323+
assert symlink.is_symlink()
324+
325+
198326
def prepare_file_hierarchy(path):
199327
"""
200328
Create a temporary folder structure like the following:

0 commit comments

Comments
 (0)