diff --git a/tests/sources/tools/perception/object_tracking_2d/fair_mot/test_object_tracking_2d_fair_mot.py b/tests/sources/tools/perception/object_tracking_2d/fair_mot/test_object_tracking_2d_fair_mot.py index ddfe93fe4f..24f3a8db9f 100644 --- a/tests/sources/tools/perception/object_tracking_2d/fair_mot/test_object_tracking_2d_fair_mot.py +++ b/tests/sources/tools/perception/object_tracking_2d/fair_mot/test_object_tracking_2d_fair_mot.py @@ -84,66 +84,66 @@ def tearDownClass(cls): rmdir(os.path.join(cls.temp_dir)) -# def test_fit(self): - -# def test_model(name): -# dataset = MotDataset(self.dataset_path) - -# learner = ObjectTracking2DFairMotLearner( -# iters=1, -# num_epochs=1, -# checkpoint_after_iter=3, -# temp_path=self.temp_dir, -# device=DEVICE, -# ) - -# starting_param = list(learner.model.parameters())[0].clone() - -# learner.fit( -# dataset, -# val_epochs=-1, -# train_split_paths=self.train_split_paths, -# val_split_paths=self.train_split_paths, -# verbose=True, -# ) -# new_param = list(learner.model.parameters())[0].clone() -# self.assertFalse(torch.equal(starting_param, new_param)) - -# print("Fit", name, "ok", file=sys.stderr) - -# for name in self.model_names: -# test_model(name) - -# def test_fit_iterator(self): -# def test_model(name): -# dataset = MotDatasetIterator(self.dataset_path, self.train_split_paths) -# eval_dataset = RawMotDatasetIterator(self.dataset_path, self.train_split_paths) - -# learner = ObjectTracking2DFairMotLearner( -# iters=1, -# num_epochs=1, -# checkpoint_after_iter=3, -# temp_path=self.temp_dir, -# device=DEVICE, -# ) - -# starting_param = list(learner.model.parameters())[0].clone() - -# learner.fit( -# dataset, -# val_dataset=eval_dataset, -# val_epochs=-1, -# train_split_paths=self.train_split_paths, -# val_split_paths=self.train_split_paths, -# verbose=True, -# ) -# new_param = list(learner.model.parameters())[0].clone() -# self.assertFalse(torch.equal(starting_param, new_param)) - -# print("Fit iterator", name, "ok", file=sys.stderr) - -# for name in self.model_names: -# test_model(name) + def test_fit(self): + + def test_model(name): + dataset = MotDataset(self.dataset_path) + + learner = ObjectTracking2DFairMotLearner( + iters=1, + num_epochs=1, + checkpoint_after_iter=3, + temp_path=self.temp_dir, + device=DEVICE, + ) + + starting_param = list(learner.model.parameters())[0].clone() + + learner.fit( + dataset, + val_epochs=-1, + train_split_paths=self.train_split_paths, + val_split_paths=self.train_split_paths, + verbose=True, + ) + new_param = list(learner.model.parameters())[0].clone() + self.assertFalse(torch.equal(starting_param, new_param)) + + print("Fit", name, "ok", file=sys.stderr) + + for name in self.model_names: + test_model(name) + + def test_fit_iterator(self): + def test_model(name): + dataset = MotDatasetIterator(self.dataset_path, self.train_split_paths) + eval_dataset = RawMotDatasetIterator(self.dataset_path, self.train_split_paths) + + learner = ObjectTracking2DFairMotLearner( + iters=1, + num_epochs=1, + checkpoint_after_iter=3, + temp_path=self.temp_dir, + device=DEVICE, + ) + + starting_param = list(learner.model.parameters())[0].clone() + + learner.fit( + dataset, + val_dataset=eval_dataset, + val_epochs=-1, + train_split_paths=self.train_split_paths, + val_split_paths=self.train_split_paths, + verbose=True, + ) + new_param = list(learner.model.parameters())[0].clone() + self.assertFalse(torch.equal(starting_param, new_param)) + + print("Fit iterator", name, "ok", file=sys.stderr) + + for name in self.model_names: + test_model(name) def test_eval(self): def test_model(name):