Skip to content

Commit 54bb70e

Browse files
authored
refactor EnvironBuilder file handling and related code (#3101)
2 parents 7055130 + 4f15773 commit 54bb70e

File tree

12 files changed

+363
-439
lines changed

12 files changed

+363
-439
lines changed

CHANGES.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ Unreleased
2626
part of the HTTP spec. :pr:`3092`
2727
- ``EnvironBuilder.close`` closes all open files in ``files`` rather than only
2828
the first for each key. :pr:`3092`
29+
- ``EnvironBuilder`` can be used as a ``with`` context manager. :pr:`3101`
30+
- ``EnvironBuilder.files.add_file`` will detect the filename when passing an
31+
IO object. :pr:`3101`
32+
- Added the ``EnvironBuilder.files.close`` method to close all files.
33+
``EnvironBuilder.files.clear`` will call ``close``. :pr:`3101`
2934
- ``Map`` takes a ``subdomain_matching`` parameter to disable subdomain
3035
matching. In ``bind_to_environ``, the ``server_name`` parameter is not used
3136
if ``host_matching`` is enabled. If ``default_subdomain`` is set, it is used
@@ -34,6 +39,10 @@ Unreleased
3439
validated against ``request.trusted_hosts``. An invalid host will raise a
3540
400 error. :issue:`3007`
3641
- Watchdog reloader is more efficient at ignoring events. :issue:`3090`
42+
- If multipart parsing fails after some files have already been parsed, they
43+
are closed to prevent a ``ResourceWarning``. :pr:`3101`
44+
- ``SpooledTemporaryFile`` is always used for multipart file parsing.
45+
:pr:`3101`
3746

3847

3948
Version 3.1.5

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,6 @@ default-groups = ["dev", "pre-commit", "tests", "typing"]
9494
testpaths = ["tests"]
9595
filterwarnings = [
9696
"error",
97-
# TODO fix these
98-
"always::pytest.PytestUnraisableExceptionWarning",
9997
]
10098
markers = ["dev_server: tests that start the dev server"]
10199

src/werkzeug/datastructures/file_storage.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,7 @@ def __init__(
3232
):
3333
self.name = name
3434
self.stream = stream or BytesIO()
35-
36-
# If no filename is provided, attempt to get the filename from
37-
# the stream object. Python names special streams like
38-
# ``<stderr>`` with angular brackets, skip these streams.
39-
if filename is None:
40-
filename = getattr(stream, "name", None)
41-
42-
if filename is not None:
43-
filename = fsdecode(filename)
44-
45-
if filename and filename[0] == "<" and filename[-1] == ">":
46-
filename = None
47-
else:
48-
filename = fsdecode(filename)
49-
50-
self.filename = filename
35+
self.filename = _guess_filename(self.stream, filename)
5136

5237
if headers is None:
5338
headers = Headers()
@@ -163,9 +148,8 @@ def __repr__(self) -> str:
163148

164149

165150
class FileMultiDict(MultiDict[str, FileStorage]):
166-
"""A special :class:`MultiDict` that has convenience methods to add
167-
files to it. This is used for :class:`EnvironBuilder` and generally
168-
useful for unittesting.
151+
"""A :class:`MultiDict` for managing form data file values. Used by
152+
:class:`.EnvironBuilder` for tests.
169153
170154
.. versionadded:: 0.5
171155
"""
@@ -177,13 +161,20 @@ def add_file(
177161
filename: str | None = None,
178162
content_type: str | None = None,
179163
) -> None:
180-
"""Adds a new file to the dict. `file` can be a file name or
181-
a :class:`file`-like or a :class:`FileStorage` object.
182-
183-
:param name: the name of the field.
184-
:param file: a filename or :class:`file`-like object
185-
:param filename: an optional filename
186-
:param content_type: an optional content type
164+
"""Add a file to the given key. Can be passed a filename or IO object,
165+
which will construct a :class:`.FileStorage` object.
166+
167+
:param name: The key to add the file to.
168+
:param file: The file to add. Constructs a :class:`FileStorage` object
169+
if the value is not one.
170+
:param filename: The filename to set for the field. Defaults to ``file``
171+
if it's a filename or ``file.name`` if it's an IO object.
172+
:param content_type: The content type to set for the field. Defaults to
173+
guessing based on the filename, falling back to
174+
``application/octet-stream``.
175+
176+
.. versionchanged:: 3.2
177+
The filename is detected from an IO object.
187178
"""
188179
if isinstance(file, FileStorage):
189180
self.add(name, file)
@@ -196,8 +187,9 @@ def add_file(
196187
file_obj: t.IO[bytes] = open(file, "rb")
197188
else:
198189
file_obj = file # type: ignore[assignment]
190+
filename = _guess_filename(file_obj, filename)
199191

200-
if filename and content_type is None:
192+
if filename is not None and content_type is None:
201193
content_type = (
202194
mimetypes.guess_type(filename)[0] or "application/octet-stream"
203195
)
@@ -214,6 +206,30 @@ def close(self) -> None:
214206
if not value.closed:
215207
value.close()
216208

209+
def clear(self) -> None:
210+
"""Call :meth:`close`, then remove all items.
211+
212+
.. versionadded:: 3.2
213+
"""
214+
self.close()
215+
super().clear()
216+
217+
218+
def _guess_filename(stream: t.IO[t.Any], filename: str | None) -> str | None:
219+
if filename is not None:
220+
return fsdecode(filename)
221+
222+
filename = getattr(stream, "name", None)
223+
224+
if filename is not None:
225+
filename = fsdecode(filename)
226+
227+
# Python names special streams like `<stderr>`, ignore these.
228+
if filename[:1] == "<" and filename[-1:] == ">":
229+
filename = None
230+
231+
return filename
232+
217233

218234
# circular dependencies
219235
from .. import http # noqa: E402

src/werkzeug/formparser.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

33
import typing as t
4-
from io import BytesIO
4+
from tempfile import SpooledTemporaryFile
5+
from types import TracebackType
56
from urllib.parse import parse_qsl
67

78
from ._internal import _plain_int
@@ -19,18 +20,8 @@
1920
from .wsgi import get_content_length
2021
from .wsgi import get_input_stream
2122

22-
# there are some platforms where SpooledTemporaryFile is not available.
23-
# In that case we need to provide a fallback.
24-
try:
25-
from tempfile import SpooledTemporaryFile
26-
except ImportError:
27-
from tempfile import TemporaryFile
28-
29-
SpooledTemporaryFile = None # type: ignore
30-
3123
if t.TYPE_CHECKING:
32-
import typing as te
33-
24+
import typing_extensions as te
3425
from _typeshed.wsgi import WSGIEnvironment
3526

3627
t_parse_result = tuple[
@@ -56,14 +47,7 @@ def default_stream_factory(
5647
filename: str | None,
5748
content_length: int | None = None,
5849
) -> t.IO[bytes]:
59-
max_size = 1024 * 500
60-
61-
if SpooledTemporaryFile is not None:
62-
return t.cast(t.IO[bytes], SpooledTemporaryFile(max_size=max_size, mode="rb+"))
63-
elif total_content_length is None or total_content_length > max_size:
64-
return t.cast(t.IO[bytes], TemporaryFile("rb+"))
65-
66-
return BytesIO()
50+
return SpooledTemporaryFile(max_size=1024 * 500, mode="rb+")
6751

6852

6953
def parse_form_data(
@@ -253,18 +237,19 @@ def _parse_multipart(
253237
content_length: int | None,
254238
options: dict[str, str],
255239
) -> t_parse_result:
256-
parser = MultiPartParser(
257-
stream_factory=self.stream_factory,
258-
max_form_memory_size=self.max_form_memory_size,
259-
max_form_parts=self.max_form_parts,
260-
cls=self.cls,
261-
)
262240
boundary = options.get("boundary", "").encode("ascii")
263241

264242
if not boundary:
265243
raise ValueError("Missing boundary")
266244

267-
form, files = parser.parse(stream, boundary, content_length)
245+
with MultiPartParser(
246+
stream_factory=self.stream_factory,
247+
max_form_memory_size=self.max_form_memory_size,
248+
max_form_parts=self.max_form_parts,
249+
cls=self.cls,
250+
) as parser:
251+
form, files = parser.parse(stream, boundary, content_length)
252+
268253
return stream, form, files
269254

270255
def _parse_urlencoded(
@@ -305,15 +290,26 @@ def __init__(
305290
stream_factory = default_stream_factory
306291

307292
self.stream_factory = stream_factory
293+
self._files: list[t.IO[bytes]] = []
308294

309295
if cls is None:
310296
cls = t.cast("type[MultiDict[str, t.Any]]", MultiDict)
311297

312298
self.cls = cls
313299
self.buffer_size = buffer_size
314300

315-
def fail(self, message: str) -> te.NoReturn:
316-
raise ValueError(message)
301+
def __enter__(self) -> te.Self:
302+
return self
303+
304+
def __exit__(
305+
self,
306+
exc_type: type[BaseException] | None,
307+
exc_val: BaseException | None,
308+
exc_tb: TracebackType | None,
309+
) -> None:
310+
if exc_val is not None:
311+
for file in self._files:
312+
file.close()
317313

318314
def get_part_charset(self, headers: Headers) -> str:
319315
# Figure out input charset for current part
@@ -346,6 +342,7 @@ def start_file_streaming(
346342
content_type=content_type,
347343
content_length=content_length,
348344
)
345+
self._files.append(container)
349346
return container
350347

351348
def parse(

0 commit comments

Comments
 (0)