Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Keras: Saving history in a JSON file #861

Merged
merged 2 commits into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/huggingface_hub/keras_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@ def _extract_hyperparameters_from_keras(model):
return hyperparameters


def _parse_model_history(model):
def _parse_model_history(model, save_directory):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not in love with the fact this history writing is happening so deep in the logic here. save_pretrained_keras -> _create_model_card -> _write_metrics -> _parse_model_history is so deep in the code to just save model.history.history as a json file.

I don't want to block on this, but just wanted to point that out so we remember to put this in a nicer spot in future refactors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's what I wanted to refactor as well (it disturbed me too!). I'm planning to do that after next Keras sprint.

lines = None
if model.history is not None:
if model.history.history != {}:
path = os.path.join(save_directory, "history.json")
with open(path, "w", encoding="utf-8") as f:
json.dump(model.history.history, f, indent=2, sort_keys=True)
lines = []
logs = model.history.history
num_epochs = len(logs["loss"])
Expand Down Expand Up @@ -79,8 +82,8 @@ def _plot_network(model, save_directory):
)


def _write_metrics(model, model_card):
lines = _parse_model_history(model)
def _write_metrics(model, model_card, save_directory):
lines = _parse_model_history(model, save_directory)
if lines is not None:
model_card += "\n| Epochs |"

Expand Down Expand Up @@ -128,7 +131,7 @@ def _create_model_card(
)
model_card += "\n"
model_card += "\n ## Training Metrics\n"
model_card = _write_metrics(model, model_card)
model_card = _write_metrics(model, model_card, repo_dir)
if plot_model and os.path.exists(f"{repo_dir}/model.png"):
model_card += "\n ## Model Plot\n"
model_card += "\n<details>"
Expand Down
5 changes: 3 additions & 2 deletions tests/test_keras_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ def test_save_pretrained_model_card_fit(self):
self.assertIn("keras_metadata.pb", files)
self.assertIn("model.png", files)
self.assertIn("README.md", files)
self.assertEqual(len(files), 6)
self.assertIn("history.json", files)
self.assertEqual(len(files), 7)

def test_save_pretrained_optimizer_state(self):
REPO_NAME = repo_name("save")
Expand Down Expand Up @@ -490,4 +491,4 @@ def test_save_pretrained_fit(self):

self.assertIn("saved_model.pb", files)
self.assertIn("keras_metadata.pb", files)
self.assertEqual(len(files), 6)
self.assertEqual(len(files), 7)