Skip to content

Commit a0d94f1

Browse files
authored
Feat: Series.shift Pyarrow Backend Implementation (#590)
1 parent 463b328 commit a0d94f1

File tree

3 files changed

+41
-15
lines changed

3 files changed

+41
-15
lines changed

narwhals/_arrow/expr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,9 @@ def sum(self) -> Self:
203203
def drop_nulls(self) -> Self:
204204
return reuse_series_implementation(self, "drop_nulls")
205205

206+
def shift(self, n: int) -> Self:
207+
return reuse_series_implementation(self, "shift", n)
208+
206209
def alias(self, name: str) -> Self:
207210
# Define this one manually, so that we can
208211
# override `output_names` and not increase depth

narwhals/_arrow/series.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,18 @@ def drop_nulls(self) -> ArrowSeries:
221221
pc = get_pyarrow_compute()
222222
return self._from_native_series(pc.drop_null(self._native_series))
223223

224+
def shift(self, n: int) -> Self:
225+
pa = get_pyarrow()
226+
ca = self._native_series
227+
228+
if n > 0:
229+
result = pa.concat_arrays([pa.nulls(n, ca.type), *ca[:-n].chunks])
230+
elif n < 0:
231+
result = pa.concat_arrays([*ca[-n:].chunks, pa.nulls(-n, ca.type)])
232+
else:
233+
result = ca
234+
return self._from_native_series(result)
235+
224236
def std(self, ddof: int = 1) -> int:
225237
pc = get_pyarrow_compute()
226238
return pc.stddev(self._native_series, ddof=ddof) # type: ignore[no-any-return]
Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any
22

3-
import pytest
3+
import pyarrow as pa
44

55
import narwhals.stable.v1 as nw
66
from tests.utils import compare_dicts
@@ -13,10 +13,7 @@
1313
}
1414

1515

16-
def test_shift(request: Any, constructor: Any) -> None:
17-
if "pyarrow_table" in str(constructor):
18-
request.applymarker(pytest.mark.xfail)
19-
16+
def test_shift(constructor: Any) -> None:
2017
df = nw.from_native(constructor(data))
2118
result = df.with_columns(nw.col("a", "b", "c").shift(2)).filter(nw.col("i") > 1)
2219
expected = {
@@ -28,21 +25,35 @@ def test_shift(request: Any, constructor: Any) -> None:
2825
compare_dicts(result, expected)
2926

3027

31-
def test_shift_series(request: Any, constructor_eager: Any) -> None:
32-
if "pyarrow_table" in str(constructor_eager):
33-
request.applymarker(pytest.mark.xfail)
34-
28+
def test_shift_series(constructor_eager: Any) -> None:
3529
df = nw.from_native(constructor_eager(data), eager_only=True)
30+
result = df.with_columns(
31+
df["a"].shift(2),
32+
df["b"].shift(2),
33+
df["c"].shift(2),
34+
).filter(nw.col("i") > 1)
3635
expected = {
3736
"i": [2, 3, 4],
3837
"a": [0, 1, 2],
3938
"b": [1, 2, 3],
4039
"c": [5, 4, 3],
4140
}
42-
result = df.select(
43-
df["i"],
44-
df["a"].shift(2),
45-
df["b"].shift(2),
46-
df["c"].shift(2),
47-
).filter(nw.col("i") > 1)
41+
compare_dicts(result, expected)
42+
43+
44+
def test_shift_multi_chunk_pyarrow() -> None:
45+
tbl = pa.table({"a": [1, 2, 3]})
46+
tbl = pa.concat_tables([tbl, tbl, tbl])
47+
df = nw.from_native(tbl, eager_only=True)
48+
49+
result = df.select(nw.col("a").shift(1))
50+
expected = {"a": [None, 1, 2, 3, 1, 2, 3, 1, 2]}
51+
compare_dicts(result, expected)
52+
53+
result = df.select(nw.col("a").shift(-1))
54+
expected = {"a": [2, 3, 1, 2, 3, 1, 2, 3, None]}
55+
compare_dicts(result, expected)
56+
57+
result = df.select(nw.col("a").shift(0))
58+
expected = {"a": [1, 2, 3, 1, 2, 3, 1, 2, 3]}
4859
compare_dicts(result, expected)

0 commit comments

Comments
 (0)