Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Integrated Gradients Notebook - np.ndarray kwargs but only tf.Tensor supported #528

Merged
merged 5 commits into from
Nov 16, 2021

Conversation

RobertSamoilescu
Copy link
Collaborator

This PR fixes issue #527.
The newer versions of transformers do not support np.ndarray for optional arguments (i.e., attention_mask).
The error is fixed by casting np.darray to tf.Tensor before passing to explain or forward methods.

# the values of the kwargs have to be `tf.Tensor`. 
# see transformers issue #14404: https://github.com/huggingface/transformers/issues/14404
kwargs = {k: tf.constant(v) for k,v in z_test_sample.items() if k == 'attention_mask'}

In addition, I included import matplotlib.cm as it is required for matplotlib >= 3.4.2.

Also used a smaller validation dataset as when using the full testing dataset I ran into GPU memory issues.

    # using the entire testing dataset might result in memory issues when running on GPU
    model_out.fit(train_embbedings, y_train, 
                  validation_data=(test_embbedings[:100], y_test[:100]),
                  epochs=epochs, 
                  batch_size=batch_size,
                  callbacks=[cp_callback],
                  verbose=1)

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov
Copy link

codecov bot commented Nov 15, 2021

Codecov Report

Merging #528 (cc9522c) into master (6100135) will not change coverage.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master     #528   +/-   ##
=======================================
  Coverage   82.34%   82.34%           
=======================================
  Files          76       76           
  Lines       10334    10334           
=======================================
  Hits         8510     8510           
  Misses       1824     1824           

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants