-
Notifications
You must be signed in to change notification settings - Fork 29.2k
[SPARK-42751][PS] Support str.findall with capture groups #56533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
92a3413
d3fb798
4c84a24
ef9a01b
94626dd
2234e4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| String functions on pandas-on-Spark Series | ||
| """ | ||
|
|
||
| import re | ||
| from functools import wraps | ||
| from typing import ( | ||
| Any, | ||
|
|
@@ -1116,6 +1117,12 @@ def findall(self, pat: str, flags: int = 0) -> "ps.Series": | |
| All non-overlapping matches of pattern or regular expression in | ||
| each string of this Series. | ||
|
|
||
| Notes | ||
| ----- | ||
| For regular expressions with more than one capture group, pandas-on-Spark | ||
| returns nested lists instead of pandas' tuple matches because Spark SQL | ||
| does not have a tuple type. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> s = ps.Series(['Lion', 'Monkey', 'Rabbit']) | ||
|
|
@@ -1174,14 +1181,23 @@ def findall(self, pat: str, flags: int = 0) -> "ps.Series": | |
| 2 [b, b] | ||
| dtype: object | ||
| """ | ||
| num_groups = re.compile(pat, flags=flags).groups | ||
| str_dtype = is_str_dtype(self._data.dtype) | ||
| if num_groups > 1: | ||
| return_type = ArrayType(ArrayType(StringType(), containsNull=True), containsNull=True) | ||
| else: | ||
| return_type = ArrayType(StringType(), containsNull=True) | ||
|
|
||
| # type hint does not support to specify array type yet. | ||
| @pandas_udf( # type: ignore[call-overload] | ||
| returnType=ArrayType(StringType(), containsNull=True) | ||
| ) | ||
| @pandas_udf(returnType=return_type) # type: ignore[call-overload] | ||
| def pudf(s: pd.Series) -> pd.Series: | ||
| ret = s.str.findall(pat, flags) | ||
| if num_groups > 1: | ||
| ret = ret.map( | ||
| lambda matches: [list(match) for match in matches] | ||
| if isinstance(matches, list) | ||
| else matches | ||
| ) | ||
| if str_dtype: | ||
| # ArrayType does not support NaN, so replace with None | ||
| ret = ret.replace(np.nan, None) | ||
|
Comment on lines
1194
to
1203
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems a bit orthogonal to the current PR, but is this a real concern nonetheless? @Yicong-Huang @HyukjinKwon PTAL ^^ |
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -86,6 +86,20 @@ def test_string_findall(self): | |
| lambda x: x.str.findall("wh.*", flags=re.IGNORECASE), self.pser, ignore_null=True | ||
| ) | ||
|
|
||
| pser = pd.Series(["abc-123 def-456", "no match", None]) | ||
| pattern = "([a-z]+)-([0-9]+)" | ||
|
|
||
| def normalize_matches(matches): # type: ignore[no-untyped-def] | ||
| if isinstance(matches, (list, np.ndarray)): | ||
| return [list(match) for match in matches] | ||
| return matches | ||
|
|
||
| expected = pser.str.findall(pattern).map(normalize_matches) | ||
| actual = ps.from_pandas(pser).str.findall(pattern).to_pandas() | ||
| self.assertIsNone(actual.iloc[-1]) | ||
| actual = actual.map(normalize_matches) | ||
| self.assert_eq(actual, expected) | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: the new test is hand-rolled, if possible - please try to align with the rest of the suite (e.g. use existing helpers like |
||
| def test_string_index(self): | ||
| pser = pd.Series(["tea", "eat"]) | ||
| self.check_func_on_series(lambda x: x.str.index("ea"), pser) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we update the Examples accordingly? Seems like a short doctest showing multi-group output as nested lists would complement the Notes section.