Skip to content

Commit c31582f

Browse files
committed
add snowstorm_dataset and IceCubehosted class
1 parent 79d7baf commit c31582f

File tree

5 files changed

+310
-1
lines changed

5 files changed

+310
-1
lines changed

src/graphnet/data/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,8 @@
99
from .pre_configured import I3ToParquetConverter
1010
from .pre_configured import I3ToSQLiteConverter
1111
from .datamodule import GraphNeTDataModule
12-
from .curated_datamodule import CuratedDataset, ERDAHostedDataset
12+
from .curated_datamodule import (
13+
CuratedDataset,
14+
ERDAHostedDataset,
15+
IceCubeHostedDataset,
16+
)

src/graphnet/data/constants.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,27 @@ class FEATURES:
2929
"sensor_pos_z",
3030
"t",
3131
]
32+
SNOWSTORM = [
33+
"dom_x",
34+
"dom_y",
35+
"dom_z",
36+
"charge",
37+
"dom_time",
38+
"width",
39+
"pmt_area",
40+
"rde",
41+
"is_bright_dom",
42+
"is_bad_dom",
43+
"is_saturated_dom",
44+
"is_errata_dom",
45+
"event_time",
46+
"hlc",
47+
"awtd",
48+
"string",
49+
"pmt_number",
50+
"dom_number",
51+
"dom_type",
52+
]
3253
KAGGLE = ["x", "y", "z", "time", "charge", "auxiliary"]
3354
LIQUIDO = ["sipm_x", "sipm_y", "sipm_z", "t"]
3455

@@ -84,6 +105,57 @@ class TRUTH:
84105
"primary_hadron_1_energy",
85106
"total_energy",
86107
]
108+
SNOWSTORM = [
109+
"energy",
110+
"position_x",
111+
"position_y",
112+
"position_z",
113+
"azimuth",
114+
"zenith",
115+
"pid",
116+
"event_time",
117+
"interaction_type",
118+
"elasticity",
119+
"RunID",
120+
"SubrunID",
121+
"EventID",
122+
"SubEventID",
123+
"dbang_decay_length",
124+
"track_length",
125+
"stopped_muon",
126+
"energy_track",
127+
"energy_cascade",
128+
"inelasticity",
129+
"DeepCoreFilter_13",
130+
"CascadeFilter_13",
131+
"MuonFilter_13",
132+
"OnlineL2Filter_17",
133+
"L3_oscNext_bool",
134+
"L4_oscNext_bool",
135+
"L5_oscNext_bool",
136+
"L6_oscNext_bool",
137+
"L7_oscNext_bool",
138+
"Homogenized_QTot",
139+
"MCLabelClassification",
140+
"MCLabelCoincidentMuons",
141+
"MCLabelBgMuonMCPE",
142+
"MCLabelBgMuonMCPECharge",
143+
"GNLabelTrackEnergyDeposited",
144+
"GNLabelTrackEnergyOnEntrance",
145+
"GNLabelTrackEnergyOnEntrancePrimary",
146+
"GNLabelTrackEnergyDepositedPrimary",
147+
"GNLabelEnergyPrimary",
148+
"GNLabelCascadeEnergyDepositedPrimary",
149+
"GNLabelCascadeEnergyDeposited",
150+
"GNLabelEnergyDepositedTotal",
151+
"GNLabelEnergyDepositedPrimary",
152+
"GNLabelHighestEInIceParticleIsChild",
153+
"GNLabelHighestEInIceParticleDistance",
154+
"GNLabelHighestEInIceParticleEFraction",
155+
"GNLabelHighestEInIceParticleEOnEntrance",
156+
"GNLabelHighestEDaughterDistance",
157+
"GNLabelHighestEDaughterEFraction",
158+
]
87159
KAGGLE = ["zenith", "azimuth"]
88160
LIQUIDO = [
89161
"vertex_x",

src/graphnet/data/curated_datamodule.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Dict, Any, Optional, List, Tuple, Union
99
from abc import abstractmethod
1010
import os
11+
from glob import glob
1112

1213
from .datamodule import GraphNeTDataModule
1314
from graphnet.models.graphs import GraphDefinition
@@ -280,3 +281,75 @@ def prepare_data(self) -> None:
280281
os.system(f"wget -O {file_path} {self._mirror}/{file_hash}")
281282
os.system(f"tar -xf {file_path} -C {self.dataset_dir}")
282283
os.system(f"rm {file_path}")
284+
285+
286+
class IceCubeHostedDataset(CuratedDataset):
287+
"""A base class for dataset/datamodule hosted on the IceCube cluster.
288+
289+
Inheriting subclasses will need to do:
290+
- fill out the `_zipped_files` attribute, which
291+
should be a list of paths to files that are compressed using `tar` with
292+
extension ".tar.gz" and are stored on the IceCube Cluster in "/data/".
293+
- implement the `_get_dir_name` method, which should return the
294+
directory name where the files resulting from the unzipping of a
295+
compressed file should end up.
296+
"""
297+
298+
_mirror = "https://convey.icecube.wisc.edu"
299+
300+
def prepare_data(self) -> None:
301+
"""Prepare the dataset for training."""
302+
assert hasattr(self, "_zipped_files") and (len(self._zipped_files) > 0)
303+
304+
# Check which files still need to be downloaded
305+
files_to_dl = self._resolve_downloads()
306+
if files_to_dl == []:
307+
return
308+
309+
# Download files
310+
USER = input("Username: ")
311+
source_file_paths = " ".join(
312+
[f"{self._mirror}{f}" for f in files_to_dl]
313+
)
314+
os.system(
315+
f"wget -P {self.dataset_dir} --user={USER} "
316+
+ f"--ask-password {source_file_paths}"
317+
)
318+
319+
# unzip files
320+
for file in glob(os.path.join(self.dataset_dir, "*.tar.gz")):
321+
tmp_dir = os.path.join(self.dataset_dir, "tmp")
322+
os.mkdir(tmp_dir)
323+
os.system(f"tar -xzf {file} -C {tmp_dir}")
324+
unzip_dir = self._get_dir_name(file)
325+
os.makedirs(unzip_dir)
326+
for db_file in glob(
327+
os.path.join(tmp_dir, "**/*.db"), recursive=True
328+
):
329+
os.system(f"mv {db_file} {unzip_dir}")
330+
331+
os.system(f"rm {file}")
332+
os.system(f"rm -r {tmp_dir}")
333+
334+
@abstractmethod
335+
def _get_dir_name(self, source_file_path: str) -> str:
336+
"""Get directory name from source file path.
337+
338+
E.g. if `source_file_path` is "/data/set/file.tar.gz",
339+
return os.path.join(self.dataset_dir, source_file_path.split("/")[-2])
340+
to have 'set' as the directory name where all files resulting from the
341+
unzipping of `source_file_path` end up. If no substrucutre is desired,
342+
just return `self.dataset_dir`
343+
"""
344+
raise NotImplementedError
345+
346+
def _resolve_downloads(self) -> List[str]:
347+
"""Resolve which files still need to be downloaded."""
348+
if not os.path.exists(self.dataset_dir):
349+
return self._zipped_files
350+
dir_names = [self._get_dir_name(f) for f in self._zipped_files]
351+
ret = []
352+
for i, dir in enumerate(dir_names):
353+
if not os.path.exists(dir):
354+
ret.append(self._zipped_files[i])
355+
return ret

src/graphnet/datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22

33
from .test_dataset import TestDataset
44
from .prometheus_datasets import TRIDENTSmall, BaikalGVDSmall, PONESmall
5+
from .snowstorm_dataset import SnowStormDataset
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""Snowstorm dataset module hosted on the IceCube Collaboration servers."""
2+
3+
import pandas as pd
4+
import re
5+
import os
6+
from typing import Dict, Any, Optional, List, Tuple, Union
7+
from glob import glob
8+
from sklearn.model_selection import train_test_split
9+
10+
from graphnet.data.constants import FEATURES, TRUTH
11+
from graphnet.data.curated_datamodule import IceCubeHostedDataset
12+
from graphnet.data.utilities import query_database
13+
from graphnet.models.graphs import GraphDefinition
14+
15+
16+
class SnowStormDataset(IceCubeHostedDataset):
17+
"""IceCube SnowStorm simulation dataset.
18+
19+
More information can be found at
20+
https://wiki.icecube.wisc.edu/index.php/SnowStorm_MC#File_Locations
21+
This is a IceCube Collaboration simulation dataset.
22+
Requires a username and password.
23+
"""
24+
25+
_experiment = "IceCube SnowStorm dataset"
26+
_creator = "Severin Magel"
27+
_citation = "arXiv:1909.01530"
28+
_available_backends = ["sqlite"]
29+
30+
_pulsemaps = ["SRTInIcePulses"]
31+
_truth_table = "truth"
32+
_pulse_truth = None
33+
_features = FEATURES.SNOWSTORM
34+
_event_truth = TRUTH.SNOWSTORM
35+
_data_root_dir = "/data/ana/graphnet/Snowstorm_l2"
36+
37+
def __init__(
38+
self,
39+
run_ids: List[int],
40+
graph_definition: GraphDefinition,
41+
download_dir: str,
42+
truth: Optional[List[str]] = None,
43+
features: Optional[List[str]] = None,
44+
train_dataloader_kwargs: Optional[Dict[str, Any]] = None,
45+
validation_dataloader_kwargs: Optional[Dict[str, Any]] = None,
46+
test_dataloader_kwargs: Optional[Dict[str, Any]] = None,
47+
):
48+
"""Initialize SnowStorm dataset."""
49+
self._run_ids = run_ids
50+
self._zipped_files = [
51+
os.path.join(self._data_root_dir, f"{s}.tar.gz") for s in run_ids
52+
]
53+
54+
super().__init__(
55+
graph_definition=graph_definition,
56+
download_dir=download_dir,
57+
truth=truth,
58+
features=features,
59+
backend="sqlite",
60+
train_dataloader_kwargs=train_dataloader_kwargs,
61+
validation_dataloader_kwargs=validation_dataloader_kwargs,
62+
test_dataloader_kwargs=test_dataloader_kwargs,
63+
)
64+
65+
def _prepare_args(
66+
self, backend: str, features: List[str], truth: List[str]
67+
) -> Tuple[Dict[str, Any], Union[List[int], None], Union[List[int], None]]:
68+
"""Prepare arguments for dataset."""
69+
assert backend == "sqlite"
70+
dataset_paths = []
71+
for rid in self._run_ids:
72+
dataset_paths += glob(
73+
os.path.join(self.dataset_dir, str(rid), "**/*.db"),
74+
recursive=True,
75+
)
76+
77+
# get event numbers from all datasets
78+
event_no = []
79+
80+
# get RunID
81+
pattern = rf"{re.escape(self.dataset_dir)}/(\d+)/.*"
82+
event_counts: Dict[str, int] = {}
83+
event_counts = {}
84+
for path in dataset_paths:
85+
86+
# Extract the ID
87+
match = re.search(pattern, path)
88+
assert match
89+
run_id = match.group(1)
90+
91+
query_df = query_database(
92+
database=path,
93+
query=f"SELECT event_no FROM {self._truth_table}",
94+
)
95+
query_df["path"] = path
96+
event_no.append(query_df)
97+
98+
# save event count for description
99+
if run_id in event_counts:
100+
event_counts[run_id] += query_df.shape[0]
101+
else:
102+
event_counts[run_id] = query_df.shape[0]
103+
104+
event_no = pd.concat(event_no, axis=0)
105+
106+
# split the non-unique event numbers into train/val and test
107+
train_val, test = train_test_split(
108+
event_no,
109+
test_size=0.10,
110+
random_state=42,
111+
shuffle=True,
112+
)
113+
114+
train_val = train_val.groupby("path")
115+
test = test.groupby("path")
116+
117+
# parse into right format for CuratedDataset
118+
train_val_selection = []
119+
test_selection = []
120+
for path in dataset_paths:
121+
train_val_selection.append(
122+
train_val["event_no"].get_group(path).tolist()
123+
)
124+
test_selection.append(test["event_no"].get_group(path).tolist())
125+
126+
dataset_args = {
127+
"truth_table": self._truth_table,
128+
"pulsemaps": self._pulsemaps,
129+
"path": dataset_paths,
130+
"graph_definition": self._graph_definition,
131+
"features": features,
132+
"truth": truth,
133+
}
134+
135+
self._create_comment(event_counts)
136+
137+
return dataset_args, train_val_selection, test_selection
138+
139+
@classmethod
140+
def _create_comment(cls, event_counts: Dict[str, int] = {}) -> None:
141+
"""Print the number of events in each RunID."""
142+
fixed_string = (
143+
" Simulation produced by the IceCube Collaboration, "
144+
+ "https://wiki.icecube.wisc.edu/index.php/SnowStorm_MC#File_Locations" # noqa: E501
145+
)
146+
tot = 0
147+
runid_string = ""
148+
for k, v in event_counts.items():
149+
runid_string += f"RunID {k} contains {v:10d} events\n"
150+
tot += v
151+
cls._comments = (
152+
f"Contains ~{tot/1e6:.1f} million events:\n"
153+
+ runid_string
154+
+ fixed_string
155+
)
156+
157+
def _get_dir_name(self, source_file_path: str) -> str:
158+
file_name = os.path.basename(source_file_path).split(".")[0]
159+
return str(os.path.join(self.dataset_dir, file_name))

0 commit comments

Comments
 (0)