Skip to content

Commit 25cf38e

Browse files
SSaishruthiSquadrick
authored andcommitted
R-square metric (#283)
1 parent 31ec2ec commit 25cf38e

File tree

5 files changed

+176
-1
lines changed

5 files changed

+176
-1
lines changed

tensorflow_addons/metrics/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ py_library(
77
srcs = [
88
"__init__.py",
99
"cohens_kappa.py",
10+
"r_square.py",
1011
],
1112
srcs_version = "PY2AND3",
1213
deps = [
@@ -26,3 +27,16 @@ py_test(
2627
":metrics",
2728
],
2829
)
30+
31+
py_test(
32+
name = "r_square_test",
33+
size = "small",
34+
srcs = [
35+
"r_square_test.py",
36+
],
37+
main = "r_square_test.py",
38+
srcs_version = "PY2AND3",
39+
deps = [
40+
":metrics",
41+
],
42+
)

tensorflow_addons/metrics/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
| Submodule | Maintainers | Contact Info |
55
|:---------- |:------------- |:--------------|
66
| cohens_kappa| Aakash Nain | aakashnain@outlook.com|
7+
| r_square| Saishruthi Swaminathan| saishruthi.tn@gmail.com|
78

89
## Contents
910
| Submodule | Metric | Reference |
1011
|:----------------------- |:-------------------|:---------------|
1112
| cohens_kappa| CohenKappa|[Cohen's Kappa](https://en.wikipedia.org/wiki/Cohen%27s_kappa)|
13+
| r_square| RSquare|[R-Sqaure](https://en.wikipedia.org/wiki/Coefficient_of_determination)|
1214

1315

1416
## Contribution Guidelines

tensorflow_addons/metrics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
from tensorflow_addons.metrics.cohens_kappa import CohenKappa
21+
from tensorflow_addons.metrics.cohens_kappa import CohenKappa
22+
from tensorflow_addons.metrics.r_square import RSquare

tensorflow_addons/metrics/r_square.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Implements R^2 scores."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tensorflow as tf
22+
from tensorflow.keras.metrics import Metric
23+
24+
25+
class RSquare(Metric):
26+
"""Compute R^2 score.
27+
28+
This is also called as coefficient of determination.
29+
It tells how close are data to the fitted regression line.
30+
31+
- Highest score can be 1.0 and it indicates that the predictors
32+
perfectly accounts for variation in the target.
33+
- Score 0.0 indicates that the predictors do not
34+
account for variation in the target.
35+
- It can also be negative if the model is worse.
36+
37+
Usage:
38+
```python
39+
actuals = tf.constant([1, 4, 3], dtype=tf.float32)
40+
preds = tf.constant([2, 4, 4], dtype=tf.float32)
41+
result = tf.keras.metrics.RSquare()
42+
result.update_state(actuals, preds)
43+
print('R^2 score is: ', r1.result().numpy()) # 0.57142866
44+
```
45+
"""
46+
47+
def __init__(self, name='r_square', dtype=tf.float32):
48+
super(RSquare, self).__init__(name=name, dtype=dtype)
49+
self.squared_sum = self.add_weight("squared_sum", initializer="zeros")
50+
self.sum = self.add_weight("sum", initializer="zeros")
51+
self.res = self.add_weight("residual", initializer="zeros")
52+
self.count = self.add_weight("count", initializer="zeros")
53+
54+
def update_state(self, y_true, y_pred):
55+
y_true = tf.convert_to_tensor(y_true, tf.float32)
56+
y_pred = tf.convert_to_tensor(y_pred, tf.float32)
57+
self.squared_sum.assign_add(tf.reduce_sum(y_true**2))
58+
self.sum.assign_add(tf.reduce_sum(y_true))
59+
self.res.assign_add(
60+
tf.reduce_sum(tf.square(tf.subtract(y_true, y_pred))))
61+
self.count.assign_add(tf.cast(tf.shape(y_true)[0], tf.float32))
62+
63+
def result(self):
64+
mean = self.sum / self.count
65+
total = self.squared_sum - 2 * self.sum * mean + self.count * mean**2
66+
return 1 - (self.res / total)
67+
68+
def reset_states(self):
69+
# The state of the metric will be reset at the start of each epoch.
70+
self.squared_sum.assign(0.0)
71+
self.sum.assign(0.0)
72+
self.res.assign(0.0)
73+
self.count.assign(0.0)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for R-Square Metric."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tensorflow as tf
22+
from tensorflow_addons.metrics import RSquare
23+
24+
25+
class RSquareTest(tf.test.TestCase):
26+
def test_config(self):
27+
r2_obj = RSquare(name='r_square')
28+
self.assertEqual(r2_obj.name, 'r_square')
29+
self.assertEqual(r2_obj.dtype, tf.float32)
30+
# Check save and restore config
31+
r2_obj2 = RSquare.from_config(r2_obj.get_config())
32+
self.assertEqual(r2_obj2.name, 'r_square')
33+
self.assertEqual(r2_obj2.dtype, tf.float32)
34+
35+
def initialize_vars(self):
36+
r2_obj = RSquare()
37+
self.evaluate(tf.compat.v1.variables_initializer(r2_obj.variables))
38+
return r2_obj
39+
40+
def update_obj_states(self, obj, actuals, preds):
41+
update_op = obj.update_state(actuals, preds)
42+
self.evaluate(update_op)
43+
44+
def check_results(self, obj, value):
45+
self.assertAllClose(value, self.evaluate(obj.result()), atol=1e-5)
46+
47+
def test_r2_perfect_score(self):
48+
actuals = tf.constant([100, 700, 40, 5.7], dtype=tf.float32)
49+
preds = tf.constant([100, 700, 40, 5.7], dtype=tf.float32)
50+
actuals = tf.constant(actuals, dtype=tf.float32)
51+
preds = tf.constant(preds, dtype=tf.float32)
52+
# Initialize
53+
r2_obj = self.initialize_vars()
54+
# Update
55+
self.update_obj_states(r2_obj, actuals, preds)
56+
# Check results
57+
self.check_results(r2_obj, 1.0)
58+
59+
def test_r2_worst_score(self):
60+
actuals = tf.constant([10, 600, 4, 9.77], dtype=tf.float32)
61+
preds = tf.constant([1, 70, 40, 5.7], dtype=tf.float32)
62+
actuals = tf.constant(actuals, dtype=tf.float32)
63+
preds = tf.constant(preds, dtype=tf.float32)
64+
# Initialize
65+
r2_obj = self.initialize_vars()
66+
# Update
67+
self.update_obj_states(r2_obj, actuals, preds)
68+
# Check results
69+
self.check_results(r2_obj, -0.073607)
70+
71+
def test_r2_random_score(self):
72+
actuals = tf.constant([10, 600, 3, 9.77], dtype=tf.float32)
73+
preds = tf.constant([1, 340, 40, 5.7], dtype=tf.float32)
74+
actuals = tf.constant(actuals, dtype=tf.float32)
75+
preds = tf.constant(preds, dtype=tf.float32)
76+
# Initialize
77+
r2_obj = self.initialize_vars()
78+
# Update
79+
self.update_obj_states(r2_obj, actuals, preds)
80+
# Check results
81+
self.check_results(r2_obj, 0.7376327)
82+
83+
84+
if __name__ == '__main__':
85+
tf.test.main()

0 commit comments

Comments
 (0)