Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions python/pyspark/pandas/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
String functions on pandas-on-Spark Series
"""

import re
from functools import wraps
from typing import (
Any,
Expand Down Expand Up @@ -1116,6 +1117,12 @@ def findall(self, pat: str, flags: int = 0) -> "ps.Series":
All non-overlapping matches of pattern or regular expression in
each string of this Series.

Notes
-----
For regular expressions with more than one capture group, pandas-on-Spark
returns nested lists instead of pandas' tuple matches because Spark SQL
does not have a tuple type.

Examples

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we update the Examples accordingly? Seems like a short doctest showing multi-group output as nested lists would complement the Notes section.

--------
>>> s = ps.Series(['Lion', 'Monkey', 'Rabbit'])
Expand Down Expand Up @@ -1174,14 +1181,23 @@ def findall(self, pat: str, flags: int = 0) -> "ps.Series":
2 [b, b]
dtype: object
"""
num_groups = re.compile(pat, flags=flags).groups
str_dtype = is_str_dtype(self._data.dtype)
if num_groups > 1:
return_type = ArrayType(ArrayType(StringType(), containsNull=True), containsNull=True)
else:
return_type = ArrayType(StringType(), containsNull=True)

# type hint does not support to specify array type yet.
@pandas_udf( # type: ignore[call-overload]
returnType=ArrayType(StringType(), containsNull=True)
)
@pandas_udf(returnType=return_type) # type: ignore[call-overload]
def pudf(s: pd.Series) -> pd.Series:
ret = s.str.findall(pat, flags)
if num_groups > 1:
ret = ret.map(
lambda matches: [list(match) for match in matches]
if isinstance(matches, list)
else matches
)
if str_dtype:
# ArrayType does not support NaN, so replace with None
ret = ret.replace(np.nan, None)
Comment on lines 1194 to 1203

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems a bit orthogonal to the current PR, but is this a real concern nonetheless? @Yicong-Huang @HyukjinKwon PTAL ^^

Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/pandas/tests/series/test_string_ops_adv.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,20 @@ def test_string_findall(self):
lambda x: x.str.findall("wh.*", flags=re.IGNORECASE), self.pser, ignore_null=True
)

pser = pd.Series(["abc-123 def-456", "no match", None])
pattern = "([a-z]+)-([0-9]+)"

def normalize_matches(matches): # type: ignore[no-untyped-def]
if isinstance(matches, (list, np.ndarray)):
return [list(match) for match in matches]
return matches

expected = pser.str.findall(pattern).map(normalize_matches)
actual = ps.from_pandas(pser).str.findall(pattern).to_pandas()
self.assertIsNone(actual.iloc[-1])
actual = actual.map(normalize_matches)
self.assert_eq(actual, expected)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: the new test is hand-rolled, if possible - please try to align with the rest of the suite (e.g. use existing helpers like check_func_on_series).

def test_string_index(self):
pser = pd.Series(["tea", "eat"])
self.check_func_on_series(lambda x: x.str.index("ea"), pser)
Expand Down