1+ import narwhals as nw
12import numpy as np
23import pandas as pd
4+ import polars as pl
35import pytest
46from sklearn .datasets import fetch_openml
57from sklearn .linear_model import LinearRegression
@@ -49,57 +51,56 @@ def test_alpha_param1():
4951 assert np .isclose (ifilter .fit_transform (X ), X_removed ).all ()
5052
5153
52- def test_alpha_param2 ():
54+ @pytest .mark .parametrize ("frame_func" , [pd .DataFrame , pl .DataFrame ])
55+ def test_alpha_param2 (frame_func ):
5356 X , y = fetch_openml (data_id = 531 , return_X_y = True , as_frame = False , parser = "liac-arff" )
54- df = pd .DataFrame (
55- X ,
56- columns = [
57- "crim" ,
58- "zn" ,
59- "indus" ,
60- "chas" ,
61- "nox" ,
62- "rm" ,
63- "age" ,
64- "dis" ,
65- "rad" ,
66- "tax" ,
67- "ptratio" ,
68- "b" ,
69- "lstat" ,
70- ],
71- )
57+ cols = [
58+ "crim" ,
59+ "zn" ,
60+ "indus" ,
61+ "chas" ,
62+ "nox" ,
63+ "rm" ,
64+ "age" ,
65+ "dis" ,
66+ "rad" ,
67+ "tax" ,
68+ "ptratio" ,
69+ "b" ,
70+ "lstat" ,
71+ ]
72+ df = frame_func (dict (zip (cols , X .T )))
7273 ifilter = InformationFilter (columns = ["b" , "lstat" ], alpha = 0.0 )
73- X_removed = df .drop (columns = ["b" , "lstat" ]).values
74+ X_removed = nw . from_native ( df ) .drop (["b" , "lstat" ]).to_numpy ()
7475 assert np .isclose (ifilter .fit_transform (df ), X_removed ).all ()
7576
7677
77- def test_output_orthogonal_pandas ():
78+ @pytest .mark .parametrize ("frame_func" , [pd .DataFrame , pl .DataFrame ])
79+ def test_output_orthogonal_frame (frame_func ):
7880 X , y = fetch_openml (data_id = 531 , return_X_y = True , as_frame = False , parser = "liac-arff" )
79- df = pd .DataFrame (
80- X ,
81- columns = [
82- "crim" ,
83- "zn" ,
84- "indus" ,
85- "chas" ,
86- "nox" ,
87- "rm" ,
88- "age" ,
89- "dis" ,
90- "rad" ,
91- "tax" ,
92- "ptratio" ,
93- "b" ,
94- "lstat" ,
95- ],
96- )
81+ cols = [
82+ "crim" ,
83+ "zn" ,
84+ "indus" ,
85+ "chas" ,
86+ "nox" ,
87+ "rm" ,
88+ "age" ,
89+ "dis" ,
90+ "rad" ,
91+ "tax" ,
92+ "ptratio" ,
93+ "b" ,
94+ "lstat" ,
95+ ]
96+ df = frame_func (dict (zip (cols , X .T )))
9797 X_fair = InformationFilter (columns = ["b" , "lstat" ]).fit_transform (df )
9898 assert all ([(c * df ["b" ]).sum () < 1e-5 for c in X_fair .T ])
9999 assert all ([(c * df ["lstat" ]).sum () < 1e-5 for c in X_fair .T ])
100100
101101
102- def test_output_orthogonal_general_cols ():
102+ @pytest .mark .parametrize ("frame_func" , [pd .DataFrame , pl .DataFrame ])
103+ def test_output_orthogonal_general_cols (frame_func ):
103104 X , y = fetch_openml (data_id = 531 , return_X_y = True , as_frame = False , parser = "liac-arff" )
104105 cols = [
105106 "crim" ,
@@ -116,7 +117,7 @@ def test_output_orthogonal_general_cols():
116117 "b" ,
117118 "lstat" ,
118119 ]
119- df = pd . DataFrame ( X , columns = cols )
120+ df = frame_func ( dict ( zip ( cols , X . T )) )
120121 for col in cols :
121122 X_fair = InformationFilter (columns = col ).fit_transform (df )
122123 assert all ([(c * df [col ]).sum () < 1e-5 for c in X_fair .T ])
0 commit comments