Skip to content

Commit b3a6571

Browse files
ejguanfacebook-github-bot
authored andcommitted
Properly cleanup unclosed files within generator function (#910)
Summary: There is a case that `file.close` is never called because when generator function has never reached to the end. A simple example would be `zip` two datepipes with different length. The longer DataPipe would never reach the end of generator and then it will be cleaned up by `gc`. So, the line of `file.close` is not executed. (This is the reason that Vitaly has to create this [hack](https://github.com/pytorch/pytorch/blob/4451eb24e6287dff62ff8a7ec0eda6a6998807b0/torch/utils/data/datapipes/iter/combining.py#L573-L583) to retrieve all remaining data to make sure generator function is fully executed) However, this hack introduces another problem where an infinite datapipe would make `zip` never end as it would try to deplete the infinite iterator. See: #865 So, in this PR, I am adding a `try-finally` clause to make sure the `file.close` is always executed during the destruction of `generator` object. Then, we don't need the hack within `zip` any more. - pytorch/pytorch#89973 - pytorch/vision#6997 - pytorch/pytorch#89974 Pull Request resolved: #910 Reviewed By: wenleix, NivekT Differential Revision: D41633909 Pulled By: ejguan fbshipit-source-id: 5613e131dc8b2962c22d2bc7e3a4bb3039440c48
1 parent 45bc070 commit b3a6571

4 files changed

Lines changed: 42 additions & 36 deletions

File tree

torchdata/datapipes/iter/util/combining.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,12 @@ def __iter__(self) -> Iterator:
115115
else:
116116
yield res
117117
finally:
118-
for remaining in ref_it:
119-
janitor(remaining)
120-
118+
del ref_it
121119
# TODO(633): This should be Exception or warn when debug mode is enabled
122-
if len(self.buffer) > 0:
120+
if self.buffer:
123121
for _, v in self.buffer.items():
124122
janitor(v)
123+
self.buffer.clear()
125124

126125
def __len__(self) -> int:
127126
return len(self.source_datapipe)
@@ -156,7 +155,10 @@ def __setstate__(self, state):
156155
self.buffer = OrderedDict()
157156

158157
def __del__(self):
159-
self.buffer.clear()
158+
if self.buffer:
159+
for _, v in self.buffer.items():
160+
janitor(v)
161+
self.buffer.clear()
160162

161163

162164
@functional_datapipe("zip_with_map")

torchdata/datapipes/iter/util/decompressor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,13 @@ def _detect_compression_type(self, path: str) -> CompressionType:
9595

9696
def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
9797
for path, file in self.source_datapipe:
98-
file_type = self._detect_compression_type(path)
99-
decompressor = self._DECOMPRESSORS[file_type]
100-
yield path, StreamWrapper(decompressor(file), file, name=path)
101-
if isinstance(file, StreamWrapper):
102-
file.autoclose()
98+
try:
99+
file_type = self._detect_compression_type(path)
100+
decompressor = self._DECOMPRESSORS[file_type]
101+
yield path, StreamWrapper(decompressor(file), file, name=path)
102+
finally:
103+
if isinstance(file, StreamWrapper):
104+
file.autoclose()
103105

104106

105107
@functional_datapipe("extract")

torchdata/datapipes/iter/util/plain_text_reader.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ def skip_lines(self, file: IO) -> Union[Iterator[bytes], Iterator[str]]:
4141
with contextlib.suppress(StopIteration):
4242
for _ in range(self._skip_lines):
4343
next(file)
44-
yield from file
45-
file.close()
44+
try:
45+
yield from file
46+
finally:
47+
file.close()
4648

4749
def strip_newline(self, stream: Union[Iterator[bytes], Iterator[str]]) -> Union[Iterator[bytes], Iterator[str]]:
4850
if not self._strip_newline:
@@ -58,10 +60,9 @@ def strip_newline(self, stream: Union[Iterator[bytes], Iterator[str]]) -> Union[
5860
def decode(self, stream: Union[Iterator[bytes], Iterator[str]]) -> Union[Iterator[bytes], Iterator[str]]:
5961
if not self._decode:
6062
yield from stream
61-
return
62-
63-
for line in stream:
64-
yield line.decode(self._encoding, self._errors) if isinstance(line, bytes) else line
63+
else:
64+
for line in stream:
65+
yield line.decode(self._encoding, self._errors) if isinstance(line, bytes) else line
6566

6667
def return_path(self, stream: Iterator[D], *, path: str) -> Iterator[Union[D, Tuple[str, D]]]:
6768
if not self._return_path:

torchdata/datapipes/iter/util/rararchiveloader.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -90,26 +90,27 @@ def __iter__(self) -> Iterator[Tuple[str, io.BufferedIOBase]]:
9090
_PATCHED = True
9191

9292
for data in self.datapipe:
93-
validate_pathname_binary_tuple(data)
94-
path, stream = data
95-
if isinstance(stream, rarfile.RarExtFile) or (
96-
isinstance(stream, StreamWrapper) and isinstance(stream.file_obj, rarfile.RarExtFile)
97-
):
98-
raise ValueError(
99-
f"Nested RAR archive is not supported by {type(self).__name__}. Please extract outer archive first."
100-
)
101-
102-
rar = rarfile.RarFile(stream)
103-
for info in rar.infolist():
104-
if info.is_dir():
105-
continue
106-
107-
inner_path = os.path.join(path, info.filename)
108-
file_obj = rar.open(info)
109-
110-
yield inner_path, StreamWrapper(file_obj, stream, name=path) # type: ignore[misc]
111-
if isinstance(stream, StreamWrapper):
112-
stream.autoclose()
93+
try:
94+
validate_pathname_binary_tuple(data)
95+
path, stream = data
96+
if isinstance(stream, rarfile.RarExtFile) or (
97+
isinstance(stream, StreamWrapper) and isinstance(stream.file_obj, rarfile.RarExtFile)
98+
):
99+
raise ValueError(
100+
f"Nested RAR archive is not supported by {type(self).__name__}. Please extract outer archive first."
101+
)
102+
103+
rar = rarfile.RarFile(stream)
104+
for info in rar.infolist():
105+
if info.is_dir():
106+
continue
107+
108+
inner_path = os.path.join(path, info.filename)
109+
file_obj = rar.open(info)
110+
yield inner_path, StreamWrapper(file_obj, stream, name=path) # type: ignore[misc]
111+
finally:
112+
if isinstance(stream, StreamWrapper):
113+
stream.autoclose()
113114

114115
def __len__(self) -> int:
115116
if self.length == -1:

0 commit comments

Comments
 (0)