From 2ffd4ad9d78aa7e88ef1f6a7c5906c72e5bc9737 Mon Sep 17 00:00:00 2001 From: pingn Date: Fri, 18 Dec 2020 13:52:15 +0100 Subject: [PATCH] Simplified conv2d gradient calc + optimized immutable array migration (the gradient calc might be slower, but is a bit easier to read since the summation is now done separately) --- .../com/codeberry/tadlib/array/TArray.java | 2 +- .../codeberry/tadlib/array/TMutableArray.java | 19 +++---- .../java/com/codeberry/tadlib/tensor/Ops.java | 51 +++++++++++-------- .../tadlib/util/TrainingDataUtils.java | 5 +- 4 files changed, 42 insertions(+), 35 deletions(-) diff --git a/src/main/java/com/codeberry/tadlib/array/TArray.java b/src/main/java/com/codeberry/tadlib/array/TArray.java index 9038823..64e8234 100644 --- a/src/main/java/com/codeberry/tadlib/array/TArray.java +++ b/src/main/java/com/codeberry/tadlib/array/TArray.java @@ -50,7 +50,7 @@ public TArray softmax() { fillSoftMax(this, output, output.shape.newIndexArray(), 0); - return output.toImmutable(); + return output.migrateToImmutable(); } private static void fillSoftMax(TArray src, TMutableArray tgt, int[] indices, int dim) { diff --git a/src/main/java/com/codeberry/tadlib/array/TMutableArray.java b/src/main/java/com/codeberry/tadlib/array/TMutableArray.java index 1423f0d..2d4ebdd 100644 --- a/src/main/java/com/codeberry/tadlib/array/TMutableArray.java +++ b/src/main/java/com/codeberry/tadlib/array/TMutableArray.java @@ -3,7 +3,7 @@ import java.util.Arrays; public class TMutableArray { - private final double[] data; + private volatile double[] data; public final Shape shape; public TMutableArray(Shape shape) { @@ -31,18 +31,19 @@ public double dataAt(int... indices) { return 0; } - public void addAt(int[] indices, double v) { - int offset = shape.calcDataIndex(indices); - data[offset] += v; - } - public void setAt(int[] indices, double v) { int offset = shape.calcDataIndex(indices); data[offset] = v; } - // TODO: implement as "migrateToImmutable() that nulls the array and hands it to TArray? Avoid array copy. - public TArray toImmutable() { - return new TArray(Arrays.copyOf(data, data.length), shape.copy()); + /** + * The current instance cannot be used after this call. + */ + public synchronized TArray migrateToImmutable() { + TArray immutable = new TArray(this.data, shape.copy()); + + this.data = null; + + return immutable; } } diff --git a/src/main/java/com/codeberry/tadlib/tensor/Ops.java b/src/main/java/com/codeberry/tadlib/tensor/Ops.java index d40d2e0..d7f233e 100644 --- a/src/main/java/com/codeberry/tadlib/tensor/Ops.java +++ b/src/main/java/com/codeberry/tadlib/tensor/Ops.java @@ -188,37 +188,45 @@ public static Tensor conv2d(Tensor input, Tensor filter) { } private static TArray calcFilterGradient(TArray grad, TArray input, TArray filter) { - Shape tgtShape = filter.shape.normalOrderedCopy(); - - return multiThreadingSupportRun(taskRange(0, grad.shape.at(0)), - range -> accumulateFilterGradientAtFirstDim(range, grad, input, filter, tgtShape), + Shape filterShape = filter.shape; + int[] dims = new int[filterShape.dimCount + 1]; + System.arraycopy(filterShape.toDimArray(), 0, dims, 1, filterShape.dimCount); + dims[0] = input.shape.at(0); + Shape tgtShape = new Shape(dims); + + TArray gradPerInputExample = multiThreadingSupportRun(taskRange(0, grad.shape.at(0)), + range -> accumulateFilterGradientAtFirstDim(range, grad, input, tgtShape), (left, right) -> left.add(right)); + + return gradPerInputExample.sumFirstDims(1, REMOVE_DIM); } - private static TArray accumulateFilterGradientAtFirstDim(TaskRange range, TArray grad, TArray input, TArray filter, Shape tgtShape) { - TMutableArray tgtGrad = new TMutableArray(new double[filter.shape.size], tgtShape); + private static TArray accumulateFilterGradientAtFirstDim(TaskRange range, TArray grad, TArray input, Shape tgtShape) { + TMutableArray tgtGrad = new TMutableArray(new double[tgtShape.size], tgtShape); int[] gradIndices = grad.shape.newIndexArray(); int[] inIndices = input.shape.newIndexArray(); - int[] filterIndices = filter.shape.newIndexArray(); + int[] tgtIndices = tgtShape.newIndexArray(); for (int i = range.start; i < range.end; i++) { gradIndices[0] = i; inIndices[0] = i; - accumulateFilterGradient(grad, gradIndices, 1, input, inIndices, tgtGrad, filterIndices); + tgtIndices[0] = i; + accumulateFilterGradient(grad, gradIndices, 1, input, inIndices, tgtGrad, tgtIndices); } - return tgtGrad.toImmutable(); + return tgtGrad.migrateToImmutable(); } private static void accumulateFilterGradient(TArray grad, int[] gradIndices, int dim, TArray input, int[] inIndices, - TMutableArray tgtGrad, int[] filterIndices) { + TMutableArray tgtGrad, int[] tgtIndices) { if (gradIndices.length - dim == 3) { - int filterH = tgtGrad.shape.at(0); - int filterW = tgtGrad.shape.at(1); + int filterH = tgtGrad.shape.at(1); + int filterW = tgtGrad.shape.at(2); int inputChannels = input.shape.at(-1); int outChannels = tgtGrad.shape.at(-1); + int tgtDims = tgtIndices.length; for (int inIdx = 0; inIdx < inputChannels; inIdx++) { for (int outIdx = 0; outIdx < outChannels; outIdx++) { for (int y = 0; y < filterH; y++) { @@ -229,11 +237,11 @@ private static void accumulateFilterGradient(TArray grad, int[] gradIndices, int inIdx, outIdx, y, x); - filterIndices[0] = y; - filterIndices[1] = x; - filterIndices[2] = inIdx; - filterIndices[3] = outIdx; - tgtGrad.addAt(filterIndices, g); + tgtIndices[tgtDims - 4] = y; + tgtIndices[tgtDims - 3] = x; + tgtIndices[tgtDims - 2] = inIdx; + tgtIndices[tgtDims - 1] = outIdx; + tgtGrad.setAt(tgtIndices, g); } } } @@ -243,7 +251,8 @@ private static void accumulateFilterGradient(TArray grad, int[] gradIndices, int for (int i = 0; i < len; i++) { gradIndices[dim] = i; inIndices[dim] = i; - accumulateFilterGradient(grad, gradIndices, dim + 1, input, inIndices, tgtGrad, filterIndices); + tgtIndices[dim] = i; + accumulateFilterGradient(grad, gradIndices, dim + 1, input, inIndices, tgtGrad, tgtIndices); } } } @@ -294,7 +303,7 @@ public static Tensor maxpool2d(Tensor input, int size) { GradFunc gF = grad -> distribute2dMaxGrad(grad, input.vals.shape, maxIndexShape, maxIndexData); - return new Tensor(tgt.toImmutable(), singletonList(parentLink(input, gF))); + return new Tensor(tgt.migrateToImmutable(), singletonList(parentLink(input, gF))); } public static Shape getMaxPool2dOutputSize(Shape inputShape, int size) { @@ -317,7 +326,7 @@ private static TArray distribute2dMaxGrad(TArray grad, Shape inputShape, Shape m fillMax2dGradInto(outputGrad, maxIndexShape, maxIndexData, grad, 0, tmpOutputGradIndices, tmpGradIndices, tmpMaxIndices); - return outputGrad.toImmutable(); + return outputGrad.migrateToImmutable(); } private static void fillMax2dGradInto(TMutableArray outputGrad, Shape maxIndexShape, int[] maxIndexData, TArray grad, int dim, @@ -539,7 +548,7 @@ public static Tensor sumSoftmaxCrossEntropy(Tensor labelsOneHot, Tensor predicti TMutableArray tgt = TMutableArray.copyOf(softmax); toSoftmaxGradient(tgt, softmax, softmax.shape.newIndexArray(), labelsOneHot.vals, 0); - return tgt.toImmutable().mul(grad); + return tgt.migrateToImmutable().mul(grad); }; return new Tensor(new TArray(cost), singletonList(parentLink(prediction, gF))); diff --git a/src/main/java/com/codeberry/tadlib/util/TrainingDataUtils.java b/src/main/java/com/codeberry/tadlib/util/TrainingDataUtils.java index 2ff52cf..d97a543 100644 --- a/src/main/java/com/codeberry/tadlib/util/TrainingDataUtils.java +++ b/src/main/java/com/codeberry/tadlib/util/TrainingDataUtils.java @@ -1,8 +1,5 @@ package com.codeberry.tadlib.util; -import com.codeberry.tadlib.array.Shape; -import com.codeberry.tadlib.array.TArray; -import com.codeberry.tadlib.array.TArrayFactory; import com.codeberry.tadlib.array.TMutableArray; import com.codeberry.tadlib.tensor.Tensor; @@ -18,6 +15,6 @@ public static Tensor toOneHot(Tensor yTrain, int outputUnits) { indices[1] = (int) yTrain.dataAt(i, 0); out.setAt(indices, 1.0); } - return new Tensor(out.toImmutable(), Tensor.GradientMode.NONE); + return new Tensor(out.migrateToImmutable(), Tensor.GradientMode.NONE); } }