From 19ff52c7a284af426114479d393a9d2e501d59cc Mon Sep 17 00:00:00 2001 From: jarno-r Date: Wed, 2 Jan 2019 09:39:32 +0200 Subject: [PATCH 1/2] Fixed division by zero in QR decomposition. Issue #1058 --- src/ops/linalg_ops.ts | 7 ++++++- src/ops/linalg_ops_test.ts | 23 +++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts index 7a15b94e16..40fe0463ff 100644 --- a/src/ops/linalg_ops.ts +++ b/src/ops/linalg_ops.ts @@ -211,7 +211,12 @@ function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] { const rjEnd1 = r.slice([j, j], [m - j, 1]); const normX = rjEnd1.norm(); const rjj = r.slice([j, j], [1, 1]); - const s = rjj.sign().neg() as Tensor2D; + + // The sign() function returns 0 on 0, which causes division by zero. + const s = tensor2d([[-1]]).where( + rjj.greater(tensor2d([[0]])), + tensor2d([[1]])); + const u1 = rjj.sub(s.mul(normX)) as Tensor2D; const wPre = rjEnd1.div(u1); if (wPre.shape[0] === 1) { diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index bfbb5ef62b..96ab17fb08 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -171,6 +171,29 @@ describeWithFlags('qr', ALL_ENVS, () => { [3, 3])); }); + it('3x3, zero on diagonal', () => { + const x = tensor2d([[0, 2, 2], [1, 1, 1], [0, 1, 2]], [3, 3]); + const [q, r] = tf.linalg.qr(x); + expectArraysClose( + q, + tensor2d( + [ + [0., -0.89442719, 0.4472136], + [1., 0., 0.], + [0., -0.4472136, -0.89442719] + ], + [3, 3])); + expectArraysClose( + r, + tensor2d( + [ + [1., 1., 1.], + [0., -2.23606798, -2.68328157], + [0., 0., -0.89442719] + ], + [3, 3])); + }); + it('3x2, fullMatrices = default false', () => { const x = tensor2d([[1, 2], [3, -3], [-2, 1]], [3, 2]); const [q, r] = tf.linalg.qr(x); From f2f6da2b811e6805317f5ed6687044d857804914 Mon Sep 17 00:00:00 2001 From: jarno-r Date: Thu, 3 Jan 2019 08:08:12 +0200 Subject: [PATCH 2/2] Minor code simplification. --- src/ops/linalg_ops.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts index 40fe0463ff..e23df3b1de 100644 --- a/src/ops/linalg_ops.ts +++ b/src/ops/linalg_ops.ts @@ -214,7 +214,7 @@ function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] { // The sign() function returns 0 on 0, which causes division by zero. const s = tensor2d([[-1]]).where( - rjj.greater(tensor2d([[0]])), + rjj.greater(0), tensor2d([[1]])); const u1 = rjj.sub(s.mul(normX)) as Tensor2D;