Skip to content

Commit

Permalink
A few cleanups to make lib use easier.
Browse files Browse the repository at this point in the history
  • Loading branch information
jlevy committed Nov 28, 2024
1 parent 172b138 commit 6db93ce
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 44 deletions.
4 changes: 2 additions & 2 deletions repren/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .repren import main
from .repren import main, multi_replace, rewrite_file, rewrite_files

__all__ = ["main"]
__all__ = ["main", "rewrite_file", "rewrite_files", "multi_replace"]
130 changes: 89 additions & 41 deletions repren/repren.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,20 +251,22 @@

import argparse
import bisect
from dataclasses import dataclass
import importlib.metadata
import os
import re
import shutil
import sys
from typing import BinaryIO, Callable, List, Match, Optional, Pattern, Tuple, TypeAlias
from dataclasses import dataclass
from typing import BinaryIO, Callable, List, Match, Optional, Pattern, Tuple

# Type aliases for clarity.
PatternType: TypeAlias = Tuple[Pattern[bytes], bytes]
FileHandle: TypeAlias = BinaryIO
MatchType: TypeAlias = Match[bytes]
PatternPair: TypeAlias = Tuple[MatchType, bytes]
TransformFunc: TypeAlias = Callable[[bytes], Tuple[bytes, "_MatchCounts"]]
PatternType = Tuple[Pattern[bytes], bytes]
FileHandle = BinaryIO
MatchType = Match[bytes]
PatternPair = Tuple[MatchType, bytes]
TransformFunc = Callable[[bytes], Tuple[bytes, "_MatchCounts"]]
LogFunc = Callable[[str], None]
FailHandler = Callable[[str, Optional[Exception]], None]

# Get the version from package metadata.
VERSION: str
Expand All @@ -282,17 +284,29 @@
DEFAULT_EXCLUDE_PAT: str = r"\."


def log(op: Optional[str], msg: str) -> None:
if op:
msg = "- %s: %s" % (op, msg)
def no_log(msg: str) -> None:
pass


def print_stderr(msg: str) -> None:
print(msg, file=sys.stderr)


def fail(msg: str) -> None:
def _fail_exit(msg: str, e: Optional[Exception] = None) -> None:
if e:
msg = "%s: %s" % (msg, e) if msg else str(e)
print("error: " + msg, file=sys.stderr)
sys.exit(1)


def _fail_exception(msg: str, e: Optional[Exception] = None) -> None:
raise ValueError(msg) from e


# By default, fail with exceptions in case we want to use this as a library.
_fail: FailHandler = _fail_exception


def safe_decode(b: bytes) -> str:
"""
Safely decode bytes to a string for logging purposes.
Expand All @@ -314,6 +328,7 @@ class _Tally:

_tally: _Tally = _Tally()


# --- String matching ---


Expand All @@ -322,7 +337,9 @@ def _overlap(match1: MatchType, match2: MatchType) -> bool:


def _sort_drop_overlaps(
matches: List[PatternPair], source_name: Optional[str] = None
matches: List[PatternPair],
source_name: Optional[str] = None,
log: LogFunc = no_log,
) -> List[PatternPair]:
"""Select and sort a set of disjoint intervals, omitting ones that overlap."""
non_overlaps: List[PatternPair] = []
Expand All @@ -333,9 +350,9 @@ def _sort_drop_overlaps(
(prev_match, _) = non_overlaps[index - 1]
if _overlap(prev_match, match):
log(
source_name,
"Skipping overlapping match '%s' of '%s' that overlaps '%s' of '%s' on its left"
"- %s: Skipping overlapping match '%s' of '%s' that overlaps '%s' of '%s' on its left"
% (
source_name,
safe_decode(match.group()),
safe_decode(match.re.pattern),
safe_decode(prev_match.group()),
Expand All @@ -347,9 +364,9 @@ def _sort_drop_overlaps(
(next_match, _) = non_overlaps[index]
if _overlap(next_match, match):
log(
source_name,
"Skipping overlapping match '%s' of '%s' that overlaps '%s' of '%s' on its right"
"- %s: Skipping overlapping match '%s' of '%s' that overlaps '%s' of '%s' on its right"
% (
source_name,
safe_decode(match.group()),
safe_decode(match.re.pattern),
safe_decode(next_match.group()),
Expand Down Expand Up @@ -388,14 +405,19 @@ def multi_replace(
patterns: List[PatternType],
is_path: bool = False,
source_name: Optional[str] = None,
log: LogFunc = no_log,
) -> Tuple[bytes, _MatchCounts]:
"""Replace all occurrences in the input given a list of patterns (regex,
replacement), simultaneously, so that no replacement affects any other."""
"""
Replace all occurrences in the input given a list of patterns (regex,
replacement), simultaneously, so that no replacement affects any other.
"""
matches: List[PatternPair] = []
for regex, replacement in patterns:
for match in regex.finditer(input_bytes):
matches.append((match, replacement))
valid_matches: List[PatternPair] = _sort_drop_overlaps(matches, source_name=source_name)
valid_matches: List[PatternPair] = _sort_drop_overlaps(
matches, source_name=source_name, log=log
)
result: bytes = _apply_replacements(input_bytes, valid_matches)

global _tally
Expand Down Expand Up @@ -488,14 +510,17 @@ def all_case_variants(expr: str) -> List[str]:


def make_parent_dirs(path: str) -> str:
"""Ensure parent directories of a file are created as needed."""
dirname = os.path.dirname(path)
if dirname and not os.path.isdir(dirname):
os.makedirs(dirname)
"""
Ensure parent directories of a file are created as needed.
"""
os.makedirs(os.path.dirname(path), exist_ok=True)
return path


def move_file(source_path: str, dest_path: str, clobber: bool = False) -> None:
"""
Move a file, adding a numeric suffix if the destination already exists.
"""
if not clobber:
trailing_num = re.compile(r"(.*)[.]\d+$")
i = 1
Expand All @@ -514,6 +539,9 @@ def transform_stream(
stream_out: BinaryIO,
by_line: bool = False,
) -> _MatchCounts:
"""
Transform a stream of bytes, either line-by-line or at once in memory.
"""
counts = _MatchCounts()
if by_line:
for line in stream_in: # line will be bytes
Expand Down Expand Up @@ -543,8 +571,8 @@ def transform_file(
) -> _MatchCounts:
"""
Transform full contents of file at source_path with specified function,
either line-by-line or at once in memory, writing dest_path atomically and keeping a backup.
Source and destination may be the same path.
either line-by-line or at once in memory, writing dest_path atomically and
keeping a backup. Source and destination may be the same path.
"""
counts = _MatchCounts()
global _tally
Expand Down Expand Up @@ -590,36 +618,41 @@ def transform_file(


def rewrite_file(
path: str, # Paths stay as str
path: str,
patterns: List[PatternType],
do_renames: bool = False,
do_contents: bool = False,
by_line: bool = False,
dry_run: bool = False,
log: LogFunc = no_log,
) -> None:
"""
Rewrite and/or rename the given file, making simultaneous changes according to the
given list of patterns.
"""
# Convert path to bytes for pattern matching, then back to str for filesystem ops.
path_bytes = path.encode("utf-8")
dest_path_bytes = (
multi_replace(path_bytes, patterns, is_path=True)[0] if do_renames else path_bytes
multi_replace(path_bytes, patterns, is_path=True, log=log)[0] if do_renames else path_bytes
)
dest_path = dest_path_bytes.decode("utf-8")

transform = None
if do_contents:
transform = lambda contents: multi_replace(contents, patterns, source_name=path)
transform = lambda contents: multi_replace(contents, patterns, source_name=path, log=log)
counts = transform_file(transform, path, dest_path, by_line=by_line, dry_run=dry_run)
if counts.found > 0:
log("modify", "%s: %s matches" % (path, counts.found))
log("- modify: %s: %s matches" % (path, counts.found))
if dest_path != path:
log("rename", "%s -> %s" % (path, dest_path))
log("- rename: %s -> %s" % (path, dest_path))


def walk_files(paths: List[str], exclude_pat: str = DEFAULT_EXCLUDE_PAT) -> List[str]:
out: List[str] = []
exclude_re = re.compile(exclude_pat)
for path in paths:
if not os.path.exists(path):
fail("path not found: %s" % path)
_fail("path not found: %s" % path, None)
if os.path.isfile(path):
out.append(path)
else:
Expand All @@ -643,10 +676,16 @@ def rewrite_files(
exclude_pat: str = DEFAULT_EXCLUDE_PAT,
by_line: bool = False,
dry_run: bool = False,
log: LogFunc = no_log,
) -> None:
"""
Walk the given `root_paths`, rewriting and/or renaming files making simultaneous
changes according to the given list of patterns. Set `log` if you wish to log info
in `dry_run` mode.
"""
paths = walk_files(root_paths, exclude_pat=exclude_pat)
paths.sort() # Ensure deterministic order of file processing.
log(None, "Found %s files in: %s" % (len(paths), ", ".join(root_paths)))
log("Found %s files in: %s" % (len(paths), ", ".join(root_paths)))
for path in paths:
rewrite_file(
path,
Expand All @@ -655,6 +694,7 @@ def rewrite_files(
do_contents=do_contents,
by_line=by_line,
dry_run=dry_run,
log=log,
)


Expand Down Expand Up @@ -698,7 +738,7 @@ def parse_patterns(
)
)
except Exception as e:
fail("error parsing pattern: %s: %s" % (e, bits))
_fail("error parsing pattern: %s: %s" % (e, bits), e)
return patterns


Expand Down Expand Up @@ -787,17 +827,28 @@ def main() -> None:
dest="dry_run",
action="store_true",
)
parser.add_argument(
"-q",
"--quiet",
help="quiet mode: suppress all output except errors",
dest="quiet",
action="store_true",
)
parser.add_argument("root_paths", nargs="*", help="root paths to process")

options = parser.parse_args()

global _fail
_fail = _fail_exit
log: LogFunc = print_stderr if not options.quiet else no_log

if options.dry_run:
log(None, "Dry run: No files will be changed")
log("Dry run: No files will be changed")

options.do_contents = not options.do_renames
options.do_renames = options.do_renames or options.do_full

# log(None, "Settings: %s" % options)
# log("Settings: %s" % options)

if options.pat_file:
if options.from_pat or options.to_pat:
Expand All @@ -824,7 +875,7 @@ def main() -> None:
)

if len(patterns) == 0:
fail("found no parse patterns")
_fail("found no parse patterns", None)

def format_flags(flags: int) -> str:
flags_str = "|".join([s for s in ["IGNORECASE", "DOTALL"] if flags & getattr(re, s)])
Expand All @@ -833,7 +884,6 @@ def format_flags(flags: int) -> str:
return flags_str

log(
None,
("Using %s patterns:\n " % len(patterns))
+ "\n ".join(
[
Expand All @@ -858,10 +908,10 @@ def format_flags(flags: int) -> str:
exclude_pat=options.exclude_pat,
by_line=by_line,
dry_run=options.dry_run,
log=log,
)

log(
None,
"Read %s files (%s chars), found %s matches (%s skipped due to overlaps)"
% (
_tally.files,
Expand All @@ -872,7 +922,6 @@ def format_flags(flags: int) -> str:
)
change_words = "Dry run: Would have changed" if options.dry_run else "Changed"
log(
None,
"%s %s files (%s rewritten and %s renamed)"
% (change_words, _tally.files_changed, _tally.files_rewritten, _tally.renames),
)
Expand All @@ -881,11 +930,10 @@ def format_flags(flags: int) -> str:
parser.error("can't specify --renames on stdin; give filename arguments")
if options.dry_run:
parser.error("can't specify --dry-run on stdin; give filename arguments")
transform = lambda contents: multi_replace(contents, patterns)
transform = lambda contents: multi_replace(contents, patterns, log=log)
transform_stream(transform, sys.stdin.buffer, sys.stdout.buffer, by_line=by_line)

log(
None,
"Read %s chars, made %s replacements (%s skipped due to overlaps)"
% (_tally.chars, _tally.valid_matches, _tally.matches - _tally.valid_matches),
)
Expand Down
2 changes: 1 addition & 1 deletion tests/tests-clean.log
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ run || expect_error
usage: repren.py [-h] [--version] [--from FROM_PAT] [--to TO_PAT]
[-p PAT_FILE] [--full] [--renames] [--literal] [-i]
[--dotall] [--preserve-case] [-b] [--exclude EXCLUDE_PAT]
[--at-once] [-t] [-n]
[--at-once] [-t] [-n] [-q]
[root_paths ...]
repren.py: error: must specify --patterns or both --from and --to
(got expected error: status 2)
Expand Down

0 comments on commit 6db93ce

Please # to comment.