diff --git a/acro/__init__.py b/acro/__init__.py index b32ac210..6fada08b 100644 --- a/acro/__init__.py +++ b/acro/__init__.py @@ -2,3 +2,6 @@ from .acro import * from .acro_regression import * + +# Directory for storing artifacts +ARTIFACTS_DIR = "acro_artifacts" diff --git a/acro/acro_tables.py b/acro/acro_tables.py index 8b92990a..b73e4a7c 100644 --- a/acro/acro_tables.py +++ b/acro/acro_tables.py @@ -16,7 +16,7 @@ from matplotlib import pyplot as plt from pandas import DataFrame, Series -from . import utils +from . import ARTIFACTS_DIR, utils from .record import Records logger = logging.getLogger("acro") @@ -583,10 +583,10 @@ def survival_plot( # pylint: disable=too-many-arguments plot = survival_func.plot() try: - os.makedirs("acro_artifacts") - logger.debug("Directory acro_artifacts created successfully") + os.makedirs(ARTIFACTS_DIR) + logger.debug(f"Directory {ARTIFACTS_DIR} created successfully") except FileExistsError: # pragma: no cover - logger.debug("Directory acro_artifacts already exists") + logger.debug(f"Directory {ARTIFACTS_DIR} already exists") # create a unique filename with number to avoid overwrite filename, extension = os.path.splitext(filename) @@ -595,10 +595,10 @@ def survival_plot( # pylint: disable=too-many-arguments return None # pragma: no cover increment_number = 0 while os.path.exists( - f"acro_artifacts/{filename}_{increment_number}{extension}" + f"{ARTIFACTS_DIR}/{filename}_{increment_number}{extension}" ): # pragma: no cover increment_number += 1 - unique_filename = f"acro_artifacts/{filename}_{increment_number}{extension}" + unique_filename = f"{ARTIFACTS_DIR}/{filename}_{increment_number}{extension}" # save the plot to the acro artifacts directory plt.savefig(unique_filename) @@ -775,12 +775,12 @@ def hist( # pylint: disable=too-many-arguments,too-many-locals f"The maximum value of the {column} column is: {max_value}" ) - # create the acro_artifacts directory to save the plot in it + # create the artifacts directory to save the plot in it try: - os.makedirs("acro_artifacts") - logger.debug("Directory acro_artifacts created successfully") + os.makedirs(ARTIFACTS_DIR) + logger.debug(f"Directory {ARTIFACTS_DIR} created successfully") except FileExistsError: # pragma: no cover - logger.debug("Directory acro_artifacts already exists") + logger.debug(f"Directory {ARTIFACTS_DIR} already exists") # create a unique filename with number to avoid overwrite filename, extension = os.path.splitext(filename) @@ -789,10 +789,10 @@ def hist( # pylint: disable=too-many-arguments,too-many-locals return None increment_number = 0 while os.path.exists( - f"acro_artifacts/{filename}_{increment_number}{extension}" + f"{ARTIFACTS_DIR}/{filename}_{increment_number}{extension}" ): # pragma: no cover increment_number += 1 - unique_filename = f"acro_artifacts/{filename}_{increment_number}{extension}" + unique_filename = f"{ARTIFACTS_DIR}/{filename}_{increment_number}{extension}" # save the plot to the acro artifacts directory plt.savefig(unique_filename) diff --git a/acro/record.py b/acro/record.py index c1b0a86a..fef57283 100644 --- a/acro/record.py +++ b/acro/record.py @@ -14,6 +14,7 @@ import pandas as pd from pandas import DataFrame +from . import ARTIFACTS_DIR from .version import __version__ logger = logging.getLogger("acro:records") @@ -452,9 +453,9 @@ def finalise(self, path: str, ext: str, interactive: bool = False) -> None: else: raise ValueError("Invalid file extension. Options: {json, xlsx}") self.write_checksums(path) - # check if the directory acro_artifacts exists and delete it - if os.path.exists("acro_artifacts"): - shutil.rmtree("acro_artifacts") + # check if the artifacts directory exists and delete it + if os.path.exists(ARTIFACTS_DIR): + shutil.rmtree(ARTIFACTS_DIR) logger.info("outputs written to: %s", path) def finalise_json(self, path: str) -> None: diff --git a/test/test_initial.py b/test/test_initial.py index f01be6fe..abed160e 100644 --- a/test/test_initial.py +++ b/test/test_initial.py @@ -10,7 +10,15 @@ import pytest import statsmodels.api as sm -from acro import ACRO, acro_tables, add_constant, add_to_acro, record, utils +from acro import ( + ACRO, + ARTIFACTS_DIR, + acro_tables, + add_constant, + add_to_acro, + record, + utils, +) from acro.acro_tables import _rounded_survival_table, crosstab_with_totals from acro.record import Records, load_records @@ -707,7 +715,7 @@ def test_surv_func(acro): assert "cells suppressed" in output.summary # plot - filename = os.path.normpath("acro_artifacts/kaplan-meier_0.png") + filename = os.path.normpath(f"{ARTIFACTS_DIR}/kaplan-meier_0.png") _ = acro.surv_func(data.futime, data.death, output="plot") assert os.path.exists(filename) acro.add_exception("output_0", "I need this") @@ -1077,7 +1085,7 @@ def test_crosstab_with_manual_totals_with_suppression_with_two_aggfunc( def test_histogram_disclosive(data, acro, caplog): """Test a discolsive histogram.""" - filename = os.path.normpath("acro_artifacts/histogram_0.png") + filename = os.path.normpath(f"{ARTIFACTS_DIR}/histogram_0.png") _ = acro.hist(data, "inc_grants") assert os.path.exists(filename) acro.add_exception("output_0", "Let me have it") @@ -1094,7 +1102,7 @@ def test_histogram_disclosive(data, acro, caplog): def test_histogram_non_disclosive(data, acro): """Test a non disclosive histogram.""" - filename = os.path.normpath("acro_artifacts/histogram_0.png") + filename = os.path.normpath(f"{ARTIFACTS_DIR}/histogram_0.png") _ = acro.hist(data, "inc_grants", bins=1) assert os.path.exists(filename) acro.add_exception("output_0", "Let me have it")