Skip to content

Commit

Permalink
Merge pull request #243 from JaerongA/datajoint_pipeline
Browse files Browse the repository at this point in the history
modify VideoSourceTracking table design & remove streams activate function
  • Loading branch information
JaerongA authored Sep 5, 2023
2 parents e2af5e0 + 645c659 commit 4106ca4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 32 deletions.
59 changes: 29 additions & 30 deletions aeon/dj_pipeline/tracking.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import datajoint as dj
import matplotlib.path
import numpy as np
Expand All @@ -12,7 +14,6 @@
streams,
)
from aeon.io import api as io_api
from aeon.io import reader

from . import acquisition, dict_to_uuid, get_schema_name, lab, qc

Expand Down Expand Up @@ -247,10 +248,8 @@ def get_object_position(
class VideoSourceTracking(dj.Imported):
definition = """ # Tracked objects position data from a particular VideoSource for multi-animal experiment using the SLEAP tracking method per chunk
-> acquisition.Chunk
-> streams.VideoSourcePosition
-> streams.VideoSource
-> TrackingParamSet
---
tracking_timestamps: longblob # (datetime) timestamps of the position data
"""

class Point(dj.Part):
Expand All @@ -260,20 +259,21 @@ class Point(dj.Part):
---
point_x: longblob
point_y: longblob
point_confidence: longblob
point_likelihood: longblob
"""

class Pose(dj.Part):
definition = """
-> master
pose_name: varchar(16)
---
class: smallint
class_confidence: longblob
---
class_likelihood: longblob
centroid_x: longblob
centroid_y: longblob
centroid_confidence: longblob
point_collection: varchar(1000) # List of point names
centroid_likelihood: longblob
pose_timestamps: longblob
point_collection=null: varchar(1000) # List of point names
"""

class PointCollection(dj.Part):
Expand All @@ -284,49 +284,48 @@ class PointCollection(dj.Part):

@property
def key_source(self):
ks = acquisition.Chunk * streams.VideoSource * TrackingParamSet
return ks & "experiment_name='multianimal'" & "video_source_name='CameraTop'" & "tracking_paramset_id = 1" # SLEAP & CameraTop
return (acquisition.Chunk & "experiment_name='multianimal'" ) * (streams.VideoSourcePosition & (streams.VideoSource & "video_source_name='CameraTop'")) * (TrackingParamSet & "tracking_paramset_id = 1") # SLEAP & CameraTop

def make(self, key):

from aeon.schema.social import Pose

chunk_start, chunk_end, dir_type = (acquisition.Chunk & key).fetch1(
"chunk_start", "chunk_end", "directory_type"
)
camera = (streams.VideoSource & key).fetch1("video_source_name")

raw_data_dir = acquisition.Experiment.get_data_directory(
key, directory_type=dir_type
)

device = getattr(
acquisition._device_schema_mapping[key["experiment_name"]], camera
)
# This needs to be modified later
sleap_reader = Pose(pattern="", columns=["class", "class_confidence", "centroid_x", "centroid_y", "centroid_confidence"])
tracking_file_path = "/ceph/aeon/aeon/data/processed/test-node1/1234567/2023-08-10T18-31-00/macentroid/test-node1_1234567_2023-08-10T18-31-00_macentroid.bin" # temp file path for testing

sleap_reader = reader.Harp(pattern="", columns=["class", "class_confidence", "centroid_x", "centroid_y", "centroid_confidence"])
tracking_file_path = "/ceph/aeon/aeon/code/scratchpad/ex_ma_tracking/ex_ma_tracking.bin" # temporary
tracking_df = sleap_reader.read(tracking_file_path)
tracking_df = sleap_reader.read(Path(tracking_file_path))

object_positions = []
for obj_name in ["body"]:
pose_list = []
for part_name in ["body"]:

for class_id in tracking_df["class"].unique():

temp_df = tracking_df[tracking_df["class"] == class_id]
class_df = tracking_df[tracking_df["class"] == class_id]

object_positions.append(
pose_list.append(
{
**key,
"object_name": obj_name,
"timestamps": temp_df.index.values,
"pose_name": part_name,
"class": class_id,
"class_confidence": temp_df.class_confidence.values,
"centroid_x": temp_df.centroid_x.values,
"centroid_y": temp_df.centroid_y.values,
"centroid_confidence": temp_df.centroid_confidence.values,
"class_likelihood": class_df["class_likelihood"].values,
"centroid_x": class_df["x"].values,
"centroid_y": class_df["y"].values,
"centroid_likelihood": class_df["part_likelihood"].values,
"pose_timestamps": class_df.index.values,
"point_collection": "",
}
)

self.insert1(key)
self.Object.insert(object_positions)
self.Pose.insert(pose_list)


# ---------- HELPER ------------------
Expand Down
2 changes: 0 additions & 2 deletions aeon/dj_pipeline/utils/streams_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ def main(create_tables=True):
f.write(full_def)

streams = importlib.import_module(f"aeon.dj_pipeline.streams")
streams.schema.activate(schema_name)

if create_tables:
# Create DeviceType tables.
Expand Down Expand Up @@ -315,7 +314,6 @@ def main(create_tables=True):
f.write(full_def)

importlib.reload(streams)
streams.schema.activate(schema_name)

return streams

Expand Down

0 comments on commit 4106ca4

Please # to comment.