Skip to content

Commit

Permalink
Merge branch 'fix_detectron2' of github.com:opendr-eu/opendr into fix…
Browse files Browse the repository at this point in the history
…_detectron2
  • Loading branch information
ad-daniel committed Aug 29, 2022
2 parents 887dc92 + 82f10d3 commit b64517b
Showing 1 changed file with 60 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,66 +84,66 @@ def tearDownClass(cls):

rmdir(os.path.join(cls.temp_dir))

# def test_fit(self):

# def test_model(name):
# dataset = MotDataset(self.dataset_path)

# learner = ObjectTracking2DFairMotLearner(
# iters=1,
# num_epochs=1,
# checkpoint_after_iter=3,
# temp_path=self.temp_dir,
# device=DEVICE,
# )

# starting_param = list(learner.model.parameters())[0].clone()

# learner.fit(
# dataset,
# val_epochs=-1,
# train_split_paths=self.train_split_paths,
# val_split_paths=self.train_split_paths,
# verbose=True,
# )
# new_param = list(learner.model.parameters())[0].clone()
# self.assertFalse(torch.equal(starting_param, new_param))

# print("Fit", name, "ok", file=sys.stderr)

# for name in self.model_names:
# test_model(name)

# def test_fit_iterator(self):
# def test_model(name):
# dataset = MotDatasetIterator(self.dataset_path, self.train_split_paths)
# eval_dataset = RawMotDatasetIterator(self.dataset_path, self.train_split_paths)

# learner = ObjectTracking2DFairMotLearner(
# iters=1,
# num_epochs=1,
# checkpoint_after_iter=3,
# temp_path=self.temp_dir,
# device=DEVICE,
# )

# starting_param = list(learner.model.parameters())[0].clone()

# learner.fit(
# dataset,
# val_dataset=eval_dataset,
# val_epochs=-1,
# train_split_paths=self.train_split_paths,
# val_split_paths=self.train_split_paths,
# verbose=True,
# )
# new_param = list(learner.model.parameters())[0].clone()
# self.assertFalse(torch.equal(starting_param, new_param))

# print("Fit iterator", name, "ok", file=sys.stderr)

# for name in self.model_names:
# test_model(name)
def test_fit(self):

def test_model(name):
dataset = MotDataset(self.dataset_path)

learner = ObjectTracking2DFairMotLearner(
iters=1,
num_epochs=1,
checkpoint_after_iter=3,
temp_path=self.temp_dir,
device=DEVICE,
)

starting_param = list(learner.model.parameters())[0].clone()

learner.fit(
dataset,
val_epochs=-1,
train_split_paths=self.train_split_paths,
val_split_paths=self.train_split_paths,
verbose=True,
)
new_param = list(learner.model.parameters())[0].clone()
self.assertFalse(torch.equal(starting_param, new_param))

print("Fit", name, "ok", file=sys.stderr)

for name in self.model_names:
test_model(name)

def test_fit_iterator(self):
def test_model(name):
dataset = MotDatasetIterator(self.dataset_path, self.train_split_paths)
eval_dataset = RawMotDatasetIterator(self.dataset_path, self.train_split_paths)

learner = ObjectTracking2DFairMotLearner(
iters=1,
num_epochs=1,
checkpoint_after_iter=3,
temp_path=self.temp_dir,
device=DEVICE,
)

starting_param = list(learner.model.parameters())[0].clone()

learner.fit(
dataset,
val_dataset=eval_dataset,
val_epochs=-1,
train_split_paths=self.train_split_paths,
val_split_paths=self.train_split_paths,
verbose=True,
)
new_param = list(learner.model.parameters())[0].clone()
self.assertFalse(torch.equal(starting_param, new_param))

print("Fit iterator", name, "ok", file=sys.stderr)

for name in self.model_names:
test_model(name)

def test_eval(self):
def test_model(name):
Expand Down

0 comments on commit b64517b

Please # to comment.