3131)
3232from graphnet .exceptions .exceptions import ColumnMissingException
3333from graphnet .utilities .logging import Logger
34- from graphnet .models .graphs import GraphDefinition
35-
34+ from graphnet .models .data_representation import (
35+ GraphDefinition ,
36+ DataRepresentation ,
37+ )
3638from graphnet .utilities .config .parsing import (
3739 get_all_grapnet_classes ,
3840)
@@ -58,11 +60,11 @@ def load_module(class_name: str) -> Type:
5860 return namespace_classes [class_name ]
5961
6062
61- def parse_graph_definition ( cfg : dict ) -> GraphDefinition :
62- """Construct GraphDefinition from DatasetConfig."""
63- assert cfg [ "graph_definition" ] is not None
63+ def parse_data_representation ( data_rep_cfg : dict ) -> DataRepresentation :
64+ """Construct DataRepresentation from DatasetConfig."""
65+ assert data_rep_cfg is not None
6466
65- args = cfg [ "graph_definition" ] ["arguments" ]
67+ args = data_rep_cfg ["arguments" ]
6668 classes = {}
6769 for arg in args .keys ():
6870 if isinstance (args [arg ], dict ):
@@ -75,10 +77,8 @@ def parse_graph_definition(cfg: dict) -> GraphDefinition:
7577
7678 new_cfg = deepcopy (args )
7779 new_cfg .update (classes )
78- graph_definition = load_module (cfg ["graph_definition" ]["class_name" ])(
79- ** new_cfg
80- )
81- return graph_definition
80+ data_representation = load_module (data_rep_cfg ["class_name" ])(** new_cfg )
81+ return data_representation
8282
8383
8484def parse_labels (cfg : dict ) -> Dict [str , Label ]:
@@ -122,9 +122,18 @@ def from_config( # type: ignore[override]
122122 "`DatasetConfig`"
123123 )
124124
125- assert (
126- "graph_definition" in source .dict ().keys ()
127- ), "`DatasetConfig` incompatible with current GraphNeT version."
125+ if "data_representation" not in source .dict ().keys ():
126+ if "graph_definition" in source .dict ().keys ():
127+ Logger (log_folder = None ).warning_once (
128+ "DeprecationWarning: Field `graph_definition` will be"
129+ " deprecated in GraphNeT 2.0. Please use "
130+ "`data_representation` instead."
131+ )
132+ else :
133+ raise TypeError (
134+ "`DatasetConfig` incompatible with "
135+ "current GraphNeT version."
136+ )
128137
129138 # Parse set of `selection``.
130139 if isinstance (source .selection , dict ):
@@ -137,8 +146,24 @@ def from_config( # type: ignore[override]
137146 return cls ._construct_dataset_from_list_of_strings (source )
138147
139148 cfg = source .dict ()
140- if cfg ["graph_definition" ] is not None :
141- cfg ["graph_definition" ] = parse_graph_definition (cfg )
149+
150+ if (
151+ "data_representation" in cfg
152+ and cfg ["data_representation" ] is not None
153+ ):
154+ cfg ["data_representation" ] = parse_data_representation (
155+ cfg ["data_representation" ]
156+ )
157+ elif "graph_definition" in cfg and cfg ["graph_definition" ] is not None :
158+ Logger (log_folder = None ).warning_once (
159+ "DeprecationWarning: Field `graph_definition` will be"
160+ " deprecated in GraphNeT 2.0. Please use "
161+ "`data_representation` instead."
162+ )
163+ cfg ["graph_definition" ] = parse_data_representation (
164+ cfg ["graph_definition" ]
165+ )
166+
142167 if cfg ["labels" ] is not None :
143168 cfg ["labels" ] = parse_labels (cfg )
144169
@@ -216,11 +241,12 @@ def _resolve_graphnet_paths(
216241 def __init__ (
217242 self ,
218243 path : Union [str , List [str ]],
219- graph_definition : GraphDefinition ,
220244 pulsemaps : Union [str , List [str ]],
221245 features : List [str ],
222246 truth : List [str ],
223247 * ,
248+ graph_definition : Optional [GraphDefinition ] = None ,
249+ data_representation : Optional [DataRepresentation ] = None ,
224250 node_truth : Optional [List [str ]] = None ,
225251 index_column : str = "event_no" ,
226252 truth_table : str = "truth" ,
@@ -278,8 +304,13 @@ def __init__(
278304 subset of events when resolving a string-based selection (e.g.,
279305 `"10000 random events ~ event_no % 5 > 0"` or `"20% random
280306 events ~ event_no % 5 > 0"`).
281- graph_definition : Method that defines the graph representation.
307+ data_representation : Method that defines the data representation.
282308 labels: Dictionary of labels to be added to the dataset.
309+
310+ graph_definition: Method that defines the graph representation.
311+ NOTE: DEPRECATED Use `data_representation` instead.
312+ # DEPRECATION: REMOVE AT 2.0 LAUNCH
313+ # See https://github.com/graphnet-team/graphnet/issues/647
283314 """
284315 # Base class constructor
285316 super ().__init__ (name = __name__ , class_name = self .__class__ .__name__ )
@@ -303,9 +334,26 @@ def __init__(
303334 self ._index_column = index_column
304335 self ._truth_table = truth_table
305336 self ._loss_weight_default_value = loss_weight_default_value
306- self ._graph_definition = deepcopy (graph_definition )
337+
338+ if data_representation is None :
339+ if graph_definition is not None :
340+ data_representation = graph_definition
341+ # Code continues after warning
342+ self .warning (
343+ "DeprecationWarning: Argument `graph_definition` "
344+ "will be deprecated in GraphNeT 2.0. "
345+ "Please use `data_representation` instead."
346+ )
347+ else :
348+ # Code stops
349+ raise TypeError (
350+ "__init__() missing 1 required keyword argument:"
351+ "'data_representation'"
352+ )
353+
354+ self ._data_representation = deepcopy (data_representation )
307355 self ._labels = labels
308- self ._string_column = graph_definition ._detector .string_index_name
356+ self ._string_column = data_representation ._detector .string_index_name
309357
310358 if node_truth is not None :
311359 assert isinstance (node_truth_table , str )
@@ -390,6 +438,17 @@ def truth_table(self) -> str:
390438 """Name of the table containing event-level truth information."""
391439 return self ._truth_table
392440
441+ # DEPRECATION PROPERTY: REMOVE AT 2.0 LAUNCH
442+ # See https://github.com/graphnet-team/graphnet/issues/647
443+ @property
444+ def _graph_definition (self ) -> DataRepresentation :
445+ """Return the graph definition."""
446+ self .warning (
447+ "DeprecationWarning: `_graph_definition` will be deprecated in"
448+ " GraphNeT 2.0. Please use `_data_representation` instead."
449+ )
450+ return self ._data_representation
451+
393452 # Abstract method(s)
394453 @abstractmethod
395454 def _init (self ) -> None :
@@ -647,8 +706,8 @@ def _create_graph(
647706
648707 assert isinstance (features , np .ndarray )
649708 # Construct graph data object
650- assert self ._graph_definition is not None
651- graph = self ._graph_definition (
709+ assert self ._data_representation is not None
710+ graph = self ._data_representation (
652711 input_features = node_features ,
653712 input_feature_names = self ._features ,
654713 truth_dicts = truth_dicts ,
0 commit comments