|
19 | 19 | */
|
20 | 20 | package org.neo4j.gds.embeddings.node2vec;
|
21 | 21 |
|
22 |
| -import org.neo4j.gds.collections.ha.HugeLongArray; |
23 | 22 | import org.neo4j.gds.collections.ha.HugeObjectArray;
|
24 | 23 | import org.neo4j.gds.core.concurrency.Concurrency;
|
25 | 24 | import org.neo4j.gds.core.concurrency.RunWithConcurrency;
|
| 25 | +import org.neo4j.gds.core.utils.partition.DegreePartition; |
26 | 26 | import org.neo4j.gds.core.utils.partition.PartitionUtils;
|
27 | 27 | import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
|
28 | 28 | import org.neo4j.gds.mem.BitUtil;
|
@@ -154,10 +154,15 @@ Node2VecResult train() {
|
154 | 154 | }
|
155 | 155 | progressTracker.endSubTask();
|
156 | 156 |
|
157 |
| - return new Node2VecResult(centerEmbeddings, lossPerIteration); |
| 157 | + return new Node2VecResult(centerEmbeddings, lossPerIteration); |
158 | 158 | }
|
159 | 159 |
|
160 |
| - private HugeObjectArray<FloatVector> initializeEmbeddings(LongUnaryOperator toOriginalNodeId, long nodeCount, int embeddingDimensions, Random random) { |
| 160 | + private HugeObjectArray<FloatVector> initializeEmbeddings( |
| 161 | + LongUnaryOperator toOriginalNodeId, |
| 162 | + long nodeCount, |
| 163 | + int embeddingDimensions, |
| 164 | + Random random |
| 165 | + ) { |
161 | 166 | HugeObjectArray<FloatVector> embeddings = HugeObjectArray.newArray(
|
162 | 167 | FloatVector.class,
|
163 | 168 | nodeCount
|
@@ -206,18 +211,16 @@ private TrainingTask(
|
206 | 211 | HugeObjectArray<FloatVector> centerEmbeddings,
|
207 | 212 | HugeObjectArray<FloatVector> contextEmbeddings,
|
208 | 213 | PositiveSampleProducer positiveSampleProducer,
|
209 |
| - HugeLongArray negativeSamples, |
| 214 | + NegativeSampleProducer negativeSampleProducer, |
210 | 215 | float learningRate,
|
211 | 216 | int negativeSamplingRate,
|
212 | 217 | int embeddingDimensions,
|
213 |
| - ProgressTracker progressTracker, |
214 |
| - long randomSeed, |
215 |
| - int taskId |
| 218 | + ProgressTracker progressTracker |
216 | 219 | ) {
|
217 | 220 | this.centerEmbeddings = centerEmbeddings;
|
218 | 221 | this.contextEmbeddings = contextEmbeddings;
|
219 | 222 | this.positiveSampleProducer = positiveSampleProducer;
|
220 |
| - this.negativeSampleProducer = new NegativeSampleProducer(negativeSamples, randomSeed + taskId); |
| 223 | + this.negativeSampleProducer = negativeSampleProducer; |
221 | 224 | this.learningRate = learningRate;
|
222 | 225 | this.negativeSamplingRate = negativeSamplingRate;
|
223 | 226 |
|
@@ -254,7 +257,7 @@ private void trainSample(long center, long context, boolean positive) {
|
254 | 257 | double positiveSigmoid = Sigmoid.sigmoid(affinity);
|
255 | 258 | double negativeSigmoid = 1 - positiveSigmoid;
|
256 | 259 |
|
257 |
| - lossSum -= positive ? Math.log(positiveSigmoid+EPSILON) : Math.log(negativeSigmoid+EPSILON); |
| 260 | + lossSum -= positive ? Math.log(positiveSigmoid + EPSILON) : Math.log(negativeSigmoid + EPSILON); |
258 | 261 |
|
259 | 262 | float gradient = positive ? (float) -negativeSigmoid : (float) positiveSigmoid;
|
260 | 263 | // we are doing gradient descent, so we go in the negative direction of the gradient here
|
@@ -290,37 +293,47 @@ void addAll(FloatConsumer other) {
|
290 | 293 | }
|
291 | 294 | }
|
292 | 295 |
|
293 |
| - List<TrainingTask> createTrainingTasks(float learningRate, AtomicInteger taskIndex){ |
| 296 | + List<TrainingTask> createTrainingTasks(float learningRate, AtomicInteger taskIndex) { |
294 | 297 | return PartitionUtils.degreePartitionWithBatchSize(
|
295 | 298 | walks.size(),
|
296 | 299 | walks::walkLength,
|
297 | 300 | BitUtil.ceilDiv(randomWalkProbabilities.sampleCount(), concurrency.value()),
|
298 | 301 | partition -> {
|
299 |
| - |
300 | 302 | var taskId = taskIndex.getAndIncrement();
|
301 |
| - var positiveSampleProducer = new PositiveSampleProducer( |
302 |
| - walks.iterator(partition.startNode(), partition.nodeCount()), |
303 |
| - randomWalkProbabilities.positiveSamplingProbabilities(), |
304 |
| - windowSize, |
305 |
| - Optional.of(randomSeed), |
306 |
| - taskId |
307 |
| - ); |
308 |
| - |
| 303 | + var taskRandomSeed = randomSeed + taskId; |
| 304 | + var positiveSampleProducer = createPositiveSampleProducer(partition, taskRandomSeed); |
| 305 | + var negativeSampleProducer = createNegativeSampleProducer(taskRandomSeed); |
309 | 306 | return new TrainingTask(
|
310 | 307 | centerEmbeddings,
|
311 | 308 | contextEmbeddings,
|
312 | 309 | positiveSampleProducer,
|
313 |
| - randomWalkProbabilities.negativeSamplingDistribution(), |
| 310 | + negativeSampleProducer, |
314 | 311 | learningRate,
|
315 | 312 | negativeSamplingRate,
|
316 | 313 | embeddingDimension,
|
317 |
| - progressTracker, |
318 |
| - randomSeed, |
319 |
| - taskId |
| 314 | + progressTracker |
320 | 315 | );
|
321 | 316 | }
|
322 | 317 | );
|
| 318 | + } |
323 | 319 |
|
| 320 | + NegativeSampleProducer createNegativeSampleProducer(long randomSeed) { |
| 321 | + return new NegativeSampleProducer( |
| 322 | + randomWalkProbabilities.negativeSamplingDistribution(), |
| 323 | + randomSeed |
| 324 | + ); |
| 325 | + } |
| 326 | + |
| 327 | + PositiveSampleProducer createPositiveSampleProducer( |
| 328 | + DegreePartition partition, |
| 329 | + long randomSeed |
| 330 | + ) { |
| 331 | + return new PositiveSampleProducer( |
| 332 | + walks.iterator(partition.startNode(), partition.nodeCount()), |
| 333 | + randomWalkProbabilities.positiveSamplingProbabilities(), |
| 334 | + windowSize, |
| 335 | + randomSeed |
| 336 | + ); |
324 | 337 | }
|
325 | 338 |
|
326 | 339 | }
|
0 commit comments