Skip to content

Commit

Permalink
Strict type checking and py.typed
Browse files Browse the repository at this point in the history
  • Loading branch information
Avasam committed Feb 6, 2025
1 parent 57b1e74 commit 3e107e1
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 76 deletions.
196 changes: 132 additions & 64 deletions jaraco/path/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,71 +4,106 @@

from __future__ import annotations

import os
import re
import itertools
import functools
import calendar
import contextlib
import logging
import ctypes
import datetime
import functools
import glob
import tempfile
import platform
import ctypes
import importlib
import itertools
import logging
import os
import pathlib
from typing import TYPE_CHECKING, Mapping, Protocol, Union, runtime_checkable
import platform
import re
import tempfile
from collections.abc import Callable, Generator, Iterable
from typing import (
TYPE_CHECKING,
AnyStr,
List,
Literal,
Mapping,
Protocol,
Sized,
TypeVar,
Union,
no_type_check,
overload,
runtime_checkable,
)

if TYPE_CHECKING:
from _typeshed import (
FileDescriptorOrPath,
GenericPath,
ReadableBuffer,
StrOrBytesPath,
StrPath,
SupportsRead,
Unused,
)
from typing_extensions import Self

_StrOrBytesPathT = TypeVar("_StrOrBytesPathT", bound=StrOrBytesPath)
_SizedT = TypeVar("_SizedT", bound=Sized)


log = logging.getLogger(__name__)


def get_unique_pathname(path, root=''):
def get_unique_pathname(path: StrPath, root: StrPath = '') -> str:
"""Return a pathname possibly with a number appended to it so that it is
unique in the directory."""
path = os.path.join(root, path)
# consider the path supplied, then the paths with numbers appended
potentialPaths = itertools.chain((path,), __get_numbered_paths(path))
potentialPaths: Iterable[str] = itertools.chain((path,), __get_numbered_paths(path))
potentialPaths = itertools.filterfalse(os.path.exists, potentialPaths)
return next(potentialPaths)


def __get_numbered_paths(filepath):
def __get_numbered_paths(filepath: StrPath) -> map[str]:
"""Append numbers in sequential order to the filename or folder name
Numbers should be appended before the extension on a filename."""
format = '%s (%%d)%s' % splitext_files_only(filepath)
return map(lambda n: format % n, itertools.count(1))


def splitext_files_only(filepath):
@overload
def splitext_files_only(filepath: str) -> tuple[str, str]: ...
@overload
def splitext_files_only(
filepath: StrPath,
) -> tuple[StrPath, Literal['']] | tuple[str, str]: ...
def splitext_files_only(
filepath: StrPath,
) -> tuple[StrPath, Literal['']] | tuple[str, str]:
"Custom version of splitext that doesn't perform splitext on directories"
return (filepath, '') if os.path.isdir(filepath) else os.path.splitext(filepath)


def set_time(filename, mod_time):
def set_time(filename: FileDescriptorOrPath, mod_time: datetime.datetime) -> None:
"""
Set the modified time of a file
"""
log.debug('Setting modified time to %s', mod_time)
mtime = calendar.timegm(mod_time.utctimetuple())
mtime: float = calendar.timegm(mod_time.utctimetuple())
# utctimetuple discards microseconds, so restore it (for consistency)
mtime += mod_time.microsecond / 1000000
atime = os.stat(filename).st_atime
os.utime(filename, (atime, mtime))


def get_time(filename):
def get_time(filename: FileDescriptorOrPath) -> datetime.datetime:
"""
Get the modified time for a file as a datetime instance
"""
ts = os.stat(filename).st_mtime
return datetime.datetime.utcfromtimestamp(ts)


def insert_before_extension(filename, content):
def insert_before_extension(filename: StrPath, content: str) -> str:
"""
Given a filename and some content, insert the content just before
the extension.
Expand All @@ -81,7 +116,7 @@ def insert_before_extension(filename, content):
return ''.join(parts)


class DirectoryStack(list):
class DirectoryStack(List[str]):
r"""
...
Expand Down Expand Up @@ -109,25 +144,27 @@ class DirectoryStack(list):
True
"""

def pushd(self, new_dir):
def pushd(self, new_dir: str) -> None:
self.append(os.getcwd())
os.chdir(new_dir)

def popd(self):
def popd(self) -> str:
res = os.getcwd()
os.chdir(self.pop())
return res

@contextlib.contextmanager
def context(self, new_dir):
def context(self, new_dir: str) -> Generator[None]:
self.pushd(new_dir)
try:
yield
finally:
self.popd()


def recursive_glob(root, spec):
def recursive_glob(
root: GenericPath[AnyStr], spec: GenericPath[AnyStr]
) -> itertools.chain[AnyStr]:
"""
Like iglob, but recurse directories
Expand All @@ -150,16 +187,15 @@ def recursive_glob(root, spec):
return itertools.chain.from_iterable(glob.iglob(spec) for spec in specs)


def encode(name, system='NTFS'):
def encode(name: str, system: Literal['NTFS'] = 'NTFS') -> str:
"""
Encode the name for a suitable name in the given filesystem
>>> encode('Test :1')
'Test _1'
"""
assert system == 'NTFS', 'unsupported filesystem'
special_characters = r'<>:"/\|?*' + ''.join(map(chr, range(32)))
pattern = '|'.join(map(re.escape, special_characters))
pattern = re.compile(pattern)
pattern = re.compile('|'.join(map(re.escape, special_characters)))
return pattern.sub('_', name)


Expand All @@ -172,43 +208,55 @@ class save_to_file:
... assert 'foo' == pathlib.Path(filename).read_text(encoding='utf-8')
"""

def __init__(self, content):
def __init__(self, content: ReadableBuffer) -> None:
self.content = content

def __enter__(self):
def __enter__(self) -> str:
tf = tempfile.NamedTemporaryFile(delete=False)
tf.write(self.content)
tf.close()
self.filename = tf.name
return tf.name

def __exit__(self, type, value, traceback):
def __exit__(self, type: Unused, value: Unused, traceback: Unused) -> None:
os.remove(self.filename)


@contextlib.contextmanager
def tempfile_context(*args, **kwargs):
"""
A wrapper around tempfile.mkstemp to create the file in a context and
delete it after.
"""
fd, filename = tempfile.mkstemp(*args, **kwargs)
os.close(fd)
try:
yield filename
finally:
os.remove(filename)
if TYPE_CHECKING:

@contextlib.contextmanager
def tempfile_context(
suffix: AnyStr | None = None,
prefix: AnyStr | None = None,
dir: GenericPath[AnyStr] | None = None,
text: bool = False,
) -> Generator[AnyStr]: ...

else:

def replace_extension(new_ext, filename):
@contextlib.contextmanager
def tempfile_context(*args, **kwargs):
"""
A wrapper around tempfile.mkstemp to create the file in a context and
delete it after.
"""
fd, filename = tempfile.mkstemp(*args, **kwargs)
os.close(fd)
try:
yield filename
finally:
os.remove(filename)


def replace_extension(new_ext: AnyStr, filename: GenericPath[AnyStr]) -> AnyStr:
"""
>>> replace_extension('.pdf', 'myfile.doc')
'myfile.pdf'
"""
return os.path.splitext(filename)[0] + new_ext


def ExtensionReplacer(new_ext):
def ExtensionReplacer(new_ext: AnyStr) -> Callable[[GenericPath[AnyStr]], AnyStr]:
"""
A reusable function to replace a file's extension with another
Expand All @@ -223,11 +271,13 @@ def ExtensionReplacer(new_ext):
return functools.partial(replace_extension, new_ext)


def ensure_dir_exists(func):
"wrap a function that returns a dir, making sure it exists"
def ensure_dir_exists(
func: Callable[[], _StrOrBytesPathT],
) -> Callable[[], _StrOrBytesPathT]:
"""wrap a function that returns a dir, making sure it exists"""

@functools.wraps(func)
def make_if_not_present():
def make_if_not_present() -> _StrOrBytesPathT:
dir = func()
if not os.path.isdir(dir):
os.makedirs(dir)
Expand All @@ -236,7 +286,11 @@ def make_if_not_present():
return make_if_not_present


def read_chunks(file, chunk_size=2048, update_func=lambda x: None):
def read_chunks(
file: SupportsRead[_SizedT],
chunk_size: int = 2048,
update_func: Callable[[int], Unused] = lambda x: None,
) -> Generator[_SizedT]:
"""
Read file in chunks of size chunk_size (or smaller).
If update_func is specified, call it on every chunk with the amount
Expand All @@ -250,7 +304,7 @@ def read_chunks(file, chunk_size=2048, update_func=lambda x: None):
yield res


def is_hidden(path) -> bool:
def is_hidden(path: str) -> bool:
"""
Check whether a file is presumed hidden, either because
the pathname starts with dot or because the platform
Expand All @@ -262,23 +316,26 @@ def is_hidden(path) -> bool:
full_path = os.path.abspath(path)
name = os.path.basename(full_path)

def no(path):
def no(path: Unused) -> Literal[False]:
return False

platform_hidden = globals().get('is_hidden_' + platform.system(), no)
return name.startswith('.') or platform_hidden(full_path)


def is_hidden_Windows(path):
@no_type_check # can't assert platform.system in the function, and platform check is done outside
def is_hidden_Windows(path: str) -> bool:
res = ctypes.windll.kernel32.GetFileAttributesW(path)
assert res != -1
return bool(res & 2)


def is_hidden_Darwin(path):
def is_hidden_Darwin(path: str) -> bool:
Foundation = importlib.import_module('Foundation')
url = Foundation.NSURL.fileURLWithPath_(path)
res = url.getResourceValue_forKey_error_(None, Foundation.NSURLIsHiddenKey, None)
res: tuple[bool, bool] = url.getResourceValue_forKey_error_(
None, Foundation.NSURLIsHiddenKey, None
)
return res[1]


Expand All @@ -293,11 +350,11 @@ class Symlink(str):

@runtime_checkable
class TreeMaker(Protocol):
def __truediv__(self, other, /) -> Self: ...
def mkdir(self, *, exist_ok) -> object: ...
def write_text(self, content, /, *, encoding) -> object: ...
def write_bytes(self, content, /) -> object: ...
def symlink_to(self, target, /) -> object: ...
def __truediv__(self, key: StrPath, /) -> Self: ...
def mkdir(self, *, exist_ok: bool) -> object: ...
def write_text(self, data: str, /, *, encoding: str | None) -> object: ...
def write_bytes(self, data: ReadableBuffer, /) -> object: ...
def symlink_to(self, target: StrOrBytesPath, /) -> object: ...


def _ensure_tree_maker(obj: str | TreeMaker) -> TreeMaker:
Expand All @@ -307,7 +364,7 @@ def _ensure_tree_maker(obj: str | TreeMaker) -> TreeMaker:
def build(
spec: FilesSpec,
prefix: str | TreeMaker = pathlib.Path(),
):
) -> None:
"""
Build a set of files/directories, as described by the spec.
Expand Down Expand Up @@ -359,7 +416,14 @@ def _(content: Symlink, path: TreeMaker) -> None:
path.symlink_to(content)


class Recording:
class _ConcatenablePurePathLike(Protocol):
def __truediv__(self, other: StrPath, /) -> Self: ...
# A Protocol can't inherit from os.PathLike until Python 3.14,
# so just copying __fspath__ from typeshed's PathLike[str] instead
def __fspath__(self) -> str: ...


class Recording(TreeMaker):
"""
A TreeMaker object that records everything that would be written.
Expand All @@ -369,20 +433,24 @@ class Recording:
['foo/foo1.txt', 'bar.txt']
"""

def __init__(self, loc=pathlib.PurePosixPath(), record=None):
def __init__(
self,
loc: _ConcatenablePurePathLike = pathlib.PurePosixPath(),
record: list[str] | None = None,
) -> None:
self.loc = loc
self.record = record if record is not None else []

def __truediv__(self, other):
return Recording(self.loc / other, self.record)
def __truediv__(self, other: StrPath) -> Self:
return type(self)(self.loc / other, self.record)

def write_text(self, content, **kwargs):
def write_text(self, content: Unused, **kwargs: Unused) -> None:
self.record.append(str(self.loc))

write_bytes = write_text

def mkdir(self, **kwargs):
def mkdir(self, **kwargs: Unused) -> None:
return

def symlink_to(self, target):
def symlink_to(self, target: Unused) -> None:
pass
Empty file added jaraco/path/py.typed
Empty file.
Loading

0 comments on commit 3e107e1

Please # to comment.