Skip to content

Commit 0ce7df0

Browse files
Enhance Node2Vec TrainTask creation tests
Co-authored-by: Ioannis Panagiotas <ioannis.panagiotas@neotechnology.com>
1 parent cd9bc4a commit 0ce7df0

File tree

4 files changed

+98
-39
lines changed

4 files changed

+98
-39
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecModel.java

+36-23
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
*/
2020
package org.neo4j.gds.embeddings.node2vec;
2121

22-
import org.neo4j.gds.collections.ha.HugeLongArray;
2322
import org.neo4j.gds.collections.ha.HugeObjectArray;
2423
import org.neo4j.gds.core.concurrency.Concurrency;
2524
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
25+
import org.neo4j.gds.core.utils.partition.DegreePartition;
2626
import org.neo4j.gds.core.utils.partition.PartitionUtils;
2727
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2828
import org.neo4j.gds.mem.BitUtil;
@@ -154,10 +154,15 @@ Node2VecResult train() {
154154
}
155155
progressTracker.endSubTask();
156156

157-
return new Node2VecResult(centerEmbeddings, lossPerIteration);
157+
return new Node2VecResult(centerEmbeddings, lossPerIteration);
158158
}
159159

160-
private HugeObjectArray<FloatVector> initializeEmbeddings(LongUnaryOperator toOriginalNodeId, long nodeCount, int embeddingDimensions, Random random) {
160+
private HugeObjectArray<FloatVector> initializeEmbeddings(
161+
LongUnaryOperator toOriginalNodeId,
162+
long nodeCount,
163+
int embeddingDimensions,
164+
Random random
165+
) {
161166
HugeObjectArray<FloatVector> embeddings = HugeObjectArray.newArray(
162167
FloatVector.class,
163168
nodeCount
@@ -206,18 +211,16 @@ private TrainingTask(
206211
HugeObjectArray<FloatVector> centerEmbeddings,
207212
HugeObjectArray<FloatVector> contextEmbeddings,
208213
PositiveSampleProducer positiveSampleProducer,
209-
HugeLongArray negativeSamples,
214+
NegativeSampleProducer negativeSampleProducer,
210215
float learningRate,
211216
int negativeSamplingRate,
212217
int embeddingDimensions,
213-
ProgressTracker progressTracker,
214-
long randomSeed,
215-
int taskId
218+
ProgressTracker progressTracker
216219
) {
217220
this.centerEmbeddings = centerEmbeddings;
218221
this.contextEmbeddings = contextEmbeddings;
219222
this.positiveSampleProducer = positiveSampleProducer;
220-
this.negativeSampleProducer = new NegativeSampleProducer(negativeSamples, randomSeed + taskId);
223+
this.negativeSampleProducer = negativeSampleProducer;
221224
this.learningRate = learningRate;
222225
this.negativeSamplingRate = negativeSamplingRate;
223226

@@ -254,7 +257,7 @@ private void trainSample(long center, long context, boolean positive) {
254257
double positiveSigmoid = Sigmoid.sigmoid(affinity);
255258
double negativeSigmoid = 1 - positiveSigmoid;
256259

257-
lossSum -= positive ? Math.log(positiveSigmoid+EPSILON) : Math.log(negativeSigmoid+EPSILON);
260+
lossSum -= positive ? Math.log(positiveSigmoid + EPSILON) : Math.log(negativeSigmoid + EPSILON);
258261

259262
float gradient = positive ? (float) -negativeSigmoid : (float) positiveSigmoid;
260263
// we are doing gradient descent, so we go in the negative direction of the gradient here
@@ -290,37 +293,47 @@ void addAll(FloatConsumer other) {
290293
}
291294
}
292295

293-
List<TrainingTask> createTrainingTasks(float learningRate, AtomicInteger taskIndex){
296+
List<TrainingTask> createTrainingTasks(float learningRate, AtomicInteger taskIndex) {
294297
return PartitionUtils.degreePartitionWithBatchSize(
295298
walks.size(),
296299
walks::walkLength,
297300
BitUtil.ceilDiv(randomWalkProbabilities.sampleCount(), concurrency.value()),
298301
partition -> {
299-
300302
var taskId = taskIndex.getAndIncrement();
301-
var positiveSampleProducer = new PositiveSampleProducer(
302-
walks.iterator(partition.startNode(), partition.nodeCount()),
303-
randomWalkProbabilities.positiveSamplingProbabilities(),
304-
windowSize,
305-
Optional.of(randomSeed),
306-
taskId
307-
);
308-
303+
var taskRandomSeed = randomSeed + taskId;
304+
var positiveSampleProducer = createPositiveSampleProducer(partition, taskRandomSeed);
305+
var negativeSampleProducer = createNegativeSampleProducer(taskRandomSeed);
309306
return new TrainingTask(
310307
centerEmbeddings,
311308
contextEmbeddings,
312309
positiveSampleProducer,
313-
randomWalkProbabilities.negativeSamplingDistribution(),
310+
negativeSampleProducer,
314311
learningRate,
315312
negativeSamplingRate,
316313
embeddingDimension,
317-
progressTracker,
318-
randomSeed,
319-
taskId
314+
progressTracker
320315
);
321316
}
322317
);
318+
}
323319

320+
NegativeSampleProducer createNegativeSampleProducer(long randomSeed) {
321+
return new NegativeSampleProducer(
322+
randomWalkProbabilities.negativeSamplingDistribution(),
323+
randomSeed
324+
);
325+
}
326+
327+
PositiveSampleProducer createPositiveSampleProducer(
328+
DegreePartition partition,
329+
long randomSeed
330+
) {
331+
return new PositiveSampleProducer(
332+
walks.iterator(partition.startNode(), partition.nodeCount()),
333+
randomWalkProbabilities.positiveSamplingProbabilities(),
334+
windowSize,
335+
randomSeed
336+
);
324337
}
325338

326339
}

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/PositiveSampleProducer.java

+3-9
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222
import org.neo4j.gds.collections.ha.HugeDoubleArray;
2323

2424
import java.util.Iterator;
25-
import java.util.Optional;
2625
import java.util.SplittableRandom;
27-
import java.util.concurrent.ThreadLocalRandom;
2826

2927
import static org.neo4j.gds.mem.BitUtil.ceilDiv;
3028

@@ -42,14 +40,13 @@ public class PositiveSampleProducer {
4240
private int contextWordIndex;
4341
private int currentWindowStart;
4442
private int currentWindowEnd;
45-
private SplittableRandom probabilitySupplier;
43+
private final SplittableRandom probabilitySupplier;
4644

4745
PositiveSampleProducer(
4846
Iterator<long[]> walks,
4947
HugeDoubleArray samplingProbabilities,
5048
int windowSize,
51-
Optional<Long> maybeRandomSeed,
52-
int taskId
49+
long randomSeed
5350
) {
5451
this.walks = walks;
5552
this.samplingProbabilities = samplingProbabilities;
@@ -60,10 +57,7 @@ public class PositiveSampleProducer {
6057
this.currentWalk = new long[0];
6158
this.centerWordIndex = -1;
6259
this.contextWordIndex = 1;
63-
probabilitySupplier = maybeRandomSeed
64-
.map(seed -> new SplittableRandom(taskId + seed))
65-
.orElseGet(() -> new SplittableRandom(ThreadLocalRandom.current().nextLong()));
66-
60+
probabilitySupplier = new SplittableRandom(randomSeed);
6761
}
6862

6963
public boolean next(long[] buffer) {

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecModelTest.java

+59
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,31 @@
1919
*/
2020
package org.neo4j.gds.embeddings.node2vec;
2121

22+
import org.junit.jupiter.api.DisplayName;
2223
import org.junit.jupiter.api.Test;
2324
import org.junit.jupiter.params.ParameterizedTest;
2425
import org.junit.jupiter.params.provider.ValueSource;
26+
import org.neo4j.gds.collections.ha.HugeLongArray;
2527
import org.neo4j.gds.core.concurrency.Concurrency;
2628
import org.neo4j.gds.core.utils.Intersections;
2729
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2830

2931
import java.util.Optional;
3032
import java.util.Random;
33+
import java.util.concurrent.atomic.AtomicInteger;
34+
import java.util.function.LongUnaryOperator;
3135
import java.util.stream.LongStream;
3236

3337
import static org.assertj.core.api.Assertions.assertThat;
3438
import static org.junit.jupiter.api.Assertions.assertEquals;
39+
import static org.mockito.ArgumentMatchers.any;
40+
import static org.mockito.ArgumentMatchers.anyLong;
41+
import static org.mockito.ArgumentMatchers.eq;
42+
import static org.mockito.Mockito.mock;
43+
import static org.mockito.Mockito.spy;
44+
import static org.mockito.Mockito.times;
45+
import static org.mockito.Mockito.verify;
46+
import static org.mockito.Mockito.when;
3547

3648
class Node2VecModelTest {
3749

@@ -196,6 +208,53 @@ void randomSeed(int iterations) {
196208
}
197209
}
198210

211+
@Test
212+
@DisplayName("When creating multiple tasks with random seed the actual seed for the task should be `randomSeed + taskId`.")
213+
void shouldCreateTrainingTasksWithCorrectRandomSeed() {
214+
var randomWalksMock = mock(CompressedRandomWalks.class);
215+
when(randomWalksMock.size()).thenReturn(10L);
216+
when(randomWalksMock.walkLength(anyLong())).thenReturn(3);
217+
218+
var randomWalkProbabilitiesMock = mock(RandomWalkProbabilities.class);
219+
when(randomWalkProbabilitiesMock.sampleCount()).thenReturn(30L);
220+
when(randomWalkProbabilitiesMock.negativeSamplingDistribution()).thenReturn(HugeLongArray.newArray(10));
221+
222+
var trainParametersMock = mock(TrainParameters.class);
223+
when(trainParametersMock.embeddingInitializer()).thenReturn(EmbeddingInitializer.UNIFORM);
224+
225+
var node2VecModel = spy(
226+
new Node2VecModel(
227+
LongUnaryOperator.identity(),
228+
1000,
229+
trainParametersMock,
230+
new Concurrency(4),
231+
Optional.of(1L), // Random Seed
232+
randomWalksMock,
233+
randomWalkProbabilitiesMock,
234+
ProgressTracker.NULL_TRACKER
235+
)
236+
);
237+
238+
var taskIdTracker = new AtomicInteger(0);
239+
var trainingTasks = node2VecModel.createTrainingTasks(0.2f, taskIdTracker);
240+
241+
assertThat(trainingTasks).hasSize(5);
242+
243+
verify(node2VecModel, times(5)).createPositiveSampleProducer(any(), anyLong());
244+
verify(node2VecModel, times(1)).createPositiveSampleProducer(any(), eq(1L));
245+
verify(node2VecModel, times(1)).createPositiveSampleProducer(any(), eq(2L));
246+
verify(node2VecModel, times(1)).createPositiveSampleProducer(any(), eq(3L));
247+
verify(node2VecModel, times(1)).createPositiveSampleProducer(any(), eq(4L));
248+
verify(node2VecModel, times(1)).createPositiveSampleProducer(any(), eq(5L));
249+
250+
verify(node2VecModel, times(5)).createNegativeSampleProducer(anyLong());
251+
verify(node2VecModel, times(1)).createNegativeSampleProducer(1L);
252+
verify(node2VecModel, times(1)).createNegativeSampleProducer(2L);
253+
verify(node2VecModel, times(1)).createNegativeSampleProducer(3L);
254+
verify(node2VecModel, times(1)).createNegativeSampleProducer(4L);
255+
verify(node2VecModel, times(1)).createNegativeSampleProducer(5L);
256+
}
257+
199258
private static CompressedRandomWalks generateRandomWalks(
200259
RandomWalkProbabilities.Builder probabilitiesBuilder,
201260
long numberOfClusters,

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/PositiveSampleProducerTest.java

-7
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import java.util.ArrayList;
3030
import java.util.Collection;
3131
import java.util.List;
32-
import java.util.Optional;
3332
import java.util.stream.LongStream;
3433
import java.util.stream.Stream;
3534

@@ -61,7 +60,6 @@ void doesNotCauseStackOverflow() {
6160
walks.iterator(0, nbrOfWalks),
6261
HugeDoubleArray.of(LongStream.range(0, nbrOfWalks).mapToDouble((l) -> 1.0).toArray()),
6362
10,
64-
Optional.empty(),
6563
0
6664
);
6765

@@ -90,7 +88,6 @@ void doesNotCauseStackOverflowDueToBadLuck() {
9088
walks.iterator(0, nbrOfWalks),
9189
probabilities,
9290
10,
93-
Optional.empty(),
9491
0
9592
);
9693
// does not overflow the stack = passes test
@@ -115,7 +112,6 @@ void doesNotAttemptToFetchOutsideBatch() {
115112
walks.iterator(0, nbrOfWalks / 2),
116113
HugeDoubleArray.of(LongStream.range(0, nbrOfWalks).mapToDouble((l) -> 1.0).toArray()),
117114
10,
118-
Optional.empty(),
119115
0
120116
);
121117

@@ -141,7 +137,6 @@ void shouldProducePairsWith(
141137
walks.iterator(0, walks.size()),
142138
centerNodeProbabilities,
143139
windowSize,
144-
Optional.empty(),
145140
0
146141
);
147142
while (producer.next(buffer)) {
@@ -165,7 +160,6 @@ void shouldProducePairsWithBounds() {
165160
walks.iterator(0, 2),
166161
centerNodeProbabilities,
167162
3,
168-
Optional.empty(),
169163
0
170164
);
171165
while (producer.next(buffer)) {
@@ -212,7 +206,6 @@ void shouldRemoveDownsampledWordFromWalk() {
212206
walks.iterator(0, walks.size()),
213207
centerNodeProbabilities,
214208
3,
215-
Optional.empty(),
216209
0
217210
);
218211

0 commit comments

Comments
 (0)