Skip to content

Commit fb26ec2

Browse files
authored
Merge pull request #788 from sevmag/restructure-Graph_Definition
Restructure graph definition
2 parents 941c9cd + 6b3b60c commit fb26ec2

File tree

22 files changed

+662
-216
lines changed

22 files changed

+662
-216
lines changed

src/graphnet/data/curated_datamodule.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@
1111
from glob import glob
1212

1313
from .datamodule import GraphNeTDataModule
14-
from graphnet.models.graphs import GraphDefinition
14+
from graphnet.models.data_representation import (
15+
GraphDefinition,
16+
DataRepresentation,
17+
)
1518
from graphnet.data.dataset import ParquetDataset, SQLiteDataset
1619

20+
from graphnet.utilities.logging import Logger
21+
1722

1823
class CuratedDataset(GraphNeTDataModule):
1924
"""Generic base class for curated datasets.
@@ -26,8 +31,9 @@ class CuratedDataset(GraphNeTDataModule):
2631

2732
def __init__(
2833
self,
29-
graph_definition: GraphDefinition,
3034
download_dir: str,
35+
data_representation: Optional[DataRepresentation] = None,
36+
graph_definition: Optional[GraphDefinition] = None,
3137
truth: Optional[List[str]] = None,
3238
features: Optional[List[str]] = None,
3339
backend: str = "parquet",
@@ -38,8 +44,10 @@ def __init__(
3844
"""Construct CuratedDataset.
3945
4046
Args:
41-
graph_definition: Method that defines the data representation.
4247
download_dir: Directory to download dataset to.
48+
data_representation: Method that defines the data representation.
49+
graph_definition: Method that defines the graph representation.
50+
NOTE: DEPRECATED Use `data_representation` instead.
4351
truth (Optional): List of event-level truth to include. Will
4452
include all available information if not given.
4553
features (Optional): List of input features from pulsemap to use.
@@ -54,9 +62,17 @@ def __init__(
5462
test_dataloader_kwargs (Optional): Arguments for the test
5563
DataLoader. Default None.
5664
"""
65+
if (data_representation is None) & (graph_definition is not None):
66+
data_representation = graph_definition
67+
elif (data_representation is None) & (graph_definition is None):
68+
# Code stops
69+
raise TypeError(
70+
"__init__() missing 1 required keyword argument:"
71+
"'data_representation'"
72+
)
73+
self._data_representation = data_representation
5774
# From user
5875
self._download_dir = download_dir
59-
self._graph_definition = graph_definition
6076
self._backend = backend.lower()
6177

6278
# Checks
@@ -85,6 +101,15 @@ def __init__(
85101
test_selection=test_selec,
86102
)
87103

104+
if graph_definition is not None:
105+
# Code continues after warning
106+
self.warning(
107+
"DeprecationWarning: Argument `graph_definition` will be"
108+
" deprecated in GraphNeT 2.0. Please use `data_representation`"
109+
" instead."
110+
""
111+
)
112+
88113
@abstractmethod
89114
def prepare_data(self) -> None:
90115
"""Download and prepare data."""
@@ -249,6 +274,28 @@ def dataset_dir(self) -> str:
249274
)
250275
return dataset_dir
251276

277+
# DEPRECATION: REMOVE AT 2.0 LAUNCH
278+
# See https://github.com/graphnet-team/graphnet/issues/647
279+
@property
280+
def _graph_definition(self) -> DataRepresentation:
281+
"""Return the graph definition."""
282+
# needed for the call in _prepare_args
283+
# call before Logger init
284+
if hasattr(self, "_logger"):
285+
self.warning(
286+
"DeprecationWarning: `_graph_definition` will be deprecated in"
287+
" GraphNeT 2.0. Please use `_data_representation` instead."
288+
)
289+
else:
290+
Logger(log_folder=None).warning_once(
291+
(
292+
"`graphnet.models.graphs` will be depricated soon. "
293+
"All functionality has been moved to "
294+
"`graphnet.models.data_representation`."
295+
)
296+
)
297+
return self._data_representation # type: ignore
298+
252299

253300
class ERDAHostedDataset(CuratedDataset):
254301
"""A base class for dataset/datamodule hosted at ERDA.

src/graphnet/data/dataset/dataset.py

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
)
3232
from graphnet.exceptions.exceptions import ColumnMissingException
3333
from graphnet.utilities.logging import Logger
34-
from graphnet.models.graphs import GraphDefinition
35-
34+
from graphnet.models.data_representation import (
35+
GraphDefinition,
36+
DataRepresentation,
37+
)
3638
from graphnet.utilities.config.parsing import (
3739
get_all_grapnet_classes,
3840
)
@@ -58,11 +60,11 @@ def load_module(class_name: str) -> Type:
5860
return namespace_classes[class_name]
5961

6062

61-
def parse_graph_definition(cfg: dict) -> GraphDefinition:
62-
"""Construct GraphDefinition from DatasetConfig."""
63-
assert cfg["graph_definition"] is not None
63+
def parse_data_representation(data_rep_cfg: dict) -> DataRepresentation:
64+
"""Construct DataRepresentation from DatasetConfig."""
65+
assert data_rep_cfg is not None
6466

65-
args = cfg["graph_definition"]["arguments"]
67+
args = data_rep_cfg["arguments"]
6668
classes = {}
6769
for arg in args.keys():
6870
if isinstance(args[arg], dict):
@@ -75,10 +77,8 @@ def parse_graph_definition(cfg: dict) -> GraphDefinition:
7577

7678
new_cfg = deepcopy(args)
7779
new_cfg.update(classes)
78-
graph_definition = load_module(cfg["graph_definition"]["class_name"])(
79-
**new_cfg
80-
)
81-
return graph_definition
80+
data_representation = load_module(data_rep_cfg["class_name"])(**new_cfg)
81+
return data_representation
8282

8383

8484
def parse_labels(cfg: dict) -> Dict[str, Label]:
@@ -122,9 +122,18 @@ def from_config( # type: ignore[override]
122122
"`DatasetConfig`"
123123
)
124124

125-
assert (
126-
"graph_definition" in source.dict().keys()
127-
), "`DatasetConfig` incompatible with current GraphNeT version."
125+
if "data_representation" not in source.dict().keys():
126+
if "graph_definition" in source.dict().keys():
127+
Logger(log_folder=None).warning_once(
128+
"DeprecationWarning: Field `graph_definition` will be"
129+
" deprecated in GraphNeT 2.0. Please use "
130+
"`data_representation` instead."
131+
)
132+
else:
133+
raise TypeError(
134+
"`DatasetConfig` incompatible with "
135+
"current GraphNeT version."
136+
)
128137

129138
# Parse set of `selection``.
130139
if isinstance(source.selection, dict):
@@ -137,8 +146,24 @@ def from_config( # type: ignore[override]
137146
return cls._construct_dataset_from_list_of_strings(source)
138147

139148
cfg = source.dict()
140-
if cfg["graph_definition"] is not None:
141-
cfg["graph_definition"] = parse_graph_definition(cfg)
149+
150+
if (
151+
"data_representation" in cfg
152+
and cfg["data_representation"] is not None
153+
):
154+
cfg["data_representation"] = parse_data_representation(
155+
cfg["data_representation"]
156+
)
157+
elif "graph_definition" in cfg and cfg["graph_definition"] is not None:
158+
Logger(log_folder=None).warning_once(
159+
"DeprecationWarning: Field `graph_definition` will be"
160+
" deprecated in GraphNeT 2.0. Please use "
161+
"`data_representation` instead."
162+
)
163+
cfg["graph_definition"] = parse_data_representation(
164+
cfg["graph_definition"]
165+
)
166+
142167
if cfg["labels"] is not None:
143168
cfg["labels"] = parse_labels(cfg)
144169

@@ -216,11 +241,12 @@ def _resolve_graphnet_paths(
216241
def __init__(
217242
self,
218243
path: Union[str, List[str]],
219-
graph_definition: GraphDefinition,
220244
pulsemaps: Union[str, List[str]],
221245
features: List[str],
222246
truth: List[str],
223247
*,
248+
graph_definition: Optional[GraphDefinition] = None,
249+
data_representation: Optional[DataRepresentation] = None,
224250
node_truth: Optional[List[str]] = None,
225251
index_column: str = "event_no",
226252
truth_table: str = "truth",
@@ -278,8 +304,13 @@ def __init__(
278304
subset of events when resolving a string-based selection (e.g.,
279305
`"10000 random events ~ event_no % 5 > 0"` or `"20% random
280306
events ~ event_no % 5 > 0"`).
281-
graph_definition: Method that defines the graph representation.
307+
data_representation: Method that defines the data representation.
282308
labels: Dictionary of labels to be added to the dataset.
309+
310+
graph_definition: Method that defines the graph representation.
311+
NOTE: DEPRECATED Use `data_representation` instead.
312+
# DEPRECATION: REMOVE AT 2.0 LAUNCH
313+
# See https://github.com/graphnet-team/graphnet/issues/647
283314
"""
284315
# Base class constructor
285316
super().__init__(name=__name__, class_name=self.__class__.__name__)
@@ -303,9 +334,26 @@ def __init__(
303334
self._index_column = index_column
304335
self._truth_table = truth_table
305336
self._loss_weight_default_value = loss_weight_default_value
306-
self._graph_definition = deepcopy(graph_definition)
337+
338+
if data_representation is None:
339+
if graph_definition is not None:
340+
data_representation = graph_definition
341+
# Code continues after warning
342+
self.warning(
343+
"DeprecationWarning: Argument `graph_definition` "
344+
"will be deprecated in GraphNeT 2.0. "
345+
"Please use `data_representation` instead."
346+
)
347+
else:
348+
# Code stops
349+
raise TypeError(
350+
"__init__() missing 1 required keyword argument:"
351+
"'data_representation'"
352+
)
353+
354+
self._data_representation = deepcopy(data_representation)
307355
self._labels = labels
308-
self._string_column = graph_definition._detector.string_index_name
356+
self._string_column = data_representation._detector.string_index_name
309357

310358
if node_truth is not None:
311359
assert isinstance(node_truth_table, str)
@@ -390,6 +438,17 @@ def truth_table(self) -> str:
390438
"""Name of the table containing event-level truth information."""
391439
return self._truth_table
392440

441+
# DEPRECATION PROPERTY: REMOVE AT 2.0 LAUNCH
442+
# See https://github.com/graphnet-team/graphnet/issues/647
443+
@property
444+
def _graph_definition(self) -> DataRepresentation:
445+
"""Return the graph definition."""
446+
self.warning(
447+
"DeprecationWarning: `_graph_definition` will be deprecated in"
448+
" GraphNeT 2.0. Please use `_data_representation` instead."
449+
)
450+
return self._data_representation
451+
393452
# Abstract method(s)
394453
@abstractmethod
395454
def _init(self) -> None:
@@ -647,8 +706,8 @@ def _create_graph(
647706

648707
assert isinstance(features, np.ndarray)
649708
# Construct graph data object
650-
assert self._graph_definition is not None
651-
graph = self._graph_definition(
709+
assert self._data_representation is not None
710+
graph = self._data_representation(
652711
input_features=node_features,
653712
input_feature_names=self._features,
654713
truth_dicts=truth_dicts,

src/graphnet/data/dataset/parquet/parquet_dataset.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
from bisect import bisect_right
1919
from collections import OrderedDict
2020

21-
from graphnet.models.graphs import GraphDefinition
21+
from graphnet.models.data_representation import (
22+
GraphDefinition,
23+
DataRepresentation,
24+
)
2225
from graphnet.data.dataset import Dataset
2326
from graphnet.exceptions.exceptions import ColumnMissingException
2427

@@ -29,11 +32,12 @@ class ParquetDataset(Dataset):
2932
def __init__(
3033
self,
3134
path: str,
32-
graph_definition: GraphDefinition,
3335
pulsemaps: Union[str, List[str]],
3436
features: List[str],
3537
truth: List[str],
3638
*,
39+
data_representation: Optional[DataRepresentation] = None,
40+
graph_definition: Optional[GraphDefinition] = None,
3741
node_truth: Optional[List[str]] = None,
3842
index_column: str = "event_no",
3943
truth_table: str = "truth",
@@ -92,7 +96,9 @@ def __init__(
9296
subset of events when resolving a string-based selection (e.g.,
9397
`"10000 random events ~ event_no % 5 > 0"` or `"20% random
9498
events ~ event_no % 5 > 0"`).
99+
data_representation: Method that defines the data representation.
95100
graph_definition: Method that defines the graph representation.
101+
NOTE: DEPRECATED Use `data_representation` instead.
96102
cache_size: Number of files to cache in memory.
97103
Must be at least 1. Defaults to 1.
98104
labels: Dictionary of labels to be added to the dataset.
@@ -116,6 +122,7 @@ def __init__(
116122
loss_weight_default_value=loss_weight_default_value,
117123
seed=seed,
118124
graph_definition=graph_definition,
125+
data_representation=data_representation,
119126
labels=labels,
120127
)
121128

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Modules for constructing data.
2+
3+
´DataRepresentation´ defines the basic structure for representing data.
4+
´GraphDefinition´ defines graphs with different nodes and their features, as
5+
well as the edges between them.
6+
"""
7+
8+
from .data_representation import DataRepresentation
9+
from .graphs import (
10+
GraphDefinition,
11+
KNNGraph,
12+
EdgelessGraph,
13+
KNNGraphRRWP,
14+
KNNGraphRWSE,
15+
NodeDefinition,
16+
NodesAsPulses,
17+
PercentileClusters,
18+
NodeAsDOMTimeSeries,
19+
IceMixNodes,
20+
)

0 commit comments

Comments
 (0)