I'm trying to perform string operations with NumPy's StringDType. As an example, I've attempted to join strings with a separator row-wise. In the past, NumPy's string operations were somewhat slower compared to Python list comprehensions, and I was hoping that with the introduction of NumPy's StringDType (which supports variable string sizes), these operations would have improved. However, I haven't been able to achieve a significant performance boost so far.
Are there better options for efficiently performing operations like string joining using NumPy's StringDType?
Here's a sample code where I test several methods that try to leverage vectorized operations:
import timeit
from functools import reduce
import numpy as np
import polars as pl
from numpy.dtypes import StringDType
def interleave_separator(arr, sep=', '):
"""Interleave a separator into a 2D array column-wise (costly helper)."""
nrows, ncols = arr.shape
interleaved = np.empty((nrows, 2 * ncols - 1), dtype=StringDType)
interleaved[:, ::2] = arr
interleaved[:, 1::2] = sep
return interleaved
def strings_join_py(arr, sep=', '):
"""Python list comprehension."""
return [sep.join(a) for a in arr]
def strings_join_pl(arr, sep=', '):
"""Polars join series of lists."""
return arr.list.join(separator=sep)
def strings_join_np1(arr, sep=', '):
"""Numpy interleave separator and apply sum."""
return np.sum(interleave_separator(arr, sep), axis=1)
def strings_join_np2(arr, sep=', '):
"""Numpy interleave separator and apply add.reduce."""
return np.strings.add.reduce(interleave_separator(arr, sep), axis=1)
def strings_join_np3(arr, sep=', '):
"""Numpy/Python accumulate over columns row-wise."""
return reduce(lambda x, y: x + sep + y, arr.T)
Check results:
np.random.seed(999)
choices = ["apple", "banana", "cherry", "salad"]
arr = np.random.choice(choices, size=(3, 3)).astype(StringDType)
sep = ", "
print('2D array:')
print(arr)
# [['apple' 'apple' 'banana']
# ['banana' 'apple' 'banana']
# ['salad' 'salad' 'banana']]
print('1D array joined by separator:')
print(strings_join_py(arr.tolist(), sep))
print(strings_join_pl(pl.Series(arr.tolist()), sep))
print(strings_join_np1(arr, sep))
print(strings_join_np2(arr, sep))
print(strings_join_np3(arr, sep))
# ['apple, apple, banana'
# 'banana, apple, banana'
# 'salad, salad, banana']
Run benchmarks:
np.random.seed(999)
choices = ["apple", "banana", "cherry", "salad"]
arr = np.random.choice(choices, size=(100_000, 10)).astype(StringDType)
lst = arr.tolist()
ser = pl.Series(lst)
sep = ", "
baseline = timeit.timeit(lambda: strings_join_py(lst, sep), number=5)
time_pl = timeit.timeit(lambda: strings_join_pl(ser, sep), number=5)
time_np1 = timeit.timeit(lambda: strings_join_np1(arr, sep), number=5)
time_np2 = timeit.timeit(lambda: strings_join_np2(arr, sep), number=5)
time_np3 = timeit.timeit(lambda: strings_join_np3(arr, sep), number=5)
print("Ratio compared to Python list comprehension (>1 is faster)")
print(f"pl: {baseline/time_pl:.2f}")
print(f"np1: {baseline/time_np1:.2f}")
print(f"np2: {baseline/time_np2:.2f}")
print(f"np3: {baseline/time_np3:.2f}")
# pl: 1.11 # Polars Series.list.join
# np1: 0.14 # interleaved np.sum
# np2: 0.14 # interleaved np.add.reduce
# np3: 0.19 # reduce np.add
Edit - Here’s a benchmark with an array shape of (500,000, 2):
# Ratio compared to Python list comprehension (>1 is faster)
# pl: 0.61
# np1: 0.31
# np2: 0.31
# np3: 1.57
The data type seems to perform well (see np3) but there seem to be not enough specialized functions at the moment to increase the usability.
Edit: Observation
I've observed that NumPy's string ufunc (np.strings.add) is quite efficient if there aren’t many intermediate results to compute. As the number of accumulated columns increases, its efficiency declines compared to a Python list comprehension.
Here's a small benchmark showing the impact of intermediate results as the number of accumulated columns rises:
# Ratio compared to Python list comprehension (>1 is faster)
# Py_list_comp / np.strings.add: 0.77 - (shape (500000, 2))
# Py_list_comp / np.strings.add: 0.04 - (shape (1000, 1000))
发布者:admin,转转请注明出处:http://www.yc00.com/questions/1744286747a4566835.html
评论列表(0条)