From e8d6e2ec530e7c2c7ecf63bc77bab9d0c26b74fb Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Fri, 20 Oct 2023 11:15:25 +0800 Subject: [PATCH 1/5] adapt DeepCFD to hydra --- docs/zh/examples/deepcfd.md | 52 +++++--- examples/deepcfd/conf/deepcfd.yaml | 61 +++++++++ examples/deepcfd/deepcfd.py | 200 +++++++++++++++++++++-------- 3 files changed, 241 insertions(+), 72 deletions(-) create mode 100644 examples/deepcfd/conf/deepcfd.yaml diff --git a/docs/zh/examples/deepcfd.md b/docs/zh/examples/deepcfd.md index cf1047885..1787551a9 100644 --- a/docs/zh/examples/deepcfd.md +++ b/docs/zh/examples/deepcfd.md @@ -1,5 +1,17 @@ # DeepCFD(Deep Computational Fluid Dynamics) +=== "模型训练命令" + + ``` sh + # linux + wget wget -P ./datasets/ https://paddle-org.bj.bcebos.com/paddlescience/datasets/DeepCFD/dataX.pkl + wget wget -P ./datasets/ https://paddle-org.bj.bcebos.com/paddlescience/datasets/DeepCFD/dataY.pkl + # windows + # curl wget -P ./datasets/ https://paddle-org.bj.bcebos.com/paddlescience/datasets/DeepCFD/dataX.pkl --output dataX.pkl + # curl wget -P ./datasets/ https://paddle-org.bj.bcebos.com/paddlescience/datasets/DeepCFD/dataY.pkl --output dataY.pkl + python deepcfd.py + ``` + ## 1. 背景简介 计算流体力学(Computational fluid dynamics, CFD)模拟通过求解 Navier-Stokes 方程(N-S 方程),可以获得流体的各种物理量的分布,如密度、压力和速度等。在微电子系统、土木工程和航空航天等领域应用广泛。 @@ -55,9 +67,9 @@ dataX 和 dataY 都具有相同的维度(Ns,Nc,Nx,Ny),其中第一 我们将数据集以 7:3 的比例划分为训练集和验证集,代码如下: -``` py linenums="201" title="examples/deepcfd/deepcfd.py" +``` py linenums="202" title="examples/deepcfd/deepcfd.py" --8<-- -examples/deepcfd/deepcfd.py:201:220 +examples/deepcfd/deepcfd.py:202:216 --8<-- ``` @@ -75,18 +87,18 @@ examples/deepcfd/deepcfd.py:201:220 模型创建用 PaddleScience 代码表示如下: -``` py linenums="222" title="examples/deepcfd/deepcfd.py" +``` py linenums="218" title="examples/deepcfd/deepcfd.py" --8<-- -examples/deepcfd/deepcfd.py:222:243 +examples/deepcfd/deepcfd.py:218:219 --8<-- ``` ### 3.3 约束构建 本案例基于数据驱动的方法求解问题,因此需要使用 PaddleScience 内置的 `SupervisedConstraint` 构建监督约束。在定义约束之前,需要首先指定监督约束中用于数据加载的各个参数,代码如下: -``` py linenums="244" title="examples/deepcfd/deepcfd.py" +``` py linenums="234" title="examples/deepcfd/deepcfd.py" --8<-- -examples/deepcfd/deepcfd.py:244:294 +examples/deepcfd/deepcfd.py:234:264 --8<-- ``` `SupervisedConstraint` 的第一个参数是数据的加载方式,这里填入相关数据的变量名。 @@ -97,36 +109,36 @@ examples/deepcfd/deepcfd.py:244:294 在监督约束构建完毕之后,以我们刚才的命名为关键字,封装到一个字典中,方便后续访问。 -``` py linenums="295" title="examples/deepcfd/deepcfd.py" +``` py linenums="266" title="examples/deepcfd/deepcfd.py" --8<-- -examples/deepcfd/deepcfd.py:295:297 +examples/deepcfd/deepcfd.py:266:267 --8<-- ``` ### 3.4 超参数设定 -接下来我们需要指定训练轮数和学习率,此处我们按实验经验,使用一千轮训练轮数。 +接下来需要在配置文件中指定训练轮数,此处我们按实验经验,使用一千轮训练轮数。 -``` py linenums="298" title="examples/deepcfd/deepcfd.py" +``` py linenums="47" title="examples/deepcfd/conf/deepcfd.yaml" --8<-- -examples/deepcfd/deepcfd.py:298:301 +examples/deepcfd/conf/deepcfd.yaml:47:51 --8<-- ``` ### 3.5 优化器构建 -训练过程会调用优化器来更新模型参数,此处选择较为常用的 `Adam` 优化器,学习率设置为 0.001。 +训练过程会调用优化器来更新模型参数,此处选择较为常用的 `Adam` 优化器,学习率设置为 0.001,权值衰减设置为 0.005。 -``` py linenums="302" title="examples/deepcfd/deepcfd.py" +``` py linenums="269" title="examples/deepcfd/deepcfd.py" --8<-- -examples/deepcfd/deepcfd.py:302:304 +examples/deepcfd/deepcfd.py:269:272 --8<-- ``` ### 3.6 评估器构建 在训练过程中通常会按一定轮数间隔,用验证集评估当前模型的训练情况,我们使用 `ppsci.validate.SupervisedValidator` 构建评估器。 -``` py linenums="305" title="examples/deepcfd/deepcfd.py" +``` py linenums="274" title="examples/deepcfd/deepcfd.py" --8<-- -examples/deepcfd/deepcfd.py:305:346 +examples/deepcfd/deepcfd.py:274:314 --8<-- ``` @@ -137,18 +149,18 @@ examples/deepcfd/deepcfd.py:305:346 ### 3.7 模型训练、评估 完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`,然后启动训练、评估。 -``` py linenums="347" title="examples/deepcfd/deepcfd.py" +``` py linenums="316" title="examples/deepcfd/deepcfd.py" --8<-- -examples/deepcfd/deepcfd.py:347:364 +examples/deepcfd/deepcfd.py:316:335 --8<-- ``` ### 3.8 结果可视化 使用 matplotlib 绘制相同输入参数时的 OpenFOAM 和 DeepCFD 的计算结果,进行对比。这里绘制了验证集第 0 个数据的计算结果。 -``` py linenums="365" title="examples/deepcfd/deepcfd.py" +``` py linenums="337" title="examples/deepcfd/deepcfd.py" --8<-- -examples/deepcfd/deepcfd.py:365:371 +examples/deepcfd/deepcfd.py:337:342 --8<-- ``` diff --git a/examples/deepcfd/conf/deepcfd.yaml b/examples/deepcfd/conf/deepcfd.yaml new file mode 100644 index 000000000..ad932d13b --- /dev/null +++ b/examples/deepcfd/conf/deepcfd.yaml @@ -0,0 +1,61 @@ +hydra: + run: + # dynamic output directory according to running time and override name + dir: outputs_deepcfd/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} # name of logfile + chdir: false # keep current working direcotry unchaned + config: + override_dirname: + exclude_keys: + - TRAIN.checkpoint_path + - TRAIN.pretrained_model_path + - EVAL.pretrained_model_path + - mode + - output_dir + - log_freq + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: train # running mode: train/eval +seed: 2023 +output_dir: ${hydra:run.dir} +log_freq: 20 + +# set data file path +DATAX_PATH: ./data/dataX.pkl +DATAY_PATH: ./data/dataY.pkl +SLIPT_RATIO: 0.7 # slipt dataset to train dataset and test datatset +SAMPLE_SIZE: 981 # the shape of dataX and dataY is [SAMPLE_SIZE, CHANNEL_SIZE, X_SIZE, Y_SIZE] +CHANNEL_SIZE: 3 +X_SIZE: 172 +Y_SIZE: 79 + +# model settings +MODEL: + unetex: + in_channel: 3 + out_channel: 3 + kernel_size: 5 + filters: [8, 16, 32, 32] + weight_norm: false + batch_norm: false + +# training settings +TRAIN: + epochs: 1000 + learning_rate: 0.001 + weight_decay: 0.005 + eval_during_train: true + eval_freq: 50 + batch_size: 64 + pretrained_model_path: null + checkpoint_path: null + +EVAL: + pretrained_model_path: null + eval_with_no_grad: true + batch_size: 8 diff --git a/examples/deepcfd/deepcfd.py b/examples/deepcfd/deepcfd.py index f2387151b..2ec834a45 100644 --- a/examples/deepcfd/deepcfd.py +++ b/examples/deepcfd/deepcfd.py @@ -18,11 +18,12 @@ from typing import List from typing import Tuple +import hydra import numpy as np from matplotlib import pyplot as plt +from omegaconf import DictConfig import ppsci -from ppsci.utils import config from ppsci.utils import logger @@ -192,66 +193,36 @@ def predict_and_save_plot( plt.tight_layout() plt.show() plt.savefig( - os.path.join(PLOT_DIR, f"cfd_{index}.png"), + os.path.join(plot_dir, f"cfd_{index}.png"), bbox_inches="tight", ) -if __name__ == "__main__": - args = config.parse_args() - ppsci.utils.misc.set_random_seed(42) - - OUTPUT_DIR = "./output_deepCFD/" if args.output_dir is None else args.output_dir - +def train(cfg: DictConfig): + # set random seed for reproducibility + ppsci.utils.misc.set_random_seed(cfg.seed) # initialize logger - logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") + logger.init_logger("ppsci", os.path.join(cfg.output_dir, "train.log"), "info") # initialize datasets - DATASET_PATH = "./datasets/" - with open(os.path.join(DATASET_PATH, "dataX.pkl"), "rb") as file: + with open(cfg.DATAX_PATH, "rb") as file: x = pickle.load(file) - with open(os.path.join(DATASET_PATH, "dataY.pkl"), "rb") as file: + with open(cfg.DATAY_PATH, "rb") as file: y = pickle.load(file) # slipt dataset to train dataset and test datatset - SLIPT_RATIO = 0.7 - train_dataset, test_dataset = split_tensors(x, y, ratio=SLIPT_RATIO) + train_dataset, test_dataset = split_tensors(x, y, ratio=cfg.SLIPT_RATIO) train_x, train_y = train_dataset test_x, test_y = test_dataset - # initialize parameters - IN_CHANNELS = 3 - OUT_CHANNELS = 3 - KERNEL_SIZE = 5 - FILTERS = (8, 16, 32, 32) - BATCH_NORM = False - WEIGHT_NORM = False - WEIGHT_DECAY = 0.005 - BATCH_SIZE = 64 - # initialize model - model = ppsci.arch.UNetEx( - "input", - "output", - IN_CHANNELS, - OUT_CHANNELS, - KERNEL_SIZE, - FILTERS, - weight_norm=WEIGHT_NORM, - batch_norm=BATCH_NORM, - ) - - # the shape of x and y is [SAMPLE_SIZE, CHANNEL_SIZE, X_SIZE, Y_SIZE] - SAMPLE_SIZE = 981 - CHANNEL_SIZE = 3 - X_SIZE = 172 - Y_SIZE = 79 + model = ppsci.arch.UNetEx("input", "output", **cfg.MODEL.unetex) CHANNELS_WEIGHTS = np.reshape( np.sqrt( np.mean( np.transpose(y, (0, 2, 3, 1)).reshape( - (SAMPLE_SIZE * X_SIZE * Y_SIZE, CHANNEL_SIZE) + (cfg.SAMPLE_SIZE * cfg.X_SIZE * cfg.Y_SIZE, cfg.CHANNEL_SIZE) ) ** 2, axis=0, @@ -281,7 +252,7 @@ def loss_expr( "input": {"input": train_x}, "label": {"output": train_y}, }, - "batch_size": BATCH_SIZE, + "batch_size": cfg.TRAIN.batch_size, "sampler": { "name": "BatchSampler", "drop_last": False, @@ -295,12 +266,10 @@ def loss_expr( # maunally build constraint constraint = {sup_constraint.name: sup_constraint} - # set training hyper-parameters - EPOCHS = 1000 - LEARNING_RATE = 0.001 - # initialize Adam optimizer - optimizer = ppsci.optimizer.Adam(LEARNING_RATE, weight_decay=WEIGHT_DECAY)(model) + optimizer = ppsci.optimizer.Adam( + cfg.TRAIN.learning_rate, weight_decay=cfg.TRAIN.weight_decay + )(model) # manually build validator eval_dataloader_cfg = { @@ -348,12 +317,15 @@ def metric_expr( solver = ppsci.solver.Solver( model, constraint, - OUTPUT_DIR, + cfg.output_dir, optimizer, - epochs=EPOCHS, - eval_during_train=True, - eval_freq=50, + epochs=cfg.TRAIN.epochs, + eval_during_train=cfg.TRAIN.eval_during_train, + eval_freq=cfg.TRAIN.eval_freq, + seed=cfg.seed, validator=validator, + checkpoint_path=cfg.TRAIN.checkpoint_path, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, ) # train model @@ -362,9 +334,133 @@ def metric_expr( # evaluate after finished training solver.eval() - PLOT_DIR = os.path.join(OUTPUT_DIR, "visual") + PLOT_DIR = os.path.join(cfg.output_dir, "visual") + os.makedirs(PLOT_DIR, exist_ok=True) + VISU_INDEX = 0 + + # visualize prediction after finished training + predict_and_save_plot(test_x, test_y, VISU_INDEX, solver, PLOT_DIR) + + +def evaluate(cfg: DictConfig): + # set random seed for reproducibility + ppsci.utils.misc.set_random_seed(cfg.seed) + # initialize logger + logger.init_logger("ppsci", os.path.join(cfg.output_dir, "train.log"), "info") + + # initialize datasets + with open(cfg.DATAX_PATH, "rb") as file: + x = pickle.load(file) + with open(cfg.DATAY_PATH, "rb") as file: + y = pickle.load(file) + + # slipt dataset to train dataset and test datatset + train_dataset, test_dataset = split_tensors(x, y, ratio=cfg.SLIPT_RATIO) + train_x, train_y = train_dataset + test_x, test_y = test_dataset + + # initialize model + model = ppsci.arch.UNetEx("input", "output", **cfg.MODEL.unetex) + + CHANNELS_WEIGHTS = np.reshape( + np.sqrt( + np.mean( + np.transpose(y, (0, 2, 3, 1)).reshape( + (cfg.SAMPLE_SIZE * cfg.X_SIZE * cfg.Y_SIZE, cfg.CHANNEL_SIZE) + ) + ** 2, + axis=0, + ) + ), + (1, -1, 1, 1), + ) + + # define loss + def loss_expr( + output_dict: Dict[str, np.ndarray], + label_dict: Dict[str, np.ndarray] = None, + weight_dict: Dict[str, np.ndarray] = None, + ) -> float: + output = output_dict["output"] + y = label_dict["output"] + loss_u = (output[:, 0:1, :, :] - y[:, 0:1, :, :]) ** 2 + loss_v = (output[:, 1:2, :, :] - y[:, 1:2, :, :]) ** 2 + loss_p = (output[:, 2:3, :, :] - y[:, 2:3, :, :]).abs() + loss = (loss_u + loss_v + loss_p) / CHANNELS_WEIGHTS + return loss.sum() + + # manually build validator + eval_dataloader_cfg = { + "dataset": { + "name": "NamedArrayDataset", + "input": {"input": test_x}, + "label": {"output": test_y}, + }, + "batch_size": cfg.EVAL.batch_size, + "sampler": { + "name": "BatchSampler", + "drop_last": False, + "shuffle": False, + }, + } + + def metric_expr( + output_dict: Dict[str, np.ndarray], + label_dict: Dict[str, np.ndarray] = None, + weight_dict: Dict[str, np.ndarray] = None, + ) -> Dict[str, float]: + output = output_dict["output"] + y = label_dict["output"] + total_mse = ((output - y) ** 2).sum() / len(test_x) + ux_mse = ((output[:, 0, :, :] - test_y[:, 0, :, :]) ** 2).sum() / len(test_x) + uy_mse = ((output[:, 1, :, :] - test_y[:, 1, :, :]) ** 2).sum() / len(test_x) + p_mse = ((output[:, 2, :, :] - test_y[:, 2, :, :]) ** 2).sum() / len(test_x) + return { + "Total_MSE": total_mse, + "Ux_MSE": ux_mse, + "Uy_MSE": uy_mse, + "p_MSE": p_mse, + } + + sup_validator = ppsci.validate.SupervisedValidator( + eval_dataloader_cfg, + ppsci.loss.FunctionalLoss(loss_expr), + {"output": lambda out: out["output"]}, + {"MSE": ppsci.metric.FunctionalMetric(metric_expr)}, + name="mse_validator", + ) + validator = {sup_validator.name: sup_validator} + + # initialize solver + solver = ppsci.solver.Solver( + model, + output_dir=cfg.output_dir, + seed=cfg.seed, + validator=validator, + pretrained_model_path=cfg.EVAL.pretrained_model_path, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + ) + + # evaluate + solver.eval() + + PLOT_DIR = os.path.join(cfg.output_dir, "visual") os.makedirs(PLOT_DIR, exist_ok=True) VISU_INDEX = 0 # visualize prediction after finished training predict_and_save_plot(test_x, test_y, VISU_INDEX, solver, PLOT_DIR) + + +@hydra.main(version_base=None, config_path="./conf", config_name="deepcfd.yaml") +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + else: + raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + + +if __name__ == "__main__": + main() From d0de9e8293e17ccc08639d2dade07dfdfaf255e4 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Mon, 23 Oct 2023 10:58:24 +0800 Subject: [PATCH 2/5] fix --- examples/deepcfd/conf/deepcfd.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/deepcfd/conf/deepcfd.yaml b/examples/deepcfd/conf/deepcfd.yaml index ad932d13b..af9308f38 100644 --- a/examples/deepcfd/conf/deepcfd.yaml +++ b/examples/deepcfd/conf/deepcfd.yaml @@ -26,8 +26,8 @@ output_dir: ${hydra:run.dir} log_freq: 20 # set data file path -DATAX_PATH: ./data/dataX.pkl -DATAY_PATH: ./data/dataY.pkl +DATAX_PATH: ./datasets/dataX.pkl +DATAY_PATH: ./datasets/dataY.pkl SLIPT_RATIO: 0.7 # slipt dataset to train dataset and test datatset SAMPLE_SIZE: 981 # the shape of dataX and dataY is [SAMPLE_SIZE, CHANNEL_SIZE, X_SIZE, Y_SIZE] CHANNEL_SIZE: 3 From 8f3eb12c8b4499acfa0eb5d97a87830cc90aa1d6 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Fri, 27 Oct 2023 10:29:48 +0800 Subject: [PATCH 3/5] fix --- docs/zh/examples/deepcfd.md | 12 ++++++------ examples/deepcfd/conf/deepcfd.yaml | 15 ++++++++------- examples/deepcfd/deepcfd.py | 12 +++++------- 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/docs/zh/examples/deepcfd.md b/docs/zh/examples/deepcfd.md index 1787551a9..3b763d433 100644 --- a/docs/zh/examples/deepcfd.md +++ b/docs/zh/examples/deepcfd.md @@ -4,11 +4,11 @@ ``` sh # linux - wget wget -P ./datasets/ https://paddle-org.bj.bcebos.com/paddlescience/datasets/DeepCFD/dataX.pkl - wget wget -P ./datasets/ https://paddle-org.bj.bcebos.com/paddlescience/datasets/DeepCFD/dataY.pkl + wget -P ./datasets/ https://paddle-org.bj.bcebos.com/paddlescience/datasets/DeepCFD/dataX.pkl + wget -P ./datasets/ https://paddle-org.bj.bcebos.com/paddlescience/datasets/DeepCFD/dataY.pkl # windows - # curl wget -P ./datasets/ https://paddle-org.bj.bcebos.com/paddlescience/datasets/DeepCFD/dataX.pkl --output dataX.pkl - # curl wget -P ./datasets/ https://paddle-org.bj.bcebos.com/paddlescience/datasets/DeepCFD/dataY.pkl --output dataY.pkl + # curl -o ./datasets/dataX.pkl https://paddle-org.bj.bcebos.com/paddlescience/datasets/DeepCFD/dataX.pkl + # curl -o ./datasets/dataX.pkl https://paddle-org.bj.bcebos.com/paddlescience/datasets/DeepCFD/dataY.pkl python deepcfd.py ``` @@ -118,14 +118,14 @@ examples/deepcfd/deepcfd.py:266:267 ### 3.4 超参数设定 接下来需要在配置文件中指定训练轮数,此处我们按实验经验,使用一千轮训练轮数。 -``` py linenums="47" title="examples/deepcfd/conf/deepcfd.yaml" +``` yaml linenums="47" title="examples/deepcfd/conf/deepcfd.yaml" --8<-- examples/deepcfd/conf/deepcfd.yaml:47:51 --8<-- ``` ### 3.5 优化器构建 -训练过程会调用优化器来更新模型参数,此处选择较为常用的 `Adam` 优化器,学习率设置为 0.001,权值衰减设置为 0.005。 +训练过程会调用优化器来更新模型参数,此处选择较为常用的 `Adam` 优化器,学习率设置为 0.001,权重衰减设置为 0.005。 ``` py linenums="269" title="examples/deepcfd/deepcfd.py" --8<-- diff --git a/examples/deepcfd/conf/deepcfd.yaml b/examples/deepcfd/conf/deepcfd.yaml index af9308f38..e7a5a488c 100644 --- a/examples/deepcfd/conf/deepcfd.yaml +++ b/examples/deepcfd/conf/deepcfd.yaml @@ -36,13 +36,14 @@ Y_SIZE: 79 # model settings MODEL: - unetex: - in_channel: 3 - out_channel: 3 - kernel_size: 5 - filters: [8, 16, 32, 32] - weight_norm: false - batch_norm: false + input_key: "input" + output_key: "output" + in_channel: 3 + out_channel: 3 + kernel_size: 5 + filters: [8, 16, 32, 32] + weight_norm: false + batch_norm: false # training settings TRAIN: diff --git a/examples/deepcfd/deepcfd.py b/examples/deepcfd/deepcfd.py index 2ec834a45..68c4e74ca 100644 --- a/examples/deepcfd/deepcfd.py +++ b/examples/deepcfd/deepcfd.py @@ -216,7 +216,7 @@ def train(cfg: DictConfig): test_x, test_y = test_dataset # initialize model - model = ppsci.arch.UNetEx("input", "output", **cfg.MODEL.unetex) + model = ppsci.arch.UNetEx(**cfg.MODEL) CHANNELS_WEIGHTS = np.reshape( np.sqrt( @@ -336,17 +336,16 @@ def metric_expr( PLOT_DIR = os.path.join(cfg.output_dir, "visual") os.makedirs(PLOT_DIR, exist_ok=True) - VISU_INDEX = 0 # visualize prediction after finished training - predict_and_save_plot(test_x, test_y, VISU_INDEX, solver, PLOT_DIR) + predict_and_save_plot(test_x, test_y, 0, solver, PLOT_DIR) def evaluate(cfg: DictConfig): # set random seed for reproducibility ppsci.utils.misc.set_random_seed(cfg.seed) # initialize logger - logger.init_logger("ppsci", os.path.join(cfg.output_dir, "train.log"), "info") + logger.init_logger("ppsci", os.path.join(cfg.output_dir, "eval.log"), "info") # initialize datasets with open(cfg.DATAX_PATH, "rb") as file: @@ -360,7 +359,7 @@ def evaluate(cfg: DictConfig): test_x, test_y = test_dataset # initialize model - model = ppsci.arch.UNetEx("input", "output", **cfg.MODEL.unetex) + model = ppsci.arch.UNetEx(**cfg.MODEL) CHANNELS_WEIGHTS = np.reshape( np.sqrt( @@ -446,10 +445,9 @@ def metric_expr( PLOT_DIR = os.path.join(cfg.output_dir, "visual") os.makedirs(PLOT_DIR, exist_ok=True) - VISU_INDEX = 0 # visualize prediction after finished training - predict_and_save_plot(test_x, test_y, VISU_INDEX, solver, PLOT_DIR) + predict_and_save_plot(test_x, test_y, 0, solver, PLOT_DIR) @hydra.main(version_base=None, config_path="./conf", config_name="deepcfd.yaml") From cc48c8d756fdb11b389544ee4961af4ef4d3020f Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Fri, 27 Oct 2023 10:39:25 +0800 Subject: [PATCH 4/5] fix --- docs/zh/examples/deepcfd.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/zh/examples/deepcfd.md b/docs/zh/examples/deepcfd.md index 3b763d433..500197967 100644 --- a/docs/zh/examples/deepcfd.md +++ b/docs/zh/examples/deepcfd.md @@ -120,7 +120,7 @@ examples/deepcfd/deepcfd.py:266:267 ``` yaml linenums="47" title="examples/deepcfd/conf/deepcfd.yaml" --8<-- -examples/deepcfd/conf/deepcfd.yaml:47:51 +examples/deepcfd/conf/deepcfd.yaml:47:52 --8<-- ``` From 34544a3d322a44eec845f187155a25269ff5e96d Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Fri, 27 Oct 2023 11:04:56 +0800 Subject: [PATCH 5/5] fix --- examples/deepcfd/deepcfd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/deepcfd/deepcfd.py b/examples/deepcfd/deepcfd.py index 68c4e74ca..cb4c042d4 100644 --- a/examples/deepcfd/deepcfd.py +++ b/examples/deepcfd/deepcfd.py @@ -278,7 +278,7 @@ def loss_expr( "input": {"input": test_x}, "label": {"output": test_y}, }, - "batch_size": 8, + "batch_size": cfg.EVAL.batch_size, "sampler": { "name": "BatchSampler", "drop_last": False,