Skip to content

Commit

Permalink
Simplified conv2d gradient calc + optimized immutable array migration
Browse files Browse the repository at this point in the history
(the gradient calc might be slower, but is a bit easier to read since the summation is now done separately)
  • Loading branch information
pingng committed Dec 18, 2020
1 parent 583cd0b commit 2ffd4ad
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/main/java/com/codeberry/tadlib/array/TArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public TArray softmax() {

fillSoftMax(this, output, output.shape.newIndexArray(), 0);

return output.toImmutable();
return output.migrateToImmutable();
}

private static void fillSoftMax(TArray src, TMutableArray tgt, int[] indices, int dim) {
Expand Down
19 changes: 10 additions & 9 deletions src/main/java/com/codeberry/tadlib/array/TMutableArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import java.util.Arrays;

public class TMutableArray {
private final double[] data;
private volatile double[] data;
public final Shape shape;

public TMutableArray(Shape shape) {
Expand Down Expand Up @@ -31,18 +31,19 @@ public double dataAt(int... indices) {
return 0;
}

public void addAt(int[] indices, double v) {
int offset = shape.calcDataIndex(indices);
data[offset] += v;
}

public void setAt(int[] indices, double v) {
int offset = shape.calcDataIndex(indices);
data[offset] = v;
}

// TODO: implement as "migrateToImmutable() that nulls the array and hands it to TArray? Avoid array copy.
public TArray toImmutable() {
return new TArray(Arrays.copyOf(data, data.length), shape.copy());
/**
* The current instance cannot be used after this call.
*/
public synchronized TArray migrateToImmutable() {
TArray immutable = new TArray(this.data, shape.copy());

this.data = null;

return immutable;
}
}
51 changes: 30 additions & 21 deletions src/main/java/com/codeberry/tadlib/tensor/Ops.java
Original file line number Diff line number Diff line change
Expand Up @@ -188,37 +188,45 @@ public static Tensor conv2d(Tensor input, Tensor filter) {
}

private static TArray calcFilterGradient(TArray grad, TArray input, TArray filter) {
Shape tgtShape = filter.shape.normalOrderedCopy();

return multiThreadingSupportRun(taskRange(0, grad.shape.at(0)),
range -> accumulateFilterGradientAtFirstDim(range, grad, input, filter, tgtShape),
Shape filterShape = filter.shape;
int[] dims = new int[filterShape.dimCount + 1];
System.arraycopy(filterShape.toDimArray(), 0, dims, 1, filterShape.dimCount);
dims[0] = input.shape.at(0);
Shape tgtShape = new Shape(dims);

TArray gradPerInputExample = multiThreadingSupportRun(taskRange(0, grad.shape.at(0)),
range -> accumulateFilterGradientAtFirstDim(range, grad, input, tgtShape),
(left, right) -> left.add(right));

return gradPerInputExample.sumFirstDims(1, REMOVE_DIM);
}

private static TArray accumulateFilterGradientAtFirstDim(TaskRange range, TArray grad, TArray input, TArray filter, Shape tgtShape) {
TMutableArray tgtGrad = new TMutableArray(new double[filter.shape.size], tgtShape);
private static TArray accumulateFilterGradientAtFirstDim(TaskRange range, TArray grad, TArray input, Shape tgtShape) {
TMutableArray tgtGrad = new TMutableArray(new double[tgtShape.size], tgtShape);

int[] gradIndices = grad.shape.newIndexArray();
int[] inIndices = input.shape.newIndexArray();
int[] filterIndices = filter.shape.newIndexArray();
int[] tgtIndices = tgtShape.newIndexArray();

for (int i = range.start; i < range.end; i++) {
gradIndices[0] = i;
inIndices[0] = i;
accumulateFilterGradient(grad, gradIndices, 1, input, inIndices, tgtGrad, filterIndices);
tgtIndices[0] = i;
accumulateFilterGradient(grad, gradIndices, 1, input, inIndices, tgtGrad, tgtIndices);
}

return tgtGrad.toImmutable();
return tgtGrad.migrateToImmutable();
}

private static void accumulateFilterGradient(TArray grad, int[] gradIndices, int dim,
TArray input, int[] inIndices,
TMutableArray tgtGrad, int[] filterIndices) {
TMutableArray tgtGrad, int[] tgtIndices) {
if (gradIndices.length - dim == 3) {
int filterH = tgtGrad.shape.at(0);
int filterW = tgtGrad.shape.at(1);
int filterH = tgtGrad.shape.at(1);
int filterW = tgtGrad.shape.at(2);
int inputChannels = input.shape.at(-1);
int outChannels = tgtGrad.shape.at(-1);
int tgtDims = tgtIndices.length;
for (int inIdx = 0; inIdx < inputChannels; inIdx++) {
for (int outIdx = 0; outIdx < outChannels; outIdx++) {
for (int y = 0; y < filterH; y++) {
Expand All @@ -229,11 +237,11 @@ private static void accumulateFilterGradient(TArray grad, int[] gradIndices, int
inIdx, outIdx,
y, x);

filterIndices[0] = y;
filterIndices[1] = x;
filterIndices[2] = inIdx;
filterIndices[3] = outIdx;
tgtGrad.addAt(filterIndices, g);
tgtIndices[tgtDims - 4] = y;
tgtIndices[tgtDims - 3] = x;
tgtIndices[tgtDims - 2] = inIdx;
tgtIndices[tgtDims - 1] = outIdx;
tgtGrad.setAt(tgtIndices, g);
}
}
}
Expand All @@ -243,7 +251,8 @@ private static void accumulateFilterGradient(TArray grad, int[] gradIndices, int
for (int i = 0; i < len; i++) {
gradIndices[dim] = i;
inIndices[dim] = i;
accumulateFilterGradient(grad, gradIndices, dim + 1, input, inIndices, tgtGrad, filterIndices);
tgtIndices[dim] = i;
accumulateFilterGradient(grad, gradIndices, dim + 1, input, inIndices, tgtGrad, tgtIndices);
}
}
}
Expand Down Expand Up @@ -294,7 +303,7 @@ public static Tensor maxpool2d(Tensor input, int size) {

GradFunc gF = grad -> distribute2dMaxGrad(grad, input.vals.shape, maxIndexShape, maxIndexData);

return new Tensor(tgt.toImmutable(), singletonList(parentLink(input, gF)));
return new Tensor(tgt.migrateToImmutable(), singletonList(parentLink(input, gF)));
}

public static Shape getMaxPool2dOutputSize(Shape inputShape, int size) {
Expand All @@ -317,7 +326,7 @@ private static TArray distribute2dMaxGrad(TArray grad, Shape inputShape, Shape m
fillMax2dGradInto(outputGrad, maxIndexShape, maxIndexData, grad, 0,
tmpOutputGradIndices, tmpGradIndices, tmpMaxIndices);

return outputGrad.toImmutable();
return outputGrad.migrateToImmutable();
}

private static void fillMax2dGradInto(TMutableArray outputGrad, Shape maxIndexShape, int[] maxIndexData, TArray grad, int dim,
Expand Down Expand Up @@ -539,7 +548,7 @@ public static Tensor sumSoftmaxCrossEntropy(Tensor labelsOneHot, Tensor predicti
TMutableArray tgt = TMutableArray.copyOf(softmax);
toSoftmaxGradient(tgt, softmax, softmax.shape.newIndexArray(),
labelsOneHot.vals, 0);
return tgt.toImmutable().mul(grad);
return tgt.migrateToImmutable().mul(grad);
};

return new Tensor(new TArray(cost), singletonList(parentLink(prediction, gF)));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package com.codeberry.tadlib.util;

import com.codeberry.tadlib.array.Shape;
import com.codeberry.tadlib.array.TArray;
import com.codeberry.tadlib.array.TArrayFactory;
import com.codeberry.tadlib.array.TMutableArray;
import com.codeberry.tadlib.tensor.Tensor;

Expand All @@ -18,6 +15,6 @@ public static Tensor toOneHot(Tensor yTrain, int outputUnits) {
indices[1] = (int) yTrain.dataAt(i, 0);
out.setAt(indices, 1.0);
}
return new Tensor(out.toImmutable(), Tensor.GradientMode.NONE);
return new Tensor(out.migrateToImmutable(), Tensor.GradientMode.NONE);
}
}

0 comments on commit 2ffd4ad

Please # to comment.