Skip to content

Commit f986cdd

Browse files
committed
Improve memory efficiency of metrics retrieval
1 parent 8006428 commit f986cdd

6 files changed

Lines changed: 212 additions & 99 deletions

File tree

src/neptune_query/internal/composition/fetch_metrics.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
split,
4747
)
4848
from ..retrieval.metrics import (
49-
FloatPointValue,
49+
MetricValues,
5050
fetch_multiple_series_values,
5151
)
5252
from ..retrieval.search import ContainerType
@@ -96,6 +96,7 @@ def fetch_metrics(
9696
project_identifier=project_identifier,
9797
step_range=step_range,
9898
lineage_to_the_root=lineage_to_the_root,
99+
include_timestamp=include_time is not None,
99100
include_point_previews=include_point_previews,
100101
tail_limit=tail_limit,
101102
executor=executor,
@@ -124,10 +125,11 @@ def _fetch_metrics(
124125
fetch_attribute_definitions_executor: Executor,
125126
step_range: tuple[Optional[float], Optional[float]],
126127
lineage_to_the_root: bool,
128+
include_timestamp: bool,
127129
include_point_previews: bool,
128130
tail_limit: Optional[int],
129131
container_type: ContainerType,
130-
) -> tuple[dict[identifiers.RunAttributeDefinition, list[FloatPointValue]], dict[identifiers.SysId, str]]:
132+
) -> tuple[dict[identifiers.RunAttributeDefinition, MetricValues], dict[identifiers.SysId, str]]:
131133
sys_id_label_mapping: dict[identifiers.SysId, str] = {}
132134

133135
def go_fetch_sys_attrs() -> Generator[list[identifiers.SysId], None, None]:
@@ -170,6 +172,7 @@ def go_fetch_sys_attrs() -> Generator[list[identifiers.SysId], None, None]:
170172
client=client,
171173
run_attribute_definitions=run_attribute_definitions_split,
172174
include_inherited=lineage_to_the_root,
175+
include_timestamp=include_timestamp,
173176
include_preview=include_point_previews,
174177
container_type=container_type,
175178
step_range=step_range,
@@ -180,13 +183,12 @@ def go_fetch_sys_attrs() -> Generator[list[identifiers.SysId], None, None]:
180183
),
181184
)
182185

183-
results: Generator[
184-
dict[identifiers.RunAttributeDefinition, list[FloatPointValue]], None, None
185-
] = concurrency.gather_results(output)
186+
results: Generator[dict[identifiers.RunAttributeDefinition, MetricValues], None, None] = concurrency.gather_results(
187+
output
188+
)
186189

187-
metrics_data: dict[identifiers.RunAttributeDefinition, list[FloatPointValue]] = {}
188-
for result in results:
189-
for run_attribute_definition, metric_points in result.items():
190-
metrics_data.setdefault(run_attribute_definition, []).extend(metric_points)
190+
metrics_data: dict[identifiers.RunAttributeDefinition, MetricValues] = {
191+
definition: metric_values for result in results for definition, metric_values in result.items()
192+
}
191193

192194
return metrics_data, sys_id_label_mapping

src/neptune_query/internal/output_format.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from . import identifiers
3232
from .retrieval import (
3333
metric_buckets,
34-
metrics,
3534
series,
3635
)
3736
from .retrieval.attribute_types import (
@@ -40,13 +39,7 @@
4039
Histogram,
4140
)
4241
from .retrieval.attribute_values import AttributeValue
43-
from .retrieval.metrics import (
44-
IsPreviewIndex,
45-
PreviewCompletionIndex,
46-
StepIndex,
47-
TimestampIndex,
48-
ValueIndex,
49-
)
42+
from .retrieval.metrics import MetricValues
5043
from .retrieval.search import ContainerType
5144
from .util import _validate_allowed_value
5245

@@ -142,7 +135,7 @@ def transform_column_names(df: pd.DataFrame) -> pd.DataFrame:
142135

143136

144137
def create_metrics_dataframe(
145-
metrics_data: dict[identifiers.RunAttributeDefinition, list[metrics.FloatPointValue]],
138+
metrics_data: dict[identifiers.RunAttributeDefinition, MetricValues],
146139
sys_id_label_mapping: dict[identifiers.SysId, str],
147140
*,
148141
type_suffix_in_column_names: bool,
@@ -176,15 +169,13 @@ def path_display_name(attr_def: identifiers.RunAttributeDefinition) -> str:
176169
paths_with_data: set[str] = set()
177170

178171
# Collect which (experiment, path) pairs have data and the set of observed steps per run.
179-
for definition, points in metrics_data.items():
180-
if not points:
172+
for definition, metric_values in metrics_data.items():
173+
if metric_values.length == 0:
181174
continue
182175

183176
paths_with_data.add(path_display_name(definition))
184-
185177
step_set = run_to_observed_steps.setdefault(sys_id_label_mapping[definition.run_identifier.sys_id], set())
186-
for point in points:
187-
step_set.add(point[StepIndex])
178+
step_set.update(metric_values.steps)
188179

189180
index_data = IndexData.from_observed_steps(
190181
observed_steps=run_to_observed_steps,
@@ -205,21 +196,21 @@ def path_display_name(attr_def: identifiers.RunAttributeDefinition) -> str:
205196
)
206197

207198
# Write every metric point directly into the pre-allocated buffers.
208-
for definition, points in metrics_data.items():
209-
if not points:
199+
for definition, metric_values in metrics_data.items():
200+
if metric_values.length == 0:
210201
continue
211202

212203
step_to_row_index: dict[float, int] = index_data.lookup_rows(sys_id=definition.run_identifier.sys_id)
204+
rows = np.array([step_to_row_index[step] for step in metric_values.steps], dtype=np.uint)
205+
213206
buffer: PathBuffer = path_buffers[path_display_name(definition)]
214-
for point in points:
215-
row_idx: int = step_to_row_index[point[StepIndex]]
216-
buffer.value[row_idx] = point[ValueIndex]
217-
if buffer.absolute_time is not None:
218-
buffer.absolute_time[row_idx] = point[TimestampIndex]
219-
if buffer.is_preview is not None:
220-
buffer.is_preview[row_idx] = point[IsPreviewIndex]
221-
if buffer.preview_completion is not None:
222-
buffer.preview_completion[row_idx] = point[PreviewCompletionIndex]
207+
buffer.value[rows] = metric_values.values
208+
if buffer.absolute_time is not None:
209+
buffer.absolute_time[rows] = metric_values.timestamps
210+
if buffer.is_preview is not None:
211+
buffer.is_preview[rows] = metric_values.is_preview
212+
if buffer.preview_completion is not None:
213+
buffer.preview_completion[rows] = metric_values.completion_ratio
223214

224215
return _assemble_wide_dataframe(
225216
index_data=index_data,

src/neptune_query/internal/retrieval/metrics.py

Lines changed: 97 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
# limitations under the License.
1515

1616
import functools as ft
17+
from dataclasses import dataclass
1718
from typing import (
1819
Any,
1920
Optional,
2021
Union,
2122
)
2223

24+
import numpy as np
2325
from neptune_api.api.retrieval import get_multiple_float_series_values_proto
2426
from neptune_api.client import AuthenticatedClient
2527
from neptune_api.models import FloatTimeSeriesValuesRequest
@@ -37,28 +39,62 @@
3739

3840
logger = get_logger()
3941

40-
# Tuples are used here to enhance performance
41-
FloatPointValue = tuple[float, float, float, bool, float]
42-
(
43-
TimestampIndex,
44-
StepIndex,
45-
ValueIndex,
46-
IsPreviewIndex,
47-
PreviewCompletionIndex,
48-
) = range(5)
49-
5042
TOTAL_POINT_LIMIT: int = 1_000_000
5143

5244

45+
@dataclass(frozen=True, slots=True)
46+
class MetricValues:
47+
steps: np.ndarray
48+
values: np.ndarray
49+
timestamps: Optional[np.ndarray]
50+
is_preview: Optional[np.ndarray]
51+
completion_ratio: Optional[np.ndarray]
52+
53+
@classmethod
54+
def allocate(cls, size: int, include_timestamp: bool, include_preview: bool) -> "MetricValues":
55+
return cls(
56+
steps=np.empty(size, dtype=np.float64),
57+
values=np.empty(size, dtype=np.float64),
58+
timestamps=np.empty(size, dtype=np.float64) if include_timestamp else None,
59+
is_preview=np.empty(size, dtype=bool) if include_preview else None,
60+
completion_ratio=np.empty(size, dtype=np.float64) if include_preview else None,
61+
)
62+
63+
@classmethod
64+
def concatenate(cls, metrics_list: list["MetricValues"]) -> "MetricValues":
65+
return cls(
66+
steps=np.concatenate([m.steps for m in metrics_list], axis=0),
67+
values=np.concatenate([m.values for m in metrics_list], axis=0),
68+
timestamps=np.concatenate([m.timestamps for m in metrics_list], axis=0)
69+
if metrics_list[0].timestamps is not None
70+
else None,
71+
is_preview=np.concatenate([m.is_preview for m in metrics_list], axis=0)
72+
if metrics_list[0].is_preview is not None
73+
else None,
74+
completion_ratio=np.concatenate([m.completion_ratio for m in metrics_list], axis=0)
75+
if metrics_list[0].completion_ratio is not None
76+
else None,
77+
)
78+
79+
@property
80+
def length(self) -> int:
81+
return len(self.steps)
82+
83+
@classmethod
84+
def length_sum(cls, metrics_list: list["MetricValues"]) -> int:
85+
return sum(m.length for m in metrics_list)
86+
87+
5388
def fetch_multiple_series_values(
5489
client: AuthenticatedClient,
5590
run_attribute_definitions: list[identifiers.RunAttributeDefinition],
5691
include_inherited: bool,
5792
container_type: ContainerType,
93+
include_timestamp: bool,
5894
include_preview: bool,
5995
step_range: tuple[Union[float, None], Union[float, None]] = (None, None),
6096
tail_limit: Optional[int] = None,
61-
) -> dict[identifiers.RunAttributeDefinition, list[FloatPointValue]]:
97+
) -> dict[identifiers.RunAttributeDefinition, MetricValues]:
6298
if not run_attribute_definitions:
6399
return {}
64100

@@ -93,25 +129,37 @@ def fetch_multiple_series_values(
93129
"order": "ascending" if not tail_limit else "descending",
94130
}
95131

96-
results: dict[identifiers.RunAttributeDefinition, list[FloatPointValue]] = {
97-
run_attribute: [] for run_attribute in run_attribute_definitions
98-
}
132+
paged_results: dict[identifiers.RunAttributeDefinition, list[MetricValues]] = {}
99133

100134
for page_result in util.fetch_pages(
101135
client=client,
102136
fetch_page=_fetch_metrics_page,
103-
process_page=ft.partial(_process_metrics_page, request_id_to_attribute=request_id_to_attribute),
137+
process_page=ft.partial(
138+
_process_metrics_page,
139+
request_id_to_attribute=request_id_to_attribute,
140+
include_timestamp=include_timestamp,
141+
include_preview=include_preview,
142+
reverse_order=tail_limit is not None,
143+
),
104144
make_new_page_params=ft.partial(
105145
_make_new_metrics_page_params,
106146
request_id_to_attribute=request_id_to_attribute,
107147
tail_limit=tail_limit,
108-
partial_results=results,
148+
partial_results=paged_results,
109149
),
110150
params=params,
111151
):
112-
for attribute, values in page_result.items:
113-
sorted_values = values if tail_limit else reversed(values)
114-
results[attribute].extend(sorted_values)
152+
for definition, metric_values in page_result.items:
153+
paged_results.setdefault(definition, []).append(metric_values)
154+
155+
results: dict[identifiers.RunAttributeDefinition, MetricValues] = {}
156+
for definition, paged_metric_values in paged_results.items():
157+
if len(paged_metric_values) > 1:
158+
results[definition] = MetricValues.concatenate(paged_metric_values)
159+
elif len(paged_metric_values) == 1:
160+
results[definition] = paged_metric_values[0]
161+
else:
162+
pass
115163

116164
return results
117165

@@ -138,20 +186,32 @@ def _fetch_metrics_page(
138186
def _process_metrics_page(
139187
data: ProtoFloatSeriesValuesResponseDTO,
140188
request_id_to_attribute: dict[str, identifiers.RunAttributeDefinition],
141-
) -> util.Page[tuple[identifiers.RunAttributeDefinition, list[FloatPointValue]]]:
189+
include_timestamp: bool,
190+
include_preview: bool,
191+
reverse_order: bool,
192+
) -> util.Page[tuple[identifiers.RunAttributeDefinition, MetricValues]]:
142193
result = {}
143194
for series in data.series:
144-
run_attribute = request_id_to_attribute[series.requestId]
145-
result[run_attribute] = [
146-
(
147-
point.timestamp_millis,
148-
point.step,
149-
point.value,
150-
point.is_preview,
151-
point.completion_ratio,
152-
)
153-
for point in series.series.values
154-
]
195+
metric_values = MetricValues.allocate(
196+
size=len(series.series.values), include_timestamp=include_timestamp, include_preview=include_preview
197+
)
198+
199+
for i, point in enumerate(series.series.values):
200+
idx = metric_values.length - 1 - i if reverse_order else i
201+
202+
metric_values.steps[idx] = point.step
203+
metric_values.values[idx] = point.value
204+
if include_timestamp:
205+
assert metric_values.timestamps
206+
metric_values.timestamps[idx] = point.timestamp_millis
207+
if include_preview:
208+
assert metric_values.is_preview
209+
assert metric_values.completion_ratio
210+
metric_values.is_preview[idx] = point.is_preview
211+
metric_values.completion_ratio[idx] = point.completion_ratio
212+
definition = request_id_to_attribute[series.requestId]
213+
result[definition] = metric_values
214+
155215
return util.Page(items=list(result.items()))
156216

157217

@@ -160,7 +220,7 @@ def _make_new_metrics_page_params(
160220
data: Optional[ProtoFloatSeriesValuesResponseDTO],
161221
request_id_to_attribute: dict[str, identifiers.RunAttributeDefinition],
162222
tail_limit: Optional[int],
163-
partial_results: dict[identifiers.RunAttributeDefinition, list[FloatPointValue]],
223+
partial_results: dict[identifiers.RunAttributeDefinition, list[MetricValues]],
164224
) -> Optional[dict[str, Any]]:
165225
if data is None: # no past data, we are fetching the first page
166226
for request in params["requests"]:
@@ -181,7 +241,9 @@ def _make_new_metrics_page_params(
181241
is_page_full = value_size == prev_per_series_points_limit
182242

183243
attribute = request_id_to_attribute[request_id]
184-
need_more_points = len(partial_results[attribute]) < tail_limit if tail_limit is not None else True
244+
need_more_points = (
245+
MetricValues.length_sum(partial_results[attribute]) < tail_limit if tail_limit is not None else True
246+
)
185247

186248
if is_page_full and need_more_points:
187249
new_request_after_steps[request_id] = series.series.values[-1].step
@@ -201,7 +263,8 @@ def _make_new_metrics_page_params(
201263
per_series_points_limit = max(1, TOTAL_POINT_LIMIT // len(params["requests"]))
202264
if tail_limit is not None:
203265
already_fetched = next(
204-
len(partial_results[request_id_to_attribute[request_id]]) for request_id in new_request_after_steps.keys()
266+
MetricValues.length_sum(partial_results[request_id_to_attribute[request_id]])
267+
for request_id in new_request_after_steps.keys()
205268
) # assumes the results for all unfinished series have the same length
206269
per_series_points_limit = min(per_series_points_limit, tail_limit - already_fetched)
207270
params["perSeriesPointsLimit"] = per_series_points_limit

0 commit comments

Comments
 (0)