Skip to content

Commit

Permalink
Merge pull request #75 from nugraph/bugfix/3d-truth
Browse files Browse the repository at this point in the history
Fix 3D truth formation
  • Loading branch information
vhewes authored Jun 5, 2024
2 parents 7f30ea7 + 78b40e0 commit 0b58a22
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
15 changes: 7 additions & 8 deletions pynuml/process/hitgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def columns(self) -> dict[str, list[str]]:
else:
groups['event_table'] = keys
if self.label_position:
groups["hit_table"].extend(["x_position", "y_position", "z_position"])
groups["edep_table"] = []
return groups

@property
Expand Down Expand Up @@ -98,14 +98,13 @@ def __call__(self, evt: 'pynuml.io.Event') -> tuple[str, Any]:

# charge-weighted average of 3D position
if self.label_position:
edeps = edeps.drop("g4_id", axis="columns")
edeps["x_position"] = edeps.x_position * edeps.energy
edeps["y_position"] = edeps.y_position * edeps.energy
edeps["z_position"] = edeps.z_position * edeps.energy
edeps = edeps[["hit_id", "energy", "x_position", "y_position", "z_position"]]
for col in ["x_position", "y_position", "z_position"]:
edeps.loc[:, col] *= edeps.energy
edeps = edeps.groupby("hit_id").sum()
edeps["x_position"] = edeps.x_position / edeps.energy
edeps["y_position"] = edeps.y_position / edeps.energy
edeps["z_position"] = edeps.z_position / edeps.energy
for col in ["x_position", "y_position", "z_position"]:
edeps.loc[:, col] /= edeps.energy
edeps = edeps.drop("energy", axis="columns")
hits = edeps.merge(hits, on="hit_id", how="right")

hits['filter_label'] = ~hits[energy_col].isnull()
Expand Down
22 changes: 22 additions & 0 deletions tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,25 @@ def test_process_uboone():
continue
plot.plot(data, target='semantic', how='true', filter='show')
plot.plot(data, target='instance', how='true', filter='true')

def test_process_dune_nutau():
"""Test graph processing with DUNE beam nutau dataset"""
f = pynuml.io.File("/raid/nugraph/dune-nutau/NeutrinoML_r00140_s00000_ts814031.h5")
processor = pynuml.process.HitGraphProducer(
file=f,
semantic_labeller=pynuml.labels.StandardLabels(),
event_labeller=pynuml.labels.FlavorLabels(),
label_position=True)
plot = pynuml.plot.GraphPlot(
planes=["u", "v", "y"],
classes=pynuml.labels.StandardLabels().labels[:-1])
f.read_data(0, 100)
evts = f.build_evt()
for evt in evts:
_, data = processor(evt)
if not data:
continue
plot.plot(data, target="filter", how="true", filter="show")
plot.plot(data, target='semantic', how='true', filter='show')
plot.plot(data, target='instance', how='true', filter='true')
plot.plot(data, target="semantic", how="true", filter="show", xyz=True)

0 comments on commit 0b58a22

Please # to comment.