Skip to content

Commit

Permalink
Keras: Saving history in a JSON file (#861)
Browse files Browse the repository at this point in the history
* added test and history saving

* Update src/huggingface_hub/keras_mixin.py

Co-authored-by: Nathan Raw <nxr9266@g.rit.edu>

Co-authored-by: Nathan Raw <nxr9266@g.rit.edu>
  • Loading branch information
2 people authored and LysandreJik committed May 24, 2022
1 parent a8b6f14 commit 0eac01b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
11 changes: 7 additions & 4 deletions src/huggingface_hub/keras_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@ def _extract_hyperparameters_from_keras(model):
return hyperparameters


def _parse_model_history(model):
def _parse_model_history(model, save_directory):
lines = None
if model.history is not None:
if model.history.history != {}:
path = os.path.join(save_directory, "history.json")
with open(path, "w", encoding="utf-8") as f:
json.dump(model.history.history, f, indent=2, sort_keys=True)
lines = []
logs = model.history.history
num_epochs = len(logs["loss"])
Expand Down Expand Up @@ -79,8 +82,8 @@ def _plot_network(model, save_directory):
)


def _write_metrics(model, model_card):
lines = _parse_model_history(model)
def _write_metrics(model, model_card, save_directory):
lines = _parse_model_history(model, save_directory)
if lines is not None:
model_card += "\n| Epochs |"

Expand Down Expand Up @@ -128,7 +131,7 @@ def _create_model_card(
)
model_card += "\n"
model_card += "\n ## Training Metrics\n"
model_card = _write_metrics(model, model_card)
model_card = _write_metrics(model, model_card, repo_dir)
if plot_model and os.path.exists(f"{repo_dir}/model.png"):
model_card += "\n ## Model Plot\n"
model_card += "\n<details>"
Expand Down
5 changes: 3 additions & 2 deletions tests/test_keras_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ def test_save_pretrained_model_card_fit(self):
self.assertIn("keras_metadata.pb", files)
self.assertIn("model.png", files)
self.assertIn("README.md", files)
self.assertEqual(len(files), 6)
self.assertIn("history.json", files)
self.assertEqual(len(files), 7)

def test_save_pretrained_optimizer_state(self):
REPO_NAME = repo_name("save")
Expand Down Expand Up @@ -490,4 +491,4 @@ def test_save_pretrained_fit(self):

self.assertIn("saved_model.pb", files)
self.assertIn("keras_metadata.pb", files)
self.assertEqual(len(files), 6)
self.assertEqual(len(files), 7)

0 comments on commit 0eac01b

Please # to comment.