1414# limitations under the License.
1515
1616import functools as ft
17+ from dataclasses import dataclass
1718from typing import (
1819 Any ,
1920 Optional ,
2021 Union ,
2122)
2223
24+ import numpy as np
2325from neptune_api .api .retrieval import get_multiple_float_series_values_proto
2426from neptune_api .client import AuthenticatedClient
2527from neptune_api .models import FloatTimeSeriesValuesRequest
3739
3840logger = get_logger ()
3941
40- # Tuples are used here to enhance performance
41- FloatPointValue = tuple [float , float , float , bool , float ]
42- (
43- TimestampIndex ,
44- StepIndex ,
45- ValueIndex ,
46- IsPreviewIndex ,
47- PreviewCompletionIndex ,
48- ) = range (5 )
49-
5042TOTAL_POINT_LIMIT : int = 1_000_000
5143
5244
45+ @dataclass (frozen = True , slots = True )
46+ class MetricValues :
47+ steps : np .ndarray
48+ values : np .ndarray
49+ timestamps : Optional [np .ndarray ]
50+ is_preview : Optional [np .ndarray ]
51+ completion_ratio : Optional [np .ndarray ]
52+
53+ @classmethod
54+ def allocate (cls , size : int , include_timestamp : bool , include_preview : bool ) -> "MetricValues" :
55+ return cls (
56+ steps = np .empty (size , dtype = np .float64 ),
57+ values = np .empty (size , dtype = np .float64 ),
58+ timestamps = np .empty (size , dtype = np .float64 ) if include_timestamp else None ,
59+ is_preview = np .empty (size , dtype = bool ) if include_preview else None ,
60+ completion_ratio = np .empty (size , dtype = np .float64 ) if include_preview else None ,
61+ )
62+
63+ @classmethod
64+ def concatenate (cls , metrics_list : list ["MetricValues" ]) -> "MetricValues" :
65+ return cls (
66+ steps = np .concatenate ([m .steps for m in metrics_list ], axis = 0 ),
67+ values = np .concatenate ([m .values for m in metrics_list ], axis = 0 ),
68+ timestamps = np .concatenate ([m .timestamps for m in metrics_list ], axis = 0 )
69+ if metrics_list [0 ].timestamps is not None
70+ else None ,
71+ is_preview = np .concatenate ([m .is_preview for m in metrics_list ], axis = 0 )
72+ if metrics_list [0 ].is_preview is not None
73+ else None ,
74+ completion_ratio = np .concatenate ([m .completion_ratio for m in metrics_list ], axis = 0 )
75+ if metrics_list [0 ].completion_ratio is not None
76+ else None ,
77+ )
78+
79+ @property
80+ def length (self ) -> int :
81+ return len (self .steps )
82+
83+ @classmethod
84+ def length_sum (cls , metrics_list : list ["MetricValues" ]) -> int :
85+ return sum (m .length for m in metrics_list )
86+
87+
5388def fetch_multiple_series_values (
5489 client : AuthenticatedClient ,
5590 run_attribute_definitions : list [identifiers .RunAttributeDefinition ],
5691 include_inherited : bool ,
5792 container_type : ContainerType ,
93+ include_timestamp : bool ,
5894 include_preview : bool ,
5995 step_range : tuple [Union [float , None ], Union [float , None ]] = (None , None ),
6096 tail_limit : Optional [int ] = None ,
61- ) -> dict [identifiers .RunAttributeDefinition , list [ FloatPointValue ] ]:
97+ ) -> dict [identifiers .RunAttributeDefinition , MetricValues ]:
6298 if not run_attribute_definitions :
6399 return {}
64100
@@ -93,25 +129,37 @@ def fetch_multiple_series_values(
93129 "order" : "ascending" if not tail_limit else "descending" ,
94130 }
95131
96- results : dict [identifiers .RunAttributeDefinition , list [FloatPointValue ]] = {
97- run_attribute : [] for run_attribute in run_attribute_definitions
98- }
132+ paged_results : dict [identifiers .RunAttributeDefinition , list [MetricValues ]] = {}
99133
100134 for page_result in util .fetch_pages (
101135 client = client ,
102136 fetch_page = _fetch_metrics_page ,
103- process_page = ft .partial (_process_metrics_page , request_id_to_attribute = request_id_to_attribute ),
137+ process_page = ft .partial (
138+ _process_metrics_page ,
139+ request_id_to_attribute = request_id_to_attribute ,
140+ include_timestamp = include_timestamp ,
141+ include_preview = include_preview ,
142+ reverse_order = tail_limit is not None ,
143+ ),
104144 make_new_page_params = ft .partial (
105145 _make_new_metrics_page_params ,
106146 request_id_to_attribute = request_id_to_attribute ,
107147 tail_limit = tail_limit ,
108- partial_results = results ,
148+ partial_results = paged_results ,
109149 ),
110150 params = params ,
111151 ):
112- for attribute , values in page_result .items :
113- sorted_values = values if tail_limit else reversed (values )
114- results [attribute ].extend (sorted_values )
152+ for definition , metric_values in page_result .items :
153+ paged_results .setdefault (definition , []).append (metric_values )
154+
155+ results : dict [identifiers .RunAttributeDefinition , MetricValues ] = {}
156+ for definition , paged_metric_values in paged_results .items ():
157+ if len (paged_metric_values ) > 1 :
158+ results [definition ] = MetricValues .concatenate (paged_metric_values )
159+ elif len (paged_metric_values ) == 1 :
160+ results [definition ] = paged_metric_values [0 ]
161+ else :
162+ pass
115163
116164 return results
117165
@@ -138,20 +186,32 @@ def _fetch_metrics_page(
138186def _process_metrics_page (
139187 data : ProtoFloatSeriesValuesResponseDTO ,
140188 request_id_to_attribute : dict [str , identifiers .RunAttributeDefinition ],
141- ) -> util .Page [tuple [identifiers .RunAttributeDefinition , list [FloatPointValue ]]]:
189+ include_timestamp : bool ,
190+ include_preview : bool ,
191+ reverse_order : bool ,
192+ ) -> util .Page [tuple [identifiers .RunAttributeDefinition , MetricValues ]]:
142193 result = {}
143194 for series in data .series :
144- run_attribute = request_id_to_attribute [series .requestId ]
145- result [run_attribute ] = [
146- (
147- point .timestamp_millis ,
148- point .step ,
149- point .value ,
150- point .is_preview ,
151- point .completion_ratio ,
152- )
153- for point in series .series .values
154- ]
195+ metric_values = MetricValues .allocate (
196+ size = len (series .series .values ), include_timestamp = include_timestamp , include_preview = include_preview
197+ )
198+
199+ for i , point in enumerate (series .series .values ):
200+ idx = metric_values .length - 1 - i if reverse_order else i
201+
202+ metric_values .steps [idx ] = point .step
203+ metric_values .values [idx ] = point .value
204+ if include_timestamp :
205+ assert metric_values .timestamps
206+ metric_values .timestamps [idx ] = point .timestamp_millis
207+ if include_preview :
208+ assert metric_values .is_preview
209+ assert metric_values .completion_ratio
210+ metric_values .is_preview [idx ] = point .is_preview
211+ metric_values .completion_ratio [idx ] = point .completion_ratio
212+ definition = request_id_to_attribute [series .requestId ]
213+ result [definition ] = metric_values
214+
155215 return util .Page (items = list (result .items ()))
156216
157217
@@ -160,7 +220,7 @@ def _make_new_metrics_page_params(
160220 data : Optional [ProtoFloatSeriesValuesResponseDTO ],
161221 request_id_to_attribute : dict [str , identifiers .RunAttributeDefinition ],
162222 tail_limit : Optional [int ],
163- partial_results : dict [identifiers .RunAttributeDefinition , list [FloatPointValue ]],
223+ partial_results : dict [identifiers .RunAttributeDefinition , list [MetricValues ]],
164224) -> Optional [dict [str , Any ]]:
165225 if data is None : # no past data, we are fetching the first page
166226 for request in params ["requests" ]:
@@ -181,7 +241,9 @@ def _make_new_metrics_page_params(
181241 is_page_full = value_size == prev_per_series_points_limit
182242
183243 attribute = request_id_to_attribute [request_id ]
184- need_more_points = len (partial_results [attribute ]) < tail_limit if tail_limit is not None else True
244+ need_more_points = (
245+ MetricValues .length_sum (partial_results [attribute ]) < tail_limit if tail_limit is not None else True
246+ )
185247
186248 if is_page_full and need_more_points :
187249 new_request_after_steps [request_id ] = series .series .values [- 1 ].step
@@ -201,7 +263,8 @@ def _make_new_metrics_page_params(
201263 per_series_points_limit = max (1 , TOTAL_POINT_LIMIT // len (params ["requests" ]))
202264 if tail_limit is not None :
203265 already_fetched = next (
204- len (partial_results [request_id_to_attribute [request_id ]]) for request_id in new_request_after_steps .keys ()
266+ MetricValues .length_sum (partial_results [request_id_to_attribute [request_id ]])
267+ for request_id in new_request_after_steps .keys ()
205268 ) # assumes the results for all unfinished series have the same length
206269 per_series_points_limit = min (per_series_points_limit , tail_limit - already_fetched )
207270 params ["perSeriesPointsLimit" ] = per_series_points_limit
0 commit comments