forked from scottkleinman/rollingwindows
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhelpers.py
More file actions
145 lines (115 loc) · 3.34 KB
/
helpers.py
File metadata and controls
145 lines (115 loc) · 3.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""helpers.py.
Last Update: June 9 2024
"""
import re
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Union
import spacy
from spacy.language import Language
@dataclass
class Windows:
"""A dataclass for storing rolling windows."""
windows: Iterable
window_units: str
n: int
alignment_mode: str = "strict"
def __iter__(self):
"""Iterate over the windows."""
return iter(self.windows)
def ensure_doc(
input: Union[str, List[str], spacy.tokens.doc.Doc],
nlp: Union[Language, str],
batch_size: int = 1000,
) -> spacy.tokens.doc.Doc:
"""Converts string or list inputs to spaCy docs.
Args:
input (Union[str, List[str], spacy.tokens.doc.Doc]): A string, list of tokens, or a spaCy doc.
nlp (Union[Language, str]): The language model to use.
batch_size (int): The number of texts to accumulate in an internal buffer.
Returns:
spacy.tokens.doc.Doc: A spaCy doc, unannotated if derived from a string or list of tokens.
"""
if isinstance(input, spacy.tokens.doc.Doc):
return input
else:
if isinstance(nlp, str):
nlp = spacy.load(nlp)
if isinstance(input, str):
return list(nlp.tokenizer.pipe([input], batch_size=batch_size))[0]
elif isinstance(input, list):
return list(nlp.tokenizer.pipe([" ".join(input)], batch_size=batch_size))[0]
else:
raise Exception(
"Invalid data type. Input data must be a string, a list of strings, or a spaCy doc."
)
def ensure_list(input: Any) -> list:
"""Ensure that an item is of type list.
Args:
input (Any): An input variable.
Returns:
list: The input variable in a list if it is not already a list.
"""
if not isinstance(input, list):
input = [input]
return input
def flatten(input: Union[dict, list, str]) -> Iterable:
"""Yield items from any nested iterable.
Args:
input (Union[dict, list, str]): A list of lists or dicts.
Yields:
d
Notes:
See https://stackoverflow.com/a/40857703.
"""
for x in input:
if isinstance(x, Iterable) and not isinstance(x, str):
if isinstance(x, list):
for sub_x in flatten(x):
yield sub_x
elif isinstance(x, dict):
yield list(x.values())[0]
else:
yield x
def regex_escape(s: str) -> str:
"""Escape only regex special characters.
Args:
s (str): A string.
Returns:
An escaped string.
Note:
See https://stackoverflow.com/a/78136529/22853742.
"""
if isinstance(s, bytes):
return re.sub(rb"[][(){}?*+.^$]", lambda m: b"\\" + m.group(), s)
return re.sub(r"[][(){}?*+.^$]", lambda m: "\\" + m.group(), s)
def spacy_rule_to_lower(
patterns: Union[Dict, List[Dict]],
old_key: Union[List[str], str] = ["TEXT", "ORTH"],
new_key: str = "LOWER",
) -> list:
"""Convert spacy Rule Matcher patterns to lowercase.
Args:
patterns (Union[Dict, List[Dict]]): A list of spacy Rule Matcher patterns.
old_key (Union[List[str], str]): A dictionary key or list of keys to rename.
new_key (str): The new key name.
Returns:
A list of spacy Rule Matcher patterns
"""
def convert(key):
"""Converts the key to lowercase."""
if key in old_key:
return new_key
else:
return key
if isinstance(patterns, dict):
new_dict = {}
for key, value in patterns.items():
key = convert(key)
new_dict[key] = value
return new_dict
if isinstance(patterns, list):
new_list = []
for value in patterns:
new_list.append(spacy_rule_to_lower(value))
return new_list
return new_list