Skip to content

Commit 28adf54

Browse files
committed
Enhance GLM pipeline with additional visualization functions and model diagnostics
1 parent 9bad92a commit 28adf54

File tree

1 file changed

+151
-1
lines changed

1 file changed

+151
-1
lines changed

src/glm_pipeline.py

Lines changed: 151 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pandas as pd
22
import statsmodels.api as sm
33
import matplotlib.pyplot as plt
4+
import numpy as np
5+
import seaborn as sns
46

57
def load_data(filepath):
68
df = pd.read_csv(filepath)
@@ -17,11 +19,159 @@ def fit_glm(X, y, family=sm.families.Binomial()):
1719
results = model.fit()
1820
return results
1921

22+
def plot_residuals(results, X, y, save_path=None):
23+
"""Plot residual diagnostics for the GLM model."""
24+
# Create a 2x2 subplot
25+
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
26+
fig.suptitle('GLM Model Diagnostics', fontsize=16)
27+
28+
# Get fitted values and residuals
29+
fitted_values = results.fittedvalues
30+
residuals = results.resid_pearson
31+
32+
# 1. Residuals vs Fitted
33+
axes[0, 0].scatter(fitted_values, residuals, alpha=0.6)
34+
axes[0, 0].axhline(y=0, color='red', linestyle='--')
35+
axes[0, 0].set_xlabel('Fitted Values')
36+
axes[0, 0].set_ylabel('Pearson Residuals')
37+
axes[0, 0].set_title('Residuals vs Fitted')
38+
39+
# 2. Q-Q Plot
40+
sm.qqplot(residuals, line='45', ax=axes[0, 1])
41+
axes[0, 1].set_title('Q-Q Plot of Residuals')
42+
43+
# 3. Scale-Location Plot
44+
sqrt_abs_residuals = np.sqrt(np.abs(residuals))
45+
axes[1, 0].scatter(fitted_values, sqrt_abs_residuals, alpha=0.6)
46+
axes[1, 0].set_xlabel('Fitted Values')
47+
axes[1, 0].set_ylabel('√|Residuals|')
48+
axes[1, 0].set_title('Scale-Location Plot')
49+
50+
# 4. Residuals vs Leverage
51+
leverage = results.get_influence().hat_matrix_diag
52+
axes[1, 1].scatter(leverage, residuals, alpha=0.6)
53+
axes[1, 1].axhline(y=0, color='red', linestyle='--')
54+
axes[1, 1].set_xlabel('Leverage')
55+
axes[1, 1].set_ylabel('Pearson Residuals')
56+
axes[1, 1].set_title('Residuals vs Leverage')
57+
58+
plt.tight_layout()
59+
60+
if save_path:
61+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
62+
plt.show()
63+
64+
def plot_feature_distributions(df, feature_cols, target_col, save_path=None):
65+
"""Plot distributions of features by target variable."""
66+
n_features = len(feature_cols)
67+
n_cols = min(3, n_features)
68+
n_rows = (n_features + n_cols - 1) // n_cols
69+
70+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
71+
fig.suptitle(f'Feature Distributions by {target_col}', fontsize=16)
72+
73+
if n_features == 1:
74+
axes = [axes]
75+
elif n_rows == 1:
76+
axes = axes.flatten() if n_features > 1 else [axes]
77+
else:
78+
axes = axes.flatten()
79+
80+
for i, feature in enumerate(feature_cols):
81+
ax = axes[i]
82+
83+
# Create histograms for each target class
84+
for target_val in df[target_col].unique():
85+
subset = df[df[target_col] == target_val][feature]
86+
ax.hist(subset, alpha=0.7, label=f'{target_col}={target_val}', bins=20)
87+
88+
ax.set_xlabel(feature)
89+
ax.set_ylabel('Frequency')
90+
ax.set_title(f'Distribution of {feature}')
91+
ax.legend()
92+
ax.grid(True, alpha=0.3)
93+
94+
# Hide empty subplots
95+
for i in range(n_features, len(axes)):
96+
axes[i].set_visible(False)
97+
98+
plt.tight_layout()
99+
100+
if save_path:
101+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
102+
plt.show()
103+
104+
def plot_correlation_matrix(df, feature_cols, save_path=None):
105+
"""Plot correlation matrix of features."""
106+
corr_matrix = df[feature_cols].corr()
107+
108+
plt.figure(figsize=(10, 8))
109+
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', center=0,
110+
square=True, linewidths=0.5)
111+
plt.title('Feature Correlation Matrix')
112+
plt.tight_layout()
113+
114+
if save_path:
115+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
116+
plt.show()
117+
118+
def plot_coefficients(results, feature_names, save_path=None):
119+
"""Plot model coefficients with confidence intervals."""
120+
params = results.params
121+
conf_int = results.conf_int()
122+
123+
fig, ax = plt.subplots(figsize=(10, 6))
124+
125+
y_pos = np.arange(len(feature_names))
126+
127+
# Plot coefficients
128+
ax.barh(y_pos, params, alpha=0.7)
129+
130+
# Plot confidence intervals
131+
for i, (param, (lower, upper)) in enumerate(zip(params, conf_int.values)):
132+
ax.plot([lower, upper], [i, i], 'k-', linewidth=2)
133+
ax.plot([lower, lower], [i-0.1, i+0.1], 'k-', linewidth=2)
134+
ax.plot([upper, upper], [i-0.1, i+0.1], 'k-', linewidth=2)
135+
136+
ax.axvline(x=0, color='red', linestyle='--', alpha=0.7)
137+
ax.set_yticks(y_pos)
138+
ax.set_yticklabels(feature_names)
139+
ax.set_xlabel('Coefficient Value')
140+
ax.set_title('Model Coefficients with 95% Confidence Intervals')
141+
ax.grid(True, alpha=0.3)
142+
143+
plt.tight_layout()
144+
145+
if save_path:
146+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
147+
plt.show()
148+
20149
def main():
150+
# Load and preprocess data
21151
df = load_data("data/sample.csv")
22-
X, y = preprocess_data(df, "default", ["age", "income", "balance"])
152+
feature_cols = ["age", "income", "balance"]
153+
target_col = "default"
154+
155+
X, y = preprocess_data(df, target_col, feature_cols)
156+
157+
# Create visualizations before fitting the model
158+
print("Creating exploratory data visualizations...")
159+
plot_feature_distributions(df, feature_cols, target_col)
160+
plot_correlation_matrix(df, feature_cols)
161+
162+
# Fit the GLM model
23163
results = fit_glm(X, y)
24164
print(results.summary())
165+
166+
# Create model diagnostic plots
167+
print("Creating model diagnostic plots...")
168+
plot_residuals(results, X, y)
169+
170+
# Plot coefficients
171+
feature_names = ['const'] + feature_cols
172+
plot_coefficients(results, feature_names)
173+
174+
# Save the model
25175
results.save("models/glm_model.pickle")
26176

27177
if __name__ == "__main__":

0 commit comments

Comments
 (0)