diff --git a/bare_metal_billing/billing.py b/bare_metal_billing/billing.py index 938e0f7..9311c3a 100644 --- a/bare_metal_billing/billing.py +++ b/bare_metal_billing/billing.py @@ -88,8 +88,11 @@ def _get_su_type(lease_info: models.BMNodeUsage): def _get_running_time( - lease_info: models.BMNodeUsage, start_time: datetime, end_time: datetime -): + lease_info: models.BMNodeUsage, + start_time: datetime, + end_time: datetime, + excluded_intervals: list[tuple[datetime, datetime]] | None = None, +) -> int: start_time = _clamp_time(lease_info.start_time, start_time, end_time) end_time = ( end_time @@ -97,7 +100,17 @@ def _get_running_time( is None # Assumes lease is still running if no expire time given else _clamp_time(lease_info.expire_time, start_time, end_time) ) - return math.ceil((end_time - start_time).total_seconds() / 3600) + + total_interval_duration = (end_time - start_time).total_seconds() + if excluded_intervals: + for e_interval_start, e_interval_end in excluded_intervals: + e_interval_start = _clamp_time(e_interval_start, start_time, end_time) + e_interval_end = _clamp_time(e_interval_end, start_time, end_time) + total_interval_duration -= ( + e_interval_end - e_interval_start + ).total_seconds() + + return math.ceil(max(0, total_interval_duration) / 3600) def _clamp_time(time, min_time, max_time): @@ -109,7 +122,10 @@ def _clamp_time(time, min_time, max_time): def get_project_invoices( - bm_usage_data: models.BMUsageData, start_time: datetime, end_time: datetime + bm_usage_data: models.BMUsageData, + start_time: datetime, + end_time: datetime, + excluded_time_ranges: list[tuple[datetime, datetime]] | None = None, ) -> list[models.ProjectUsage]: project_usage_dict = {} for lease_info in bm_usage_data.root: @@ -126,7 +142,9 @@ def get_project_invoices( logger.warning( f"Unknown resource class {lease_info.resource_class} (resource {lease_info.resource}) in lease {lease_info.uuid}." ) - su_hours = _get_running_time(lease_info, start_time, end_time) + su_hours = _get_running_time( + lease_info, start_time, end_time, excluded_time_ranges + ) project_usage_dict[project_name].add_usage(su_type, su_hours) diff --git a/bare_metal_billing/main.py b/bare_metal_billing/main.py index 2b82025..45088a6 100644 --- a/bare_metal_billing/main.py +++ b/bare_metal_billing/main.py @@ -1,5 +1,5 @@ from decimal import Decimal -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import json import argparse import logging @@ -14,7 +14,7 @@ def parse_time_from_string(time_str: str) -> datetime: - return datetime.fromisoformat(time_str) + return datetime.fromisoformat(time_str).replace(tzinfo=timezone.utc) def parse_time_argument(arg): @@ -23,14 +23,36 @@ def parse_time_argument(arg): return arg +def parse_time_range(arg: str): + start_str, end_str = arg.split(",") + start_time, end_time = [parse_time_from_string(i) for i in (start_str, end_str)] + if start_time >= end_time: + raise argparse.ArgumentTypeError( + f"Start time {start_time} must be before end time {end_time}." + ) + return start_time, end_time + + +def check_overlapping_intervals(arg_list: list[tuple[datetime, datetime]] | None): + if not arg_list: + return + + sorted_intervals = sorted(arg_list, key=lambda x: x[0]) + for i in range(1, len(sorted_intervals)): + if sorted_intervals[i][0] < sorted_intervals[i - 1][1]: + raise ValueError( + f"Overlapping time ranges: {sorted_intervals[i-1]} and {sorted_intervals[i]}" + ) + + def default_start_argument(): - d = (datetime.today() - timedelta(days=1)).replace(day=1) + d = (datetime.today() - timedelta(days=1)).replace(day=1, tzinfo=timezone.utc) d = d.replace(hour=0, minute=0, second=0, microsecond=0) return d def default_end_argument(): - d = datetime.today() + d = datetime.today().replace(tzinfo=timezone.utc) d = d.replace(hour=0, minute=0, second=0, microsecond=0) return d @@ -100,6 +122,13 @@ def main(): "to 0 for each SU's resources" ), ) + parser.add_argument( + "--excluded-time-ranges", + type=parse_time_range, + default=[], + nargs="+", + help="List of time ranges excluded from billing, in format of ','. In UTC time", + ) args = parser.parse_args() @@ -107,6 +136,9 @@ def main(): logger.info(f"Interval for processing {args.start} - {args.end}.") logger.info(f"Invoice file will be saved to {args.output_file}.") + check_overlapping_intervals(args.excluded_time_ranges) + excluded_time_ranges = args.excluded_time_ranges + su_rates_dict = {} if args.use_nerc_rates: nerc_repo_rates = load_from_url() @@ -129,7 +161,9 @@ def main(): input_bm_json = json.load(f) input_invoice = models.BMUsageData.model_validate(input_bm_json) - project_invoices = billing.get_project_invoices(input_invoice, args.start, args.end) + project_invoices = billing.get_project_invoices( + input_invoice, args.start, args.end, excluded_time_ranges + ) invoice_writer = billing.InvoiceWriter( args.invoice_month, diff --git a/bare_metal_billing/tests/test_billing.py b/bare_metal_billing/tests/test_billing.py index 2b7fcaa..f64027c 100644 --- a/bare_metal_billing/tests/test_billing.py +++ b/bare_metal_billing/tests/test_billing.py @@ -1,8 +1,10 @@ import tempfile +from argparse import ArgumentTypeError from unittest import TestCase -from datetime import datetime +from datetime import datetime, timezone from bare_metal_billing import billing, models +from bare_metal_billing.main import parse_time_range, check_overlapping_intervals HOURS_IN_DAY = 24 @@ -226,3 +228,109 @@ def test_get_su_hours(self): ), 4, ) + + def _get_lease_fixture(self): + """Fixture used in tests for excluded time ranges""" + return self._get_bm_usage_data( + ["P1"], + start_times=[datetime(2020, 3, 15, 0, 0, 0)], + expire_times=[datetime(2020, 3, 17, 0, 0, 0)], + resource_classes=["fc430"], + ).root[0] + + def test_single_excluded_interval(self): + test_args_list = [ + ( + (datetime(2020, 3, 16, 9, 30, 0), datetime(2020, 3, 16, 10, 30, 0)), + HOURS_IN_DAY * 2 - 1, + ), # Exclusion within active interval + ( + (datetime(2020, 3, 13, 0, 0, 0), datetime(2020, 3, 16, 0, 0, 0)), + HOURS_IN_DAY * 1, + ), # Exclusion starts before active interval + ( + (datetime(2020, 3, 16, 0, 0, 0), datetime(2020, 3, 18, 0, 0, 0)), + HOURS_IN_DAY, + ), # Exclusion ends after active interval + ( + (datetime(2020, 3, 1, 0, 0, 0), datetime(2020, 3, 30, 0, 0, 0)), + 0, + ), # Entire active interval excluded + ] + + for excluded_interval, expected_hours in test_args_list: + hours = billing._get_running_time( + self._get_lease_fixture(), + datetime(2020, 3, 15, 0, 0, 0), + datetime(2020, 3, 17, 0, 0, 0), + [excluded_interval], + ) + self.assertEqual(hours, expected_hours) + + def test_running_time_excluded_intervals_outside_active(self): + excluded_intervals = [ + (datetime(2020, 3, 1, 0, 0, 0), datetime(2020, 3, 5, 0, 0, 0)), + (datetime(2020, 3, 10, 0, 0, 0), datetime(2020, 3, 11, 0, 0, 0)), + (datetime(2020, 3, 20, 0, 0, 0), datetime(2020, 3, 25, 0, 0, 0)), + ] + hours = billing._get_running_time( + self._get_lease_fixture(), + datetime(2020, 3, 12, 0, 0, 0), + datetime(2020, 3, 19, 0, 0, 0), + excluded_intervals, + ) + self.assertEqual(hours, HOURS_IN_DAY * 2) + + def test_running_time_multiple_excluded_intervals(self): + excluded_intervals = [ + (datetime(2020, 3, 13, 0, 0, 0), datetime(2020, 3, 15, 0, 0, 0)), + (datetime(2020, 3, 16, 0, 0, 0), datetime(2020, 3, 17, 0, 0, 0)), + (datetime(2020, 3, 18, 0, 0, 0), datetime(2020, 3, 20, 0, 0, 0)), + ] + lease_info = self._get_bm_usage_data( + ["P1"], + start_times=[datetime(2020, 3, 14, 0, 0, 0)], + expire_times=[datetime(2020, 3, 19, 0, 0, 0)], + resource_classes=["fc430"], + ).root[0] + hours = billing._get_running_time( + lease_info, + datetime(2020, 3, 14, 0, 0, 0), + datetime(2020, 3, 19, 0, 0, 0), + excluded_intervals, + ) + self.assertEqual(hours, HOURS_IN_DAY * 2) + + +class TestParseExcludedTimeRanges(BillingTestBase): + def test_valid_excluded_time_ranges(self): + valid_input = "2023-01-01T06:00:00,2023-01-02T12:00:00" + result = parse_time_range(valid_input) + self.assertEqual( + result, + ( + datetime(2023, 1, 1, 6, 0, 0, tzinfo=timezone.utc), + datetime(2023, 1, 2, 12, 0, 0, tzinfo=timezone.utc), + ), + ) + + def test_invalid_excluded_time_ranges_format(self): + invalid_input = "foo" + with self.assertRaises(ValueError): + parse_time_range(invalid_input) + + def test_invalid_excluded_time_ranges_order(self): + # End time before start time + invalid_input = "2023-01-02T00:00:00,2023-01-01T00:00:00" + with self.assertRaises(ArgumentTypeError): + parse_time_range(invalid_input) + + def test_overlapping_excluded_time_ranges(self): + invalid_input = [ + ( + datetime(2023, 1, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2023, 1, 2, 0, 0, 0, tzinfo=timezone.utc), + ) + ] * 2 + with self.assertRaises(ValueError): + check_overlapping_intervals(invalid_input)