From bad76782dfd97a8218a0a7ac61f21e5ff35b9daa Mon Sep 17 00:00:00 2001 From: w-brock <77081497+w-brock@users.noreply.github.com> Date: Mon, 25 Oct 2021 14:18:09 -0400 Subject: [PATCH] Add tag to relative_x_is_y --- replacy/default_match_hooks.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/replacy/default_match_hooks.py b/replacy/default_match_hooks.py index f385c3a..2f7d05e 100644 --- a/replacy/default_match_hooks.py +++ b/replacy/default_match_hooks.py @@ -210,8 +210,8 @@ def relative_x_is_y(children_or_ancestors: str, pos_or_dep: str, value: Union[st if children_or_ancestors not in ["children", "ancestors"]: raise ValueError("children_or_ancestors must be set to either `children` or `ancestors`") - if pos_or_dep not in ["pos", "dep"]: - raise ValueError("pos_or_dep must be set to either `pos` or `dep`!") + if pos_or_dep not in ["pos", "dep", "tag"]: + raise ValueError("pos_or_dep must be set to either `pos`, `dep`, or `tag`!") def _in_children(doc, start, end): if end >= len(doc): @@ -222,6 +222,8 @@ def _in_children(doc, start, end): return any([child.pos_ == val for tok in match_span for child in tok.children]) elif pos_or_dep == "dep": return any([child.dep_ == val for tok in match_span for child in tok.children]) + elif pos_or_dep == "tag": + return any([child.tag_ == val for tok in match_span for child in tok.children]) def _in_ancestors(doc, start, end): if end >= len(doc): @@ -240,6 +242,12 @@ def _in_ancestors(doc, start, end): if ancestor and ancestor.dep_ == val: return True return False + if pos_or_dep == "tag": + for t in match_span: + ancestor = list(t.ancestors)[0] if len(list(t.ancestors)) else None + if ancestor and ancestor.tag_ == val: + return True + return False if children_or_ancestors == "children": return _in_children