diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index e6b9ef1..99ff85a 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -61,6 +61,16 @@ def test_select(df): assert result.column(1) == pa.array([-3, -3, -3]) +def test_select_colums(df): + df = df.select_columns("b", "a") + + # execute and collect the first (and only) batch + result = df.collect()[0] + + assert result.column(0) == pa.array([4, 5, 6]) + assert result.column(1) == pa.array([1, 2, 3]) + + def test_filter(df): df = df.select( column("a") + column("b"), diff --git a/src/dataframe.rs b/src/dataframe.rs index 7c21102..964f042 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -51,6 +51,12 @@ impl PyDataFrame { self.df.schema().into() } + #[args(args = "*")] + fn select_columns(&self, args: Vec<&str>) -> PyResult { + let df = self.df.select_columns(&args)?; + Ok(Self::new(df)) + } + #[args(args = "*")] fn select(&self, args: Vec) -> PyResult { let expr = args.into_iter().map(|e| e.into()).collect();