11import pandas as pd
22import statsmodels .api as sm
33import matplotlib .pyplot as plt
4+ import numpy as np
5+ import seaborn as sns
46
57def load_data (filepath ):
68 df = pd .read_csv (filepath )
@@ -17,11 +19,159 @@ def fit_glm(X, y, family=sm.families.Binomial()):
1719 results = model .fit ()
1820 return results
1921
22+ def plot_residuals (results , X , y , save_path = None ):
23+ """Plot residual diagnostics for the GLM model."""
24+ # Create a 2x2 subplot
25+ fig , axes = plt .subplots (2 , 2 , figsize = (12 , 10 ))
26+ fig .suptitle ('GLM Model Diagnostics' , fontsize = 16 )
27+
28+ # Get fitted values and residuals
29+ fitted_values = results .fittedvalues
30+ residuals = results .resid_pearson
31+
32+ # 1. Residuals vs Fitted
33+ axes [0 , 0 ].scatter (fitted_values , residuals , alpha = 0.6 )
34+ axes [0 , 0 ].axhline (y = 0 , color = 'red' , linestyle = '--' )
35+ axes [0 , 0 ].set_xlabel ('Fitted Values' )
36+ axes [0 , 0 ].set_ylabel ('Pearson Residuals' )
37+ axes [0 , 0 ].set_title ('Residuals vs Fitted' )
38+
39+ # 2. Q-Q Plot
40+ sm .qqplot (residuals , line = '45' , ax = axes [0 , 1 ])
41+ axes [0 , 1 ].set_title ('Q-Q Plot of Residuals' )
42+
43+ # 3. Scale-Location Plot
44+ sqrt_abs_residuals = np .sqrt (np .abs (residuals ))
45+ axes [1 , 0 ].scatter (fitted_values , sqrt_abs_residuals , alpha = 0.6 )
46+ axes [1 , 0 ].set_xlabel ('Fitted Values' )
47+ axes [1 , 0 ].set_ylabel ('√|Residuals|' )
48+ axes [1 , 0 ].set_title ('Scale-Location Plot' )
49+
50+ # 4. Residuals vs Leverage
51+ leverage = results .get_influence ().hat_matrix_diag
52+ axes [1 , 1 ].scatter (leverage , residuals , alpha = 0.6 )
53+ axes [1 , 1 ].axhline (y = 0 , color = 'red' , linestyle = '--' )
54+ axes [1 , 1 ].set_xlabel ('Leverage' )
55+ axes [1 , 1 ].set_ylabel ('Pearson Residuals' )
56+ axes [1 , 1 ].set_title ('Residuals vs Leverage' )
57+
58+ plt .tight_layout ()
59+
60+ if save_path :
61+ plt .savefig (save_path , dpi = 300 , bbox_inches = 'tight' )
62+ plt .show ()
63+
64+ def plot_feature_distributions (df , feature_cols , target_col , save_path = None ):
65+ """Plot distributions of features by target variable."""
66+ n_features = len (feature_cols )
67+ n_cols = min (3 , n_features )
68+ n_rows = (n_features + n_cols - 1 ) // n_cols
69+
70+ fig , axes = plt .subplots (n_rows , n_cols , figsize = (5 * n_cols , 4 * n_rows ))
71+ fig .suptitle (f'Feature Distributions by { target_col } ' , fontsize = 16 )
72+
73+ if n_features == 1 :
74+ axes = [axes ]
75+ elif n_rows == 1 :
76+ axes = axes .flatten () if n_features > 1 else [axes ]
77+ else :
78+ axes = axes .flatten ()
79+
80+ for i , feature in enumerate (feature_cols ):
81+ ax = axes [i ]
82+
83+ # Create histograms for each target class
84+ for target_val in df [target_col ].unique ():
85+ subset = df [df [target_col ] == target_val ][feature ]
86+ ax .hist (subset , alpha = 0.7 , label = f'{ target_col } ={ target_val } ' , bins = 20 )
87+
88+ ax .set_xlabel (feature )
89+ ax .set_ylabel ('Frequency' )
90+ ax .set_title (f'Distribution of { feature } ' )
91+ ax .legend ()
92+ ax .grid (True , alpha = 0.3 )
93+
94+ # Hide empty subplots
95+ for i in range (n_features , len (axes )):
96+ axes [i ].set_visible (False )
97+
98+ plt .tight_layout ()
99+
100+ if save_path :
101+ plt .savefig (save_path , dpi = 300 , bbox_inches = 'tight' )
102+ plt .show ()
103+
104+ def plot_correlation_matrix (df , feature_cols , save_path = None ):
105+ """Plot correlation matrix of features."""
106+ corr_matrix = df [feature_cols ].corr ()
107+
108+ plt .figure (figsize = (10 , 8 ))
109+ sns .heatmap (corr_matrix , annot = True , cmap = 'coolwarm' , center = 0 ,
110+ square = True , linewidths = 0.5 )
111+ plt .title ('Feature Correlation Matrix' )
112+ plt .tight_layout ()
113+
114+ if save_path :
115+ plt .savefig (save_path , dpi = 300 , bbox_inches = 'tight' )
116+ plt .show ()
117+
118+ def plot_coefficients (results , feature_names , save_path = None ):
119+ """Plot model coefficients with confidence intervals."""
120+ params = results .params
121+ conf_int = results .conf_int ()
122+
123+ fig , ax = plt .subplots (figsize = (10 , 6 ))
124+
125+ y_pos = np .arange (len (feature_names ))
126+
127+ # Plot coefficients
128+ ax .barh (y_pos , params , alpha = 0.7 )
129+
130+ # Plot confidence intervals
131+ for i , (param , (lower , upper )) in enumerate (zip (params , conf_int .values )):
132+ ax .plot ([lower , upper ], [i , i ], 'k-' , linewidth = 2 )
133+ ax .plot ([lower , lower ], [i - 0.1 , i + 0.1 ], 'k-' , linewidth = 2 )
134+ ax .plot ([upper , upper ], [i - 0.1 , i + 0.1 ], 'k-' , linewidth = 2 )
135+
136+ ax .axvline (x = 0 , color = 'red' , linestyle = '--' , alpha = 0.7 )
137+ ax .set_yticks (y_pos )
138+ ax .set_yticklabels (feature_names )
139+ ax .set_xlabel ('Coefficient Value' )
140+ ax .set_title ('Model Coefficients with 95% Confidence Intervals' )
141+ ax .grid (True , alpha = 0.3 )
142+
143+ plt .tight_layout ()
144+
145+ if save_path :
146+ plt .savefig (save_path , dpi = 300 , bbox_inches = 'tight' )
147+ plt .show ()
148+
20149def main ():
150+ # Load and preprocess data
21151 df = load_data ("data/sample.csv" )
22- X , y = preprocess_data (df , "default" , ["age" , "income" , "balance" ])
152+ feature_cols = ["age" , "income" , "balance" ]
153+ target_col = "default"
154+
155+ X , y = preprocess_data (df , target_col , feature_cols )
156+
157+ # Create visualizations before fitting the model
158+ print ("Creating exploratory data visualizations..." )
159+ plot_feature_distributions (df , feature_cols , target_col )
160+ plot_correlation_matrix (df , feature_cols )
161+
162+ # Fit the GLM model
23163 results = fit_glm (X , y )
24164 print (results .summary ())
165+
166+ # Create model diagnostic plots
167+ print ("Creating model diagnostic plots..." )
168+ plot_residuals (results , X , y )
169+
170+ # Plot coefficients
171+ feature_names = ['const' ] + feature_cols
172+ plot_coefficients (results , feature_names )
173+
174+ # Save the model
25175 results .save ("models/glm_model.pickle" )
26176
27177if __name__ == "__main__" :
0 commit comments