diff --git a/pynuml/process/hitgraph.py b/pynuml/process/hitgraph.py index 3abc0e4..a36e378 100644 --- a/pynuml/process/hitgraph.py +++ b/pynuml/process/hitgraph.py @@ -58,7 +58,7 @@ def columns(self) -> dict[str, list[str]]: else: groups['event_table'] = keys if self.label_position: - groups["hit_table"].extend(["x_position", "y_position", "z_position"]) + groups["edep_table"] = [] return groups @property @@ -98,14 +98,13 @@ def __call__(self, evt: 'pynuml.io.Event') -> tuple[str, Any]: # charge-weighted average of 3D position if self.label_position: - edeps = edeps.drop("g4_id", axis="columns") - edeps["x_position"] = edeps.x_position * edeps.energy - edeps["y_position"] = edeps.y_position * edeps.energy - edeps["z_position"] = edeps.z_position * edeps.energy + edeps = edeps[["hit_id", "energy", "x_position", "y_position", "z_position"]] + for col in ["x_position", "y_position", "z_position"]: + edeps.loc[:, col] *= edeps.energy edeps = edeps.groupby("hit_id").sum() - edeps["x_position"] = edeps.x_position / edeps.energy - edeps["y_position"] = edeps.y_position / edeps.energy - edeps["z_position"] = edeps.z_position / edeps.energy + for col in ["x_position", "y_position", "z_position"]: + edeps.loc[:, col] /= edeps.energy + edeps = edeps.drop("energy", axis="columns") hits = edeps.merge(hits, on="hit_id", how="right") hits['filter_label'] = ~hits[energy_col].isnull() diff --git a/tests/test_process.py b/tests/test_process.py index f340606..b7fcc08 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -20,3 +20,25 @@ def test_process_uboone(): continue plot.plot(data, target='semantic', how='true', filter='show') plot.plot(data, target='instance', how='true', filter='true') + +def test_process_dune_nutau(): + """Test graph processing with DUNE beam nutau dataset""" + f = pynuml.io.File("/raid/nugraph/dune-nutau/NeutrinoML_r00140_s00000_ts814031.h5") + processor = pynuml.process.HitGraphProducer( + file=f, + semantic_labeller=pynuml.labels.StandardLabels(), + event_labeller=pynuml.labels.FlavorLabels(), + label_position=True) + plot = pynuml.plot.GraphPlot( + planes=["u", "v", "y"], + classes=pynuml.labels.StandardLabels().labels[:-1]) + f.read_data(0, 100) + evts = f.build_evt() + for evt in evts: + _, data = processor(evt) + if not data: + continue + plot.plot(data, target="filter", how="true", filter="show") + plot.plot(data, target='semantic', how='true', filter='show') + plot.plot(data, target='instance', how='true', filter='true') + plot.plot(data, target="semantic", how="true", filter="show", xyz=True)