Skip to content

Commit

Permalink
adding support for remote to local operations in shutil.py::copytree (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
brno32 authored Jun 27, 2023
1 parent deda4d0 commit 285ad28
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/smbclient/shutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def copytree(
source path and the destination path as arguments. By default copy() is used, but any function that supports the
same signature (like copy()) can be used.
In this current form, copytree() only supports remote to remote copies over SMB.
In this current form, copytree() only supports remote to remote copies over SMB, or remote to local copies.
:param src: The source directory to copy.
:param dst: The destination directory to copy to.
Expand All @@ -296,7 +296,11 @@ def copytree(
:return: The dst path.
"""
dir_entries = list(scandir(src, **kwargs))
makedirs(dst, exist_ok=dirs_exist_ok, **kwargs)

if is_remote_path(dst):
makedirs(dst, exist_ok=dirs_exist_ok, **kwargs)
else:
os.makedirs(dst, exist_ok=dirs_exist_ok)

ignored = []
if ignore is not None:
Expand Down
27 changes: 27 additions & 0 deletions tests/test_smbclient_shutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,33 @@ def ignore(name, children):
assert fd.read() == "file3.txt"


def test_copytree_with_local_dst(smb_share, tmp_path):
src_dirname = "%s\\source" % smb_share
dst_dirname = str(tmp_path / "target")

makedirs("%s\\dir1\\subdir1" % src_dirname)
with open_file("%s\\file1.txt" % src_dirname, mode="w") as fd:
fd.write("file1.txt")
with open_file("%s\\dir1\\file2.txt" % src_dirname, mode="w") as fd:
fd.write("file2.txt")
with open_file("%s\\dir1\\subdir1\\file3.txt" % src_dirname, mode="w") as fd:
fd.write("file3.txt")

actual = copytree(src_dirname, dst_dirname)
assert actual == dst_dirname

assert sorted(list(os.listdir(dst_dirname))) == ["dir1", "file1.txt"]
assert sorted(list(os.listdir(os.path.join(dst_dirname, "dir1")))) == ["file2.txt", "subdir1"]
assert sorted(list(os.listdir(os.path.join(dst_dirname, "dir1", "subdir1")))) == ["file3.txt"]

with open(os.path.join(dst_dirname, "file1.txt")) as fd:
assert fd.read() == "file1.txt"
with open(os.path.join(dst_dirname, "dir1", "file2.txt")) as fd:
assert fd.read() == "file2.txt"
with open(os.path.join(dst_dirname, "dir1", "subdir1", "file3.txt")) as fd:
assert fd.read() == "file3.txt"


@pytest.mark.skipif(
os.name != "nt" and not os.environ.get("SMB_FORCE", False), reason="Samba does not update timestamps"
)
Expand Down

0 comments on commit 285ad28

Please # to comment.