Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions acro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@

from .acro import *
from .acro_regression import *

# Directory for storing artifacts
ARTIFACTS_DIR = "acro_artifacts"
24 changes: 12 additions & 12 deletions acro/acro_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from matplotlib import pyplot as plt
from pandas import DataFrame, Series

from . import utils
from . import ARTIFACTS_DIR, utils
from .record import Records

logger = logging.getLogger("acro")
Expand Down Expand Up @@ -583,10 +583,10 @@ def survival_plot( # pylint: disable=too-many-arguments
plot = survival_func.plot()

try:
os.makedirs("acro_artifacts")
logger.debug("Directory acro_artifacts created successfully")
os.makedirs(ARTIFACTS_DIR)
logger.debug(f"Directory {ARTIFACTS_DIR} created successfully")
except FileExistsError: # pragma: no cover
logger.debug("Directory acro_artifacts already exists")
logger.debug(f"Directory {ARTIFACTS_DIR} already exists")

# create a unique filename with number to avoid overwrite
filename, extension = os.path.splitext(filename)
Expand All @@ -595,10 +595,10 @@ def survival_plot( # pylint: disable=too-many-arguments
return None # pragma: no cover
increment_number = 0
while os.path.exists(
f"acro_artifacts/{filename}_{increment_number}{extension}"
f"{ARTIFACTS_DIR}/{filename}_{increment_number}{extension}"
): # pragma: no cover
increment_number += 1
unique_filename = f"acro_artifacts/{filename}_{increment_number}{extension}"
unique_filename = f"{ARTIFACTS_DIR}/{filename}_{increment_number}{extension}"

# save the plot to the acro artifacts directory
plt.savefig(unique_filename)
Expand Down Expand Up @@ -775,12 +775,12 @@ def hist( # pylint: disable=too-many-arguments,too-many-locals
f"The maximum value of the {column} column is: {max_value}"
)

# create the acro_artifacts directory to save the plot in it
# create the artifacts directory to save the plot in it
try:
os.makedirs("acro_artifacts")
logger.debug("Directory acro_artifacts created successfully")
os.makedirs(ARTIFACTS_DIR)
logger.debug(f"Directory {ARTIFACTS_DIR} created successfully")
except FileExistsError: # pragma: no cover
logger.debug("Directory acro_artifacts already exists")
logger.debug(f"Directory {ARTIFACTS_DIR} already exists")

# create a unique filename with number to avoid overwrite
filename, extension = os.path.splitext(filename)
Expand All @@ -789,10 +789,10 @@ def hist( # pylint: disable=too-many-arguments,too-many-locals
return None
increment_number = 0
while os.path.exists(
f"acro_artifacts/{filename}_{increment_number}{extension}"
f"{ARTIFACTS_DIR}/{filename}_{increment_number}{extension}"
): # pragma: no cover
increment_number += 1
unique_filename = f"acro_artifacts/{filename}_{increment_number}{extension}"
unique_filename = f"{ARTIFACTS_DIR}/{filename}_{increment_number}{extension}"

# save the plot to the acro artifacts directory
plt.savefig(unique_filename)
Expand Down
7 changes: 4 additions & 3 deletions acro/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pandas as pd
from pandas import DataFrame

from . import ARTIFACTS_DIR
from .version import __version__

logger = logging.getLogger("acro:records")
Expand Down Expand Up @@ -452,9 +453,9 @@ def finalise(self, path: str, ext: str, interactive: bool = False) -> None:
else:
raise ValueError("Invalid file extension. Options: {json, xlsx}")
self.write_checksums(path)
# check if the directory acro_artifacts exists and delete it
if os.path.exists("acro_artifacts"):
shutil.rmtree("acro_artifacts")
# check if the artifacts directory exists and delete it
if os.path.exists(ARTIFACTS_DIR):
shutil.rmtree(ARTIFACTS_DIR)
logger.info("outputs written to: %s", path)

def finalise_json(self, path: str) -> None:
Expand Down
16 changes: 12 additions & 4 deletions test/test_initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
import pytest
import statsmodels.api as sm

from acro import ACRO, acro_tables, add_constant, add_to_acro, record, utils
from acro import (
ACRO,
ARTIFACTS_DIR,
acro_tables,
add_constant,
add_to_acro,
record,
utils,
)
from acro.acro_tables import _rounded_survival_table, crosstab_with_totals
from acro.record import Records, load_records

Expand Down Expand Up @@ -707,7 +715,7 @@ def test_surv_func(acro):
assert "cells suppressed" in output.summary

# plot
filename = os.path.normpath("acro_artifacts/kaplan-meier_0.png")
filename = os.path.normpath(f"{ARTIFACTS_DIR}/kaplan-meier_0.png")
_ = acro.surv_func(data.futime, data.death, output="plot")
assert os.path.exists(filename)
acro.add_exception("output_0", "I need this")
Expand Down Expand Up @@ -1077,7 +1085,7 @@ def test_crosstab_with_manual_totals_with_suppression_with_two_aggfunc(

def test_histogram_disclosive(data, acro, caplog):
"""Test a discolsive histogram."""
filename = os.path.normpath("acro_artifacts/histogram_0.png")
filename = os.path.normpath(f"{ARTIFACTS_DIR}/histogram_0.png")
_ = acro.hist(data, "inc_grants")
assert os.path.exists(filename)
acro.add_exception("output_0", "Let me have it")
Expand All @@ -1094,7 +1102,7 @@ def test_histogram_disclosive(data, acro, caplog):

def test_histogram_non_disclosive(data, acro):
"""Test a non disclosive histogram."""
filename = os.path.normpath("acro_artifacts/histogram_0.png")
filename = os.path.normpath(f"{ARTIFACTS_DIR}/histogram_0.png")
_ = acro.hist(data, "inc_grants", bins=1)
assert os.path.exists(filename)
acro.add_exception("output_0", "Let me have it")
Expand Down