Skip to content

Commit

Permalink
fix: update_progress accept values >1
Browse files Browse the repository at this point in the history
  • Loading branch information
12rambau committed Nov 29, 2022
1 parent 8a8196e commit dc7fd95
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
9 changes: 5 additions & 4 deletions sepal_ui/sepalwidgets/alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ def update_progress(self, progress, msg="Progress", **tqdm_args):
self.show()

# cast the progress to float
total = tqdm_args.get("total", 1)
progress = float(progress)
if not (0 <= progress <= 1):
raise ValueError(f"progress should be in [0, 1], {progress} given")
if not (0 <= progress <= total):
raise ValueError(f"progress should be in [0, {total}], {progress} given")

# Prevent adding multiple times
if self.progress_output not in self.children:
Expand All @@ -107,7 +108,7 @@ def update_progress(self, progress, msg="Progress", **tqdm_args):
"bar_format", "{l_bar}{bar}{n_fmt}/{total_fmt}"
)
tqdm_args["dynamic_ncols"] = tqdm_args.pop("dynamic_ncols", tqdm_args)
tqdm_args["total"] = tqdm_args.pop("total", 100)
tqdm_args["total"] = tqdm_args.pop("total", 1)
tqdm_args["desc"] = tqdm_args.pop("desc", msg)
tqdm_args["colour"] = tqdm_args.pop("tqdm_args", getattr(color, self.type))

Expand All @@ -120,7 +121,7 @@ def update_progress(self, progress, msg="Progress", **tqdm_args):
# Initialize bar
self.progress_bar.update(0)

self.progress_bar.update(progress * 100 - self.progress_bar.n)
self.progress_bar.update(progress - self.progress_bar.n)

if progress == 1:
self.progress_bar.close()
Expand Down
9 changes: 8 additions & 1 deletion tests/test_Alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,18 @@ def test_update_progress(self, alert):

# test a random update
alert.update_progress(0.5)
assert alert.progress_bar.n == 50
assert alert.progress_bar.n == 0.5
assert alert.viz is True

# show that a value > 1 raise an error
with pytest.raises(ValueError):
alert.reset()
alert.update_progress(1.5)

# check that if total is set value can be more than 1
alert.reset()
alert.update_progress(50, total=100)
assert alert.progress_bar.n == 50
assert alert.viz is True

return

0 comments on commit dc7fd95

Please # to comment.