Skip to content

Commit cc45d0c

Browse files
author
finmath.net
authored
Merge pull request #1 from stefansedlmair/stoch-aad-with-expectation
Stochastic AAD with Expectation Operator
2 parents 9f9bc2a + f0daee9 commit cc45d0c

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

src/main/java/net/finmath/montecarlo/automaticdifferentiation/backward/RandomVariableDifferentiableAAD.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ private void propagateDerivativesFromResultToArgument(Map<Long, RandomVariableIn
100100
RandomVariableInterface derivative = derivatives.get(id);
101101
RandomVariableInterface argumentDerivative = derivatives.get(argumentID);
102102

103+
// Implementation of AVERAGE (see paper for details).
104+
if(operator == OperatorType.AVERAGE) derivative = derivative.average();
105+
103106
argumentDerivative = argumentDerivative.addProduct(partialDerivative, derivative);
104107

105108
derivatives.put(argumentID, argumentDerivative);
@@ -139,7 +142,7 @@ private RandomVariableInterface getPartialDerivative(OperatorTreeNode differenti
139142
resultrandomvariable = X.sin().mult(-1.0);
140143
break;
141144
case AVERAGE:
142-
resultrandomvariable = new RandomVariable(X.size()).invert();
145+
resultrandomvariable = new RandomVariable(1.0);
143146
break;
144147
case VARIANCE:
145148
resultrandomvariable = X.sub(X.getAverage()*(2.0*X.size()-1.0)/X.size()).mult(2.0/X.size());

src/test/java/net/finmath/montecarlo/automaticdifferentiation/backward/RandomVariableDifferentiableInterfaceTest.java

+30-1
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
import org.junit.runners.Parameterized;
1515
import org.junit.runners.Parameterized.Parameters;
1616

17-
import net.finmath.montecarlo.AbstractRandomVariableFactory;
17+
import net.finmath.montecarlo.BrownianMotion;
18+
import net.finmath.montecarlo.BrownianMotionInterface;
1819
import net.finmath.montecarlo.RandomVariable;
1920
import net.finmath.montecarlo.RandomVariableFactory;
2021
import net.finmath.montecarlo.automaticdifferentiation.AbstractRandomVariableDifferentiableFactory;
2122
import net.finmath.montecarlo.automaticdifferentiation.RandomVariableDifferentiableInterface;
2223
import net.finmath.montecarlo.automaticdifferentiation.backward.alternative.RandomVariableAADv2Factory;
2324
import net.finmath.stochastic.RandomVariableInterface;
25+
import net.finmath.time.TimeDiscretization;
2426

2527
/**
2628
* Unit test for random variables implementing <code>RandomVariableDifferentiableInterface</code>.
@@ -346,6 +348,33 @@ public void testRandomVariableGradientBigSum2(){
346348

347349
}
348350

351+
@Test
352+
public void testRandomVariableExpectation(){
353+
354+
int numberOfPaths = 100000;
355+
int seed = 3141;
356+
BrownianMotionInterface brownianMotion = new BrownianMotion(new TimeDiscretization(0.0, 1.0), 1 /* numberOfFactors */, numberOfPaths, seed);
357+
RandomVariableInterface brownianIncrement = brownianMotion.getIncrement(0, 0);
358+
359+
RandomVariableDifferentiableInterface x = randomVariableFactory.createRandomVariable(1.0);
360+
361+
RandomVariableInterface y = x.mult(brownianIncrement.sub(brownianIncrement.average())).average().mult(brownianIncrement);
362+
363+
Map<Long, RandomVariableInterface> aadGradient = ((RandomVariableDifferentiableInterface) y).getGradient();
364+
365+
RandomVariableInterface derivative = aadGradient.get(x.getID());
366+
367+
System.out.println(randomVariableFactory.toString());
368+
System.out.println(y.getAverage());
369+
System.out.println(brownianIncrement.squared().getAverage());
370+
System.out.println((aadGradient.get(x.getID())).getAverage());
371+
372+
Assert.assertEquals(0.0, y.getAverage(), 1E-8);
373+
374+
// Test RandomVariableDifferentiableAADFactory (the others currently fail)
375+
if(randomVariableFactory instanceof RandomVariableDifferentiableAADFactory) Assert.assertEquals(0.0, derivative.getAverage(), 1E-8);
376+
}
377+
349378
@Test
350379
public void testRandomVariableGradientBigSumWithConstants(){
351380

0 commit comments

Comments
 (0)