Skip to content

Commit a9e4939

Browse files
authored
feat: create and enforce blocked file extension list for custom outputs (#377)
* feat: create and enforce blocked file extension list for custom outputs * test: add coverage for blocked extension checks in plot outputs * refactor: centralise blocked extension check into utils helper * style: collapse Records.__init__ docstring to satisfy codacy --------- Signed-off-by: Hasaan A. <hasaana2005@gmail.com>
1 parent c5e8c89 commit a9e4939

9 files changed

Lines changed: 113 additions & 12 deletions

acro/acro.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,14 @@ def __init__(self, config: str = "default", suppress: bool = False) -> None:
6060
Tables.__init__(self, suppress)
6161
Regression.__init__(self, config)
6262
self.config: dict[str, Any] = {}
63-
self.results: Records = Records()
6463
self.suppress: bool = suppress
6564
path: pathlib.Path = pathlib.Path(__file__).with_name(config + ".yaml")
6665
logger.debug("path: %s", path)
6766
with open(path, encoding="utf-8") as handle:
6867
self.config = yaml.load(handle, Loader=yaml.loader.SafeLoader)
68+
self.results: Records = Records(
69+
blocked_extensions=self.config.get("blocked_extensions", [])
70+
)
6971
logger.info("version: %s", __version__)
7072
logger.info("config: %s", self.config)
7173
logger.info("automatic suppression: %s", self.suppress)
@@ -138,7 +140,7 @@ def print_outputs(self) -> str:
138140
"""
139141
return self.results.print()
140142

141-
def custom_output(self, filename: str, comment: str = "") -> None:
143+
def custom_output(self, filename: str, comment: str = "") -> bool:
142144
"""Add an unsupported output to the results dictionary.
143145
144146
Parameters
@@ -147,8 +149,13 @@ def custom_output(self, filename: str, comment: str = "") -> None:
147149
The name of the file that will be added to the list of the outputs.
148150
comment : str
149151
An optional comment.
152+
153+
Returns
154+
-------
155+
bool
156+
False if the file extension is blocked, True otherwise.
150157
"""
151-
self.results.add_custom(filename, comment)
158+
return self.results.add_custom(filename, comment)
152159

153160
def rename_output(self, old: str, new: str) -> None:
154161
"""Rename an output.

acro/acro_stata_parser.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -376,12 +376,10 @@ def add_custom_output(varlist: list[str]) -> str:
376376
except IndexError:
377377
return "syntax error: please pass the name of the output to be added"
378378

379-
# .gph extension contain data
380-
_, file_extension = os.path.splitext(the_output)
381-
if file_extension == ".gph":
382-
return "Warning: .gph files may not be exported as they contain data."
383379
comment_str = " ".join(varlist)
384-
stata_config.stata_acro.custom_output(the_output, comment_str)
380+
if not stata_config.stata_acro.custom_output(the_output, comment_str):
381+
_, ext = os.path.splitext(the_output)
382+
return f"Warning: {ext} files are not allowed and cannot be exported."
385383
outcome = f"file {the_output} with comment {comment_str} added to session."
386384
return outcome
387385

acro/acro_tables.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,8 @@ def survival_plot(
578578
summary: str,
579579
) -> tuple[Any, str] | None:
580580
"""Create the survival plot according to the status of suppressing."""
581+
if utils.is_blocked_extension(filename, self.results.blocked_extensions):
582+
return None
581583
if self.suppress:
582584
survival_table = _rounded_survival_table(survival_table)
583585
plot = survival_table.plot(y="rounded_survival_fun", xlim=0, ylim=0)
@@ -703,6 +705,8 @@ def hist(
703705
The name of the file where the histogram is saved.
704706
"""
705707
logger.debug("hist()")
708+
if utils.is_blocked_extension(filename, self.results.blocked_extensions):
709+
return None
706710
command: str = utils.get_command("hist()", stack())
707711

708712
if isinstance(data, list): # pragma: no cover
@@ -848,6 +852,8 @@ def pie(
848852
The path to the saved pie chart file.
849853
"""
850854
logger.debug("pie()")
855+
if utils.is_blocked_extension(filename, self.results.blocked_extensions):
856+
return None
851857
command: str = utils.get_command("pie()", stack())
852858

853859
# COMPUTE PRE-CATEGORY COUNTS

acro/default.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,13 @@ survival_safe_threshold: 10
2929

3030
# consider zeros to be disclosive
3131
zeros_are_disclosive: True
32+
33+
################################################################################
34+
# Blocked file extensions #
35+
################################################################################
36+
# File extensions that are not allowed in custom outputs or plots.
37+
# Extensions are case-insensitive and must include the leading dot.
38+
blocked_extensions:
39+
- .svg
40+
- .gph
3241
...

acro/record.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pandas import DataFrame
1616

1717
from .constants import ARTIFACTS_DIR
18+
from .utils import is_blocked_extension
1819
from .version import __version__
1920

2021
logger = logging.getLogger("acro:records")
@@ -210,10 +211,13 @@ def __str__(self) -> str:
210211
class Records:
211212
"""Stores data related to a collection of output records."""
212213

213-
def __init__(self) -> None:
214+
def __init__(self, blocked_extensions: list[str] | None = None) -> None:
214215
"""Construct a new object for storing multiple records."""
215216
self.results: dict[str, Record] = {}
216217
self.output_id: int = 0
218+
self.blocked_extensions: list[str] = [
219+
ext.lower() for ext in (blocked_extensions or [])
220+
]
217221

218222
def add(
219223
self,
@@ -323,7 +327,7 @@ def get_index(self, index: int) -> Record:
323327
key = list(self.results.keys())[index]
324328
return self.results[key]
325329

326-
def add_custom(self, filename: str, comment: str | None = None) -> None:
330+
def add_custom(self, filename: str, comment: str | None = None) -> bool:
327331
"""Add an unsupported output to the results dictionary.
328332
329333
Parameters
@@ -332,7 +336,14 @@ def add_custom(self, filename: str, comment: str | None = None) -> None:
332336
The name of the file that will be added to the list of the outputs.
333337
comment : str | None, default None
334338
An optional comment.
339+
340+
Returns
341+
-------
342+
bool
343+
False if the file extension is blocked, True otherwise.
335344
"""
345+
if is_blocked_extension(filename, self.blocked_extensions):
346+
return False
336347
if os.path.exists(filename):
337348
output = Record(
338349
uid=f"output_{self.output_id}",
@@ -353,6 +364,7 @@ def add_custom(self, filename: str, comment: str | None = None) -> None:
353364
logger.info(
354365
"WARNING: Unable to add %s because the file does not exist", filename
355366
) # pragma: no cover
367+
return True
356368

357369
def rename(self, old: str, new: str) -> None:
358370
"""Rename an output.

acro/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,27 @@
33
from __future__ import annotations
44

55
import logging
6+
import os
67
from inspect import FrameInfo, getframeinfo
78

89
import pandas as pd
910

1011
logger = logging.getLogger("acro")
1112

1213

14+
def is_blocked_extension(filename: str, blocked_extensions: list[str]) -> bool:
15+
"""Return True and log a warning if the file's extension is blocked."""
16+
_, ext = os.path.splitext(filename)
17+
if ext.lower() in blocked_extensions:
18+
logger.warning(
19+
"Blocked file extension %s. Files with extension %s are not allowed.",
20+
filename,
21+
ext,
22+
)
23+
return True
24+
return False
25+
26+
1327
def get_command(default: str, stack_list: list[FrameInfo]) -> str:
1428
"""Return the calling source line as a string.
1529

test/test_initial.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,61 @@ def test_custom_output(acro):
506506
shutil.rmtree(PATH)
507507

508508

509+
def test_blocked_extension(acro, tmp_path):
510+
"""Test that blocked file extensions are rejected in custom output."""
511+
# create temporary files with blocked extensions
512+
svg_file = tmp_path / "test.svg"
513+
svg_file.write_text("<svg></svg>")
514+
gph_file = tmp_path / "test.gph"
515+
gph_file.write_text("data")
516+
517+
# blocked extensions should be rejected
518+
acro.custom_output(str(svg_file))
519+
acro.custom_output(str(gph_file))
520+
assert len(acro.results.results) == 0
521+
522+
# allowed extensions should be accepted
523+
txt_file = tmp_path / "test.txt"
524+
txt_file.write_text("hello")
525+
acro.custom_output(str(txt_file))
526+
assert len(acro.results.results) == 1
527+
528+
# case-insensitive check
529+
svg_upper = tmp_path / "test.SVG"
530+
svg_upper.write_text("<svg></svg>")
531+
acro.custom_output(str(svg_upper))
532+
assert len(acro.results.results) == 1
533+
534+
535+
def test_blocked_extension_hist(data, acro):
536+
"""Test that blocked file extensions are rejected for histograms."""
537+
result = acro.hist(data, "inc_grants", bins=1, filename="hist.svg")
538+
assert result is None
539+
assert len(acro.results.results) == 0
540+
541+
542+
def test_blocked_extension_pie(data, acro):
543+
"""Test that blocked file extensions are rejected for pie charts."""
544+
result = acro.pie(data, "grant_type", filename="pie.svg")
545+
assert result is None
546+
assert len(acro.results.results) == 0
547+
548+
549+
def test_blocked_extension_survival(acro):
550+
"""Test that blocked file extensions are rejected for survival plots."""
551+
result = acro.survival_plot(
552+
survival_table=pd.DataFrame(),
553+
survival_func=None,
554+
filename="surv.svg",
555+
status="pass",
556+
sdc={},
557+
command="test",
558+
summary="test",
559+
)
560+
assert result is None
561+
assert len(acro.results.results) == 0
562+
563+
509564
def test_missing(data, acro, monkeypatch):
510565
"""Pivot table and Crosstab with negative values."""
511566
acro_tables.CHECK_MISSING_VALUES = True

test/test_stata17_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def test_stata_custom_output_invalid():
472472
options="nototals",
473473
stata_version="17",
474474
)
475-
correct = "Warning: .gph files may not be exported as they contain data."
475+
correct = "Warning: .gph files are not allowed and cannot be exported."
476476
assert ret == correct, f" we got : {ret}\nexpected:{correct}"
477477
newres = stata_config.stata_acro.results.__dict__
478478
assert newres == previous, (

test/test_stata_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def test_stata_custom_output_invalid():
562562
options="nototals",
563563
stata_version="17",
564564
)
565-
correct = "Warning: .gph files may not be exported as they contain data."
565+
correct = "Warning: .gph files are not allowed and cannot be exported."
566566
assert ret == correct, f" we got : {ret}\nexpected:{correct}"
567567
newres = stata_config.stata_acro.results.__dict__
568568
assert newres == previous, (

0 commit comments

Comments
 (0)