Skip to content

Commit 941c9cd

Browse files
authored
Merge pull request #783 from sevmag/new_dataset_snowstorm
add snowstorm_dataset and IceCubehosted class
2 parents 0734941 + 78843af commit 941c9cd

File tree

5 files changed

+336
-1
lines changed

5 files changed

+336
-1
lines changed

src/graphnet/data/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,8 @@
99
from .pre_configured import I3ToParquetConverter
1010
from .pre_configured import I3ToSQLiteConverter
1111
from .datamodule import GraphNeTDataModule
12-
from .curated_datamodule import CuratedDataset, ERDAHostedDataset
12+
from .curated_datamodule import (
13+
CuratedDataset,
14+
ERDAHostedDataset,
15+
IceCubeHostedDataset,
16+
)

src/graphnet/data/constants.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,27 @@ class FEATURES:
2929
"sensor_pos_z",
3030
"t",
3131
]
32+
SNOWSTORM = [
33+
"dom_x",
34+
"dom_y",
35+
"dom_z",
36+
"charge",
37+
"dom_time",
38+
"width",
39+
"pmt_area",
40+
"rde",
41+
"is_bright_dom",
42+
"is_bad_dom",
43+
"is_saturated_dom",
44+
"is_errata_dom",
45+
"event_time",
46+
"hlc",
47+
"awtd",
48+
"string",
49+
"pmt_number",
50+
"dom_number",
51+
"dom_type",
52+
]
3253
KAGGLE = ["x", "y", "z", "time", "charge", "auxiliary"]
3354
LIQUIDO = ["sipm_x", "sipm_y", "sipm_z", "t"]
3455

@@ -84,6 +105,57 @@ class TRUTH:
84105
"primary_hadron_1_energy",
85106
"total_energy",
86107
]
108+
SNOWSTORM = [
109+
"energy",
110+
"position_x",
111+
"position_y",
112+
"position_z",
113+
"azimuth",
114+
"zenith",
115+
"pid",
116+
"event_time",
117+
"interaction_type",
118+
"elasticity",
119+
"RunID",
120+
"SubrunID",
121+
"EventID",
122+
"SubEventID",
123+
"dbang_decay_length",
124+
"track_length",
125+
"stopped_muon",
126+
"energy_track",
127+
"energy_cascade",
128+
"inelasticity",
129+
"DeepCoreFilter_13",
130+
"CascadeFilter_13",
131+
"MuonFilter_13",
132+
"OnlineL2Filter_17",
133+
"L3_oscNext_bool",
134+
"L4_oscNext_bool",
135+
"L5_oscNext_bool",
136+
"L6_oscNext_bool",
137+
"L7_oscNext_bool",
138+
"Homogenized_QTot",
139+
"MCLabelClassification",
140+
"MCLabelCoincidentMuons",
141+
"MCLabelBgMuonMCPE",
142+
"MCLabelBgMuonMCPECharge",
143+
"GNLabelTrackEnergyDeposited",
144+
"GNLabelTrackEnergyOnEntrance",
145+
"GNLabelTrackEnergyOnEntrancePrimary",
146+
"GNLabelTrackEnergyDepositedPrimary",
147+
"GNLabelEnergyPrimary",
148+
"GNLabelCascadeEnergyDepositedPrimary",
149+
"GNLabelCascadeEnergyDeposited",
150+
"GNLabelEnergyDepositedTotal",
151+
"GNLabelEnergyDepositedPrimary",
152+
"GNLabelHighestEInIceParticleIsChild",
153+
"GNLabelHighestEInIceParticleDistance",
154+
"GNLabelHighestEInIceParticleEFraction",
155+
"GNLabelHighestEInIceParticleEOnEntrance",
156+
"GNLabelHighestEDaughterDistance",
157+
"GNLabelHighestEDaughterEFraction",
158+
]
87159
KAGGLE = ["zenith", "azimuth"]
88160
LIQUIDO = [
89161
"vertex_x",

src/graphnet/data/curated_datamodule.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Dict, Any, Optional, List, Tuple, Union
99
from abc import abstractmethod
1010
import os
11+
from glob import glob
1112

1213
from .datamodule import GraphNeTDataModule
1314
from graphnet.models.graphs import GraphDefinition
@@ -280,3 +281,75 @@ def prepare_data(self) -> None:
280281
os.system(f"wget -O {file_path} {self._mirror}/{file_hash}")
281282
os.system(f"tar -xf {file_path} -C {self.dataset_dir}")
282283
os.system(f"rm {file_path}")
284+
285+
286+
class IceCubeHostedDataset(CuratedDataset):
287+
"""A base class for dataset/datamodule hosted on the IceCube cluster.
288+
289+
Inheriting subclasses will need to do:
290+
- fill out the `_zipped_files` attribute, which
291+
should be a list of paths to files that are compressed using `tar` with
292+
extension ".tar.gz" and are stored on the IceCube Cluster in "/data/".
293+
- implement the `_get_dir_name` method, which should return the
294+
directory name where the files resulting from the unzipping of a
295+
compressed file should end up.
296+
"""
297+
298+
_mirror = "https://convey.icecube.wisc.edu"
299+
300+
def prepare_data(self) -> None:
301+
"""Prepare the dataset for training."""
302+
assert hasattr(self, "_zipped_files") and (len(self._zipped_files) > 0)
303+
304+
# Check which files still need to be downloaded
305+
files_to_dl = self._resolve_downloads()
306+
if files_to_dl == []:
307+
return
308+
309+
# Download files
310+
USER = input("Username: ")
311+
source_file_paths = " ".join(
312+
[f"{self._mirror}{f}" for f in files_to_dl]
313+
)
314+
os.system(
315+
f"wget -P {self.dataset_dir} --user={USER} "
316+
+ f"--ask-password {source_file_paths}"
317+
)
318+
319+
# unzip files
320+
for file in glob(os.path.join(self.dataset_dir, "*.tar.gz")):
321+
tmp_dir = os.path.join(self.dataset_dir, "tmp")
322+
os.mkdir(tmp_dir)
323+
os.system(f"tar -xzf {file} -C {tmp_dir}")
324+
unzip_dir = self._get_dir_name(file)
325+
os.makedirs(unzip_dir)
326+
for db_file in glob(
327+
os.path.join(tmp_dir, "**/*.db"), recursive=True
328+
):
329+
os.system(f"mv {db_file} {unzip_dir}")
330+
331+
os.system(f"rm {file}")
332+
os.system(f"rm -r {tmp_dir}")
333+
334+
@abstractmethod
335+
def _get_dir_name(self, source_file_path: str) -> str:
336+
"""Get directory name from source file path.
337+
338+
E.g. if `source_file_path` is "/data/set/file.tar.gz",
339+
return os.path.join(self.dataset_dir, source_file_path.split("/")[-2])
340+
to have 'set' as the directory name where all files resulting from the
341+
unzipping of `source_file_path` end up. If no substrucutre is desired,
342+
just return `self.dataset_dir`
343+
"""
344+
raise NotImplementedError
345+
346+
def _resolve_downloads(self) -> List[str]:
347+
"""Resolve which files still need to be downloaded."""
348+
if not os.path.exists(self.dataset_dir):
349+
return self._zipped_files
350+
dir_names = [self._get_dir_name(f) for f in self._zipped_files]
351+
ret = []
352+
for i, dir in enumerate(dir_names):
353+
if not os.path.exists(dir):
354+
ret.append(self._zipped_files[i])
355+
return ret

src/graphnet/datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22

33
from .test_dataset import TestDataset
44
from .prometheus_datasets import TRIDENTSmall, BaikalGVDSmall, PONESmall
5+
from .snowstorm_dataset import SnowStormDataset
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
"""Snowstorm dataset module hosted on the IceCube Collaboration servers."""
2+
3+
import pandas as pd
4+
import re
5+
import os
6+
from typing import Dict, Any, Optional, List, Tuple, Union
7+
from glob import glob
8+
from sklearn.model_selection import train_test_split
9+
10+
from graphnet.data.constants import FEATURES, TRUTH
11+
from graphnet.data.curated_datamodule import IceCubeHostedDataset
12+
from graphnet.data.utilities import query_database
13+
from graphnet.models.graphs import GraphDefinition
14+
15+
AVAILABLE_RUN_IDS = [
16+
*list(range(22010, 22019)),
17+
*list(range(22042, 22051)),
18+
*list(range(22078, 22087)),
19+
]
20+
21+
22+
class SnowStormDataset(IceCubeHostedDataset):
23+
"""IceCube SnowStorm simulation dataset.
24+
25+
More information can be found at
26+
https://wiki.icecube.wisc.edu/index.php/SnowStorm_MC#File_Locations
27+
This is a IceCube Collaboration simulation dataset.
28+
Requires a username and password.
29+
"""
30+
31+
_experiment = "IceCube SnowStorm dataset"
32+
_creator = "Aske Rosted"
33+
_citation = "arXiv:1909.01530"
34+
_available_backends = ["sqlite"]
35+
36+
_pulsemaps = ["SRTInIcePulses"]
37+
_truth_table = "truth"
38+
_pulse_truth = None
39+
_features = FEATURES.SNOWSTORM
40+
_event_truth = TRUTH.SNOWSTORM
41+
_data_root_dir = "/data/ana/graphnet/Snowstorm_l2"
42+
43+
def __init__(
44+
self,
45+
run_ids: List[int],
46+
graph_definition: GraphDefinition,
47+
download_dir: str,
48+
truth: Optional[List[str]] = None,
49+
features: Optional[List[str]] = None,
50+
train_dataloader_kwargs: Optional[Dict[str, Any]] = None,
51+
validation_dataloader_kwargs: Optional[Dict[str, Any]] = None,
52+
test_dataloader_kwargs: Optional[Dict[str, Any]] = None,
53+
):
54+
"""Construct SnowStormDataset.
55+
56+
Args:
57+
run_ids: List of RunIDs to include.
58+
graph_definition: Method that defines the data representation.
59+
download_dir: Directory to download dataset to.
60+
truth (Optional): List of event-level truth to include. Will
61+
include all available information if not given.
62+
features (Optional): List of input features from pulsemap to use.
63+
If not given, all available features will be
64+
used.
65+
train_dataloader_kwargs (Optional): Arguments for the training
66+
DataLoader. Default None.
67+
validation_dataloader_kwargs (Optional): Arguments for the
68+
validation DataLoader, Default None.
69+
test_dataloader_kwargs (Optional): Arguments for the test
70+
DataLoader. Default None.
71+
"""
72+
assert all(
73+
[i in AVAILABLE_RUN_IDS for i in run_ids]
74+
), f"RunIDs must be in {AVAILABLE_RUN_IDS}. You provided {run_ids}"
75+
self._run_ids = run_ids
76+
self._zipped_files = [
77+
os.path.join(self._data_root_dir, f"{s}.tar.gz") for s in run_ids
78+
]
79+
80+
super().__init__(
81+
graph_definition=graph_definition,
82+
download_dir=download_dir,
83+
truth=truth,
84+
features=features,
85+
backend="sqlite",
86+
train_dataloader_kwargs=train_dataloader_kwargs,
87+
validation_dataloader_kwargs=validation_dataloader_kwargs,
88+
test_dataloader_kwargs=test_dataloader_kwargs,
89+
)
90+
91+
def _prepare_args(
92+
self, backend: str, features: List[str], truth: List[str]
93+
) -> Tuple[Dict[str, Any], Union[List[int], None], Union[List[int], None]]:
94+
"""Prepare arguments for dataset."""
95+
assert backend == "sqlite"
96+
dataset_paths = []
97+
for rid in self._run_ids:
98+
dataset_paths += glob(
99+
os.path.join(self.dataset_dir, str(rid), "**/*.db"),
100+
recursive=True,
101+
)
102+
103+
# get event numbers from all datasets
104+
event_no = []
105+
106+
# get RunID
107+
pattern = rf"{re.escape(self.dataset_dir)}/(\d+)/.*"
108+
event_counts: Dict[str, int] = {}
109+
event_counts = {}
110+
for path in dataset_paths:
111+
112+
# Extract the ID
113+
match = re.search(pattern, path)
114+
assert match
115+
run_id = match.group(1)
116+
117+
query_df = query_database(
118+
database=path,
119+
query=f"SELECT event_no FROM {self._truth_table}",
120+
)
121+
query_df["path"] = path
122+
event_no.append(query_df)
123+
124+
# save event count for description
125+
if run_id in event_counts:
126+
event_counts[run_id] += query_df.shape[0]
127+
else:
128+
event_counts[run_id] = query_df.shape[0]
129+
130+
event_no = pd.concat(event_no, axis=0)
131+
132+
# split the non-unique event numbers into train/val and test
133+
train_val, test = train_test_split(
134+
event_no,
135+
test_size=0.10,
136+
random_state=42,
137+
shuffle=True,
138+
)
139+
140+
train_val = train_val.groupby("path")
141+
test = test.groupby("path")
142+
143+
# parse into right format for CuratedDataset
144+
train_val_selection = []
145+
test_selection = []
146+
for path in dataset_paths:
147+
train_val_selection.append(
148+
train_val["event_no"].get_group(path).tolist()
149+
)
150+
test_selection.append(test["event_no"].get_group(path).tolist())
151+
152+
dataset_args = {
153+
"truth_table": self._truth_table,
154+
"pulsemaps": self._pulsemaps,
155+
"path": dataset_paths,
156+
"graph_definition": self._graph_definition,
157+
"features": features,
158+
"truth": truth,
159+
}
160+
161+
self._create_comment(event_counts)
162+
163+
return dataset_args, train_val_selection, test_selection
164+
165+
@classmethod
166+
def _create_comment(cls, event_counts: Dict[str, int] = {}) -> None:
167+
"""Print the number of events in each RunID."""
168+
fixed_string = (
169+
" Simulation produced by the IceCube Collaboration, "
170+
+ "https://wiki.icecube.wisc.edu/index.php/SnowStorm_MC#File_Locations" # noqa: E501
171+
)
172+
tot = 0
173+
runid_string = ""
174+
for k, v in event_counts.items():
175+
runid_string += f"RunID {k} contains {v:10d} events\n"
176+
tot += v
177+
cls._comments = (
178+
f"Contains ~{tot/1e6:.1f} million events:\n"
179+
+ runid_string
180+
+ fixed_string
181+
)
182+
183+
def _get_dir_name(self, source_file_path: str) -> str:
184+
file_name = os.path.basename(source_file_path).split(".")[0]
185+
return str(os.path.join(self.dataset_dir, file_name))

0 commit comments

Comments
 (0)