Skip to content

Commit 3651f9f

Browse files
authored
Cavity flow PINNs example (#8)
* Cavity flow PINNs example * Address reviewer comments * Explicitly state the spatial domain in the introduction
1 parent 765bc1d commit 3651f9f

File tree

6 files changed

+534
-0
lines changed

6 files changed

+534
-0
lines changed
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
2+
# Cavity flow with Physics\-Informed Neural Networks
3+
4+
Solve cavity flow governed by 2d steady Navier\-Stokes equations and continuity equation, using a Physics\-Informed Neural Network (PINN).
5+
6+
7+
The 2d, steady Navier\-Stokes equations for an incompressible fluid are:
8+
9+
$$ \frac{\partial u}{\partial x}+\frac{\partial v}{\partial y}=0 $$
10+
11+
$$ u\frac{\partial u}{\partial x}+v\frac{\partial u}{\partial y}+\frac{\partial p}{\partial x}-\frac{1}{Re}\bigg(\frac{\partial^2 u}{\partial x^2 }+\frac{\partial^2 u}{\partial y^2 }\bigg)=0 $$
12+
13+
$$ u\frac{\partial v}{\partial x}+v\frac{\partial v}{\partial y}+\frac{\partial p}{\partial y}-\frac{1}{Re}\bigg(\frac{\partial^2 v}{\partial x^2 }+\frac{\partial^2 v}{\partial y^2 }\bigg)=0 $$
14+
15+
$(x,y)$ are the spatial coordinates, $(u,v)$ is the fluid velocity, $p$ is the pressure and $Re$ is the Reynolds number.
16+
17+
18+
In order to automatically satisfy the continuity equation we use the stream function $\psi$, such that $u=\partial \psi /\partial y$ and $v=-\partial \psi /\partial x$. The cavity is defined as the square domain $[0,1]\times [0,1]$. The boundary conditions are $(u,v)=(1,0)$ at the top boundary and $(u,v)=(0,0)$ at the other boundaries. Additionally, $\psi =0$ is assumed on all the boundaries. The Reynolds number is $Re=100$.
19+
20+
21+
The PINNs model takes the spatial coordinates $(x,y)$ as inputs and returns the streamfunction and pressure $(\psi ,p)$ as outputs.
22+
23+
24+
This work is inspired by the following GitHub repo: [https://github.com/okada39/pinn\_cavity](https://github.com/okada39/pinn_cavity)
25+
26+
## Set parameters.
27+
```matlab
28+
Re = 100;
29+
u0 = 1;
30+
```
31+
## Create network
32+
33+
The core network architecture is a standard multi\-layer perceptron (MLP) with `numHiddenUnits=32` and swish activations. We use separate inputs for `x` and `y` because it makes it easier to compute derivatives with respect to these inputs later when imposing the PINNs loss. In addition to the MLP, we use anchor functions to impose the $\psi =0$ boundary condition. For example, the anchor function in $x$ ensures that the boundary condition is strictly enforced by multiplying the unconstrained network estimate for $\psi$ by the function $4x(1-x)$ \-\- which is $0$ at the boundaries (i.e. when $x=0$ or $x=1$ ). The factor $4$ is chosen so that the anchor function has a maximum of one. We include two anchor functions, one for the $x$ \-coordinate and one for the $y$ \-coordinate, then multiply them with the "free" $\psi$ estimation to produce the final output for $\psi$.
34+
35+
```matlab
36+
% Create basic MLP network architecture with two inputs (x,y) and two
37+
% outputs (psi,p).
38+
numHiddenUnits = 32;
39+
net = dlnetwork();
40+
layers = [ featureInputLayer(1, Name="x")
41+
concatenationLayer(1, 2)
42+
fullyConnectedLayer(numHiddenUnits)
43+
swishLayer()
44+
fullyConnectedLayer(numHiddenUnits)
45+
swishLayer()
46+
fullyConnectedLayer(numHiddenUnits)
47+
swishLayer()
48+
fullyConnectedLayer(numHiddenUnits)
49+
swishLayer(Name="swishout")
50+
fullyConnectedLayer(1, Name="psiFree") ];
51+
net = addLayers(net, layers);
52+
net = addLayers(net, fullyConnectedLayer(1, Name="p"));
53+
net = connectLayers(net, "swishout", "p");
54+
net = addInputLayer(net, featureInputLayer(1, Name="y"), Initialize=false);
55+
56+
% Add anchor functions to strictly enforce boundary conditions on the
57+
% streamfunction.
58+
net = addLayers(net, [functionLayer(@(x)4.*x.*(1-x), Name="anchorX", Acceleratable=true); multiplicationLayer(3, Name="psi")]);
59+
net = addLayers(net, functionLayer(@(y)4.*y.*(1-y), Name="anchorY", Acceleratable=true));
60+
net = connectLayers(net, "x", "anchorX");
61+
net = connectLayers(net, "y", "anchorY");
62+
net = connectLayers(net, "anchorY", "psi/in2");
63+
net = connectLayers(net, "psiFree", "psi/in3");
64+
65+
% Make sure outputs are ordered (psi,p).
66+
net.OutputNames = ["psi", "p"];
67+
68+
% Initialize the network and cast to double precision.
69+
net = initialize(net);
70+
net = dlupdate(@double, net);
71+
72+
% Visually inspect the network.
73+
analyzeNetwork(net)
74+
```
75+
![figure_2.png](./images/figure_2.png)
76+
77+
## Create training input
78+
```matlab
79+
numTrainSamples = 1e4;
80+
xyEquation = rand([numTrainSamples 2]);
81+
82+
numBoundarySamples = floor(numTrainSamples/2);
83+
xyTopBottom = rand([numBoundarySamples 2]); % top-bottom boundaries.
84+
xyTopBottom(:, 2) = round(xyTopBottom(:, 2)); % y-position is 0 or 1.
85+
86+
xyLeftRight = rand([numBoundarySamples 2]); % left-right boundaries.
87+
xyLeftRight(:, 1) = round(xyLeftRight(:, 1)); % x-position is 0 or 1.
88+
89+
xyBoundary = cat(1, xyTopBottom, xyLeftRight);
90+
idxPerm = randperm(size(xyBoundary, 1));
91+
xyBoundary = xyBoundary(idxPerm, :);
92+
```
93+
## Create training output
94+
```matlab
95+
zeroVector = zeros([numTrainSamples 1]);
96+
uvBoundary = [zeroVector zeroVector];
97+
uvBoundary(:, 1) = u0.*floor( xyBoundary(:, 2) );
98+
```
99+
## Train the model
100+
101+
Train using the L\-BFGS optimizer, using a GPU is one is available.
102+
103+
```matlab
104+
% Prepare training data.
105+
xyEquation = dlarray(xyEquation);
106+
xyBoundary = dlarray(xyBoundary);
107+
if canUseGPU
108+
xyEquation = gpuArray(xyEquation);
109+
xyBoundary = gpuArray(xyBoundary);
110+
end
111+
112+
% Create training progress plot.
113+
monitor = trainingProgressMonitor();
114+
monitor.XLabel = "Iteration";
115+
monitor.Metrics = ["TotalLoss", "LossEqnX", "LossEqnY", "LossBC"];
116+
groupSubPlot(monitor, "Loss", ["TotalLoss", "LossEqnX", "LossEqnY", "LossBC"])
117+
yscale(monitor, "Loss", "log");
118+
119+
% Train with L-BFGS.
120+
maxIterations = 1e4;
121+
solverState = [];
122+
lossFcn = dlaccelerate(@pinnsLossFunction);
123+
lbfgsLossFcn = @(n)dlfeval(lossFcn, n, xyEquation, xyBoundary, zeroVector, uvBoundary, Re);
124+
for iteration = 1:maxIterations
125+
[net, solverState] = lbfgsupdate(net, lbfgsLossFcn, solverState, NumLossFunctionOutputs=5);
126+
127+
% loss = extractdata(solverState.Loss);
128+
additionalLosses = solverState.AdditionalLossFunctionOutputs;
129+
% additionalLosses = cellfun(@extractdata, additionalLosses);
130+
recordMetrics(monitor, ...
131+
iteration, ...
132+
TotalLoss=solverState.Loss, ...
133+
LossEqnX=additionalLosses{1}, ...
134+
LossEqnY=additionalLosses{2}, ...
135+
LossBC=additionalLosses{3});
136+
end
137+
```
138+
139+
![figure_0.png](./images/figure_0.png)
140+
## Plot predictions
141+
```matlab
142+
% Create test set using meshgrid.
143+
numTestSamples = 100;
144+
x = linspace(0, 1, numTestSamples)';
145+
y = x;
146+
[xt, yt] = meshgrid(x, y);
147+
148+
% Flatten gridpoints and prepare data.
149+
xTest = dlarray(xt(:));
150+
yTest = dlarray(yt(:));
151+
if canUseGPU
152+
xTest = gpuArray(xTest);
153+
yTest = gpuArray(yTest);
154+
end
155+
156+
% Evaluate the network.
157+
[psiTest, pTest, uTest, vTest] = dlfeval(@calculateStreamfunctionPressureAndVelocity, net, xTest, yTest);
158+
159+
% Return predictions to grid and plot.
160+
ut = unflattenAndExtract(uTest, numTestSamples);
161+
vt = unflattenAndExtract(vTest, numTestSamples);
162+
pt = unflattenAndExtract(pTest, numTestSamples);
163+
psit = unflattenAndExtract(psiTest, numTestSamples);
164+
165+
figure;
166+
subplot(2,2,1)
167+
contourf(xt, yt, psit)
168+
colorbar
169+
axis equal
170+
title('psi')
171+
172+
subplot(2,2,2)
173+
contourf(xt, yt, pt)
174+
colorbar
175+
axis equal
176+
title('p')
177+
178+
subplot(2,2,3)
179+
contourf(xt, yt, ut)
180+
colorbar
181+
axis equal
182+
title('u')
183+
184+
subplot(2,2,4)
185+
contourf(xt, yt, vt)
186+
colorbar
187+
axis equal
188+
title('v')
189+
```
190+
191+
![figure_1.png](./images/figure_1.png)
192+
## Loss function and helper functions
193+
```matlab
194+
function [loss, grads, lossEqnX, lossEqnY, lossBC] = pinnsLossFunction(net, xyEquation, xyBoundary, zeroVector, uvBoundary, Re)
195+
196+
% Get model outputs at interior points.
197+
xeq = xyEquation(:, 1);
198+
yeq = xyEquation(:, 2);
199+
[psi, p] = forward(net, xeq, yeq);
200+
201+
% Compute gradients.
202+
u = dljacobian(psi', yeq, 1);
203+
v = -1.*dljacobian(psi', xeq, 1);
204+
205+
ux = dljacobian(u', xeq, 1);
206+
uy = dljacobian(u', yeq, 1);
207+
uxx = dljacobian(ux', xeq, 1);
208+
uyy = dljacobian(uy', yeq, 1);
209+
210+
vx = dljacobian(v', xeq, 1);
211+
vy = dljacobian(v', yeq, 1);
212+
vxx = dljacobian(vx', xeq, 1);
213+
vyy = dljacobian(vy', yeq, 1);
214+
215+
px = dljacobian(p', xeq, 1);
216+
py = dljacobian(p', yeq, 1);
217+
218+
% Momentum equations.
219+
lx = u.*ux + v.*uy + px - (1/Re).*(uxx + uyy);
220+
ly = u.*vx + v.*vy + py - (1/Re).*(vxx + vyy);
221+
222+
% Combine for equation loss.
223+
lossEqnX = logCoshLoss(lx, zeroVector);
224+
lossEqnY = logCoshLoss(ly, zeroVector);
225+
226+
% Get model outputs at boundary points.
227+
xbd = xyBoundary(:, 1);
228+
ybd = xyBoundary(:, 2);
229+
psibd = forward(net, xbd, ybd);
230+
231+
ubd = dljacobian(psibd', ybd, 1);
232+
vbd = -1.*dljacobian(psibd', xbd, 1);
233+
234+
uvbd = cat(2, ubd, vbd);
235+
lossBC = logCoshLoss(uvbd, uvBoundary);
236+
237+
% Total loss and model gradients
238+
loss = lossEqnX + lossEqnY + lossBC;
239+
grads = dlgradient(loss, net.Learnables);
240+
end
241+
242+
function loss = logCoshLoss(y, t)
243+
% log-cosh loss function
244+
e = y - t;
245+
loss = mean( log(cosh(e)), 'all' );
246+
end
247+
248+
function [psi, p, u, v] = calculateStreamfunctionPressureAndVelocity(net, x, y)
249+
% Compute the streamfunction psi, pressure p and velocity (u,v) given
250+
% input positions (x,y).
251+
[psi, p] = forward(net, x, y);
252+
u = dljacobian(psi', y, 1);
253+
v = -1.*dljacobian(psi', x, 1);
254+
end
255+
256+
function x = unflattenAndExtract(xflat, sz)
257+
x = reshape(xflat, [sz sz]);
258+
x = extractdata(x);
259+
end
260+
```
261+
262+
#### Requirements
263+
- [MATLAB ®](https://mathworks.com/products/matlab.html) (R2025a or newer)
264+
- [Deep Learning Toolbox<sup>TM</sup>](https://mathworks.com/products/deep-learning.html)
265+
266+
#### References
267+
[1] [https://github.com/okada39/pinn\_cavity](https://github.com/okada39/pinn_cavity)
268+
269+
#### Community Support
270+
[MATLAB Central](https://www.mathworks.com/matlabcentral)
271+
272+
Copyright 2025 The MathWorks, Inc.

0 commit comments

Comments
 (0)