Skip to content

Commit

Permalink
add unitest
Browse files Browse the repository at this point in the history
  • Loading branch information
CjhHa1 committed Nov 30, 2022
1 parent 3f1719e commit dffe14d
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def partition_cluster(
n: int,
m: int,
filter=[
complete_meshes,
complete_meshes.__func__,
],
) -> list:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_group_operators MODULES test_group_operators)
py_test_modules(test_pattern MODULES test_pattern)
py_test_modules(test_pattern_match MODULES test_pattern_match)
py_test_modules(test_cluster_partition MODULES test_cluster_partition)
py_test_modules(test_convert_to_process_meshes MODULES
test_convert_to_process_meshes)
endif()
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest


class TestClusterPartition(unittest.TestCase):
def test_cluster_partition(self):
clusters = [(5, 8), (1, 8), (4, 8), (16, 8)]
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
ClusterPartitionUtil,
)

device_meshes = []
for cluster in clusters:
n = cluster[0]
m = cluster[1]
device_mesh = ClusterPartitionUtil.partition_cluster(n, m)
device_meshes.append(device_mesh)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest


class TestConvertToProcessMeshes(unittest.TestCase):
def test_convert_to_process_meshes(self):
device_meshes = [[1, 8], [4, 8], [15, 8]]
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
convert_to_process_meshes,
)

process_meshes = []
for device_mesh in device_meshes:
process_mesh = convert_to_process_meshes(device_mesh)
process_meshes.append(process_mesh)


if __name__ == "__main__":
unittest.main()

0 comments on commit dffe14d

Please # to comment.