Skip to content

Commit 1070aea

Browse files
author
Miruna Oprescu
committed
Add performance tests and an example notebook
1 parent 762e0e4 commit 1070aea

File tree

4 files changed

+872
-2
lines changed

4 files changed

+872
-2
lines changed

econml/dml/dynamic_dml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ def _gen_ortho_learner_model_final(self, n_periods):
482482
return _LinearDynamicModelFinal(wrapped_final_model, n_periods=n_periods)
483483

484484
def _prefit(self, Y, T, *args, groups=None, only_final=False, **kwargs):
485-
u_periods = np.unique(np.bincount(groups.astype(int)))
485+
u_periods = np.unique(np.unique(groups, return_counts=True)[1])
486486
if len(u_periods) > 1:
487487
raise AttributeError(
488488
"Imbalanced panel. Method currently expects only panels with equal number of periods. Pad your data")

econml/tests/dgp.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
import numpy as np
4+
from econml.utilities import cross_product
5+
from statsmodels.tools.tools import add_constant
6+
7+
8+
class _BaseDynamicPanelDGP:
9+
10+
def __init__(self, n_periods, n_treatments, n_x):
11+
self.n_periods = n_periods
12+
self.n_treatments = n_treatments
13+
self.n_x = n_x
14+
return
15+
16+
def create_instance(self, *args, **kwargs):
17+
pass
18+
19+
def _gen_data_with_policy(self, n_units, policy_gen, random_seed=123):
20+
pass
21+
22+
def static_policy_data(self, n_units, tau, random_seed=123):
23+
def policy_gen(Tpre, X, period):
24+
return tau[period]
25+
return self._gen_data_with_policy(n_units, policy_gen, random_seed=random_seed)
26+
27+
def adaptive_policy_data(self, n_units, policy_gen, random_seed=123):
28+
return self._gen_data_with_policy(n_units, policy_gen, random_seed=random_seed)
29+
30+
def static_policy_effect(self, tau, mc_samples=1000):
31+
Y_tau, _, _, _ = self.static_policy_data(mc_samples, tau)
32+
Y_zero, _, _, _ = self.static_policy_data(
33+
mc_samples, np.zeros((self.n_periods, self.n_treatments)))
34+
return np.mean(Y_tau[np.arange(Y_tau.shape[0]) % self.n_periods == self.n_periods - 1]) - \
35+
np.mean(Y_zero[np.arange(Y_zero.shape[0]) %
36+
self.n_periods == self.n_periods - 1])
37+
38+
def adaptive_policy_effect(self, policy_gen, mc_samples=1000):
39+
Y_tau, _, _, _ = self.adaptive_policy_data(mc_samples, policy_gen)
40+
Y_zero, _, _, _ = self.static_policy_data(
41+
mc_samples, np.zeros((self.n_periods, self.n_treatments)))
42+
return np.mean(Y_tau[np.arange(Y_tau.shape[0]) % self.n_periods == self.n_periods - 1]) - \
43+
np.mean(Y_zero[np.arange(Y_zero.shape[0]) %
44+
self.n_periods == self.n_periods - 1])
45+
46+
47+
class DynamicPanelDGP(_BaseDynamicPanelDGP):
48+
49+
def __init__(self, n_periods, n_treatments, n_x):
50+
super().__init__(n_periods, n_treatments, n_x)
51+
52+
def create_instance(self, s_x, sigma_x=.5, sigma_y=.5, conf_str=5, hetero_strength=0, hetero_inds=None,
53+
autoreg=.5, state_effect=.5, random_seed=123):
54+
np.random.seed(random_seed)
55+
self.s_x = s_x
56+
self.conf_str = conf_str
57+
self.sigma_x = sigma_x
58+
self.sigma_y = sigma_y
59+
self.hetero_inds = hetero_inds.astype(
60+
int) if hetero_inds is not None else hetero_inds
61+
self.endo_inds = np.setdiff1d(
62+
np.arange(self.n_x), hetero_inds).astype(int)
63+
# The first s_x state variables are confounders. The final s_x variables are exogenous and can create
64+
# heterogeneity
65+
self.Alpha = np.random.uniform(-1, 1,
66+
size=(self.n_x, self.n_treatments))
67+
self.Alpha /= np.linalg.norm(self.Alpha, axis=1, ord=1, keepdims=True)
68+
self.Alpha *= state_effect
69+
if self.hetero_inds is not None:
70+
self.Alpha[self.hetero_inds] = 0
71+
72+
self.Beta = np.zeros((self.n_x, self.n_x))
73+
for t in range(self.n_x):
74+
self.Beta[t, :] = autoreg * np.roll(np.random.uniform(low=4.0**(-np.arange(
75+
0, self.n_x)), high=4.0**(-np.arange(1, self.n_x + 1))), t)
76+
if self.hetero_inds is not None:
77+
self.Beta[np.ix_(self.endo_inds, self.hetero_inds)] = 0
78+
self.Beta[np.ix_(self.hetero_inds, self.endo_inds)] = 0
79+
80+
self.epsilon = np.random.uniform(-1, 1, size=self.n_treatments)
81+
self.zeta = np.zeros(self.n_x)
82+
self.zeta[:self.s_x] = self.conf_str / self.s_x
83+
84+
self.y_hetero_effect = np.zeros(self.n_x)
85+
self.x_hetero_effect = np.zeros(self.n_x)
86+
if self.hetero_inds is not None:
87+
self.y_hetero_effect[self.hetero_inds] = np.random.uniform(.5 * hetero_strength, 1.5 * hetero_strength) / \
88+
len(self.hetero_inds)
89+
self.x_hetero_effect[self.hetero_inds] = np.random.uniform(.5 * hetero_strength, 1.5 * hetero_strength) / \
90+
len(self.hetero_inds)
91+
92+
self.true_effect = np.zeros((self.n_periods, self.n_treatments))
93+
self.true_effect[0] = self.epsilon
94+
for t in np.arange(1, self.n_periods):
95+
self.true_effect[t, :] = (self.zeta.reshape(
96+
1, -1) @ np.linalg.matrix_power(self.Beta, t - 1) @ self.Alpha)
97+
98+
self.true_hetero_effect = np.zeros(
99+
(self.n_periods, (self.n_x + 1) * self.n_treatments))
100+
self.true_hetero_effect[0, :] = cross_product(
101+
add_constant(self.y_hetero_effect.reshape(1, -1), has_constant='add'),
102+
self.epsilon.reshape(1, -1))
103+
for t in np.arange(1, self.n_periods):
104+
self.true_hetero_effect[t, :] = cross_product(
105+
add_constant(self.x_hetero_effect.reshape(1, -1), has_constant='add'),
106+
self.zeta.reshape(1, -1) @ np.linalg.matrix_power(self.Beta, t - 1) @ self.Alpha)
107+
return self
108+
109+
def hetero_effect_fn(self, t, x):
110+
if t == 0:
111+
return (np.dot(self.y_hetero_effect, x.flatten()) + 1) * self.epsilon
112+
else:
113+
return (np.dot(self.x_hetero_effect, x.flatten()) + 1) *\
114+
(self.zeta.reshape(1, -1) @ np.linalg.matrix_power(self.Beta, t - 1)
115+
@ self.Alpha).flatten()
116+
117+
def _gen_data_with_policy(self, n_units, policy_gen, random_seed=123):
118+
np.random.seed(random_seed)
119+
Y = np.zeros(n_units * self.n_periods)
120+
T = np.zeros((n_units * self.n_periods, self.n_treatments))
121+
X = np.zeros((n_units * self.n_periods, self.n_x))
122+
groups = np.zeros(n_units * self.n_periods)
123+
for t in range(n_units * self.n_periods):
124+
period = t % self.n_periods
125+
if period == 0:
126+
X[t] = np.random.normal(0, self.sigma_x, size=self.n_x)
127+
T[t] = policy_gen(np.zeros(self.n_treatments), X[t], period)
128+
else:
129+
X[t] = (np.dot(self.x_hetero_effect, X[t - 1]) + 1) * np.dot(self.Alpha, T[t - 1]) + \
130+
np.dot(self.Beta, X[t - 1]) + \
131+
np.random.normal(0, self.sigma_x, size=self.n_x)
132+
T[t] = policy_gen(T[t - 1], X[t], period)
133+
Y[t] = (np.dot(self.y_hetero_effect, X[t]) + 1) * np.dot(self.epsilon, T[t]) + \
134+
np.dot(X[t], self.zeta) + \
135+
np.random.normal(0, self.sigma_y)
136+
groups[t] = t // self.n_periods
137+
138+
return Y, T, X[:, self.hetero_inds] if self.hetero_inds else None, X[:, self.endo_inds], groups
139+
140+
def observational_data(self, n_units, gamma=0, s_t=1, sigma_t=0.5, random_seed=123):
141+
""" Generated observational data with some observational treatment policy parameters
142+
143+
Parameters
144+
----------
145+
n_units : how many units to observe
146+
gamma : what is the degree of auto-correlation of the treatments across periods
147+
s_t : sparsity of treatment policy; how many states does it depend on
148+
sigma_t : what is the std of the exploration/randomness in the treatment
149+
"""
150+
Delta = np.zeros((self.n_treatments, self.n_x))
151+
Delta[:, :s_t] = self.conf_str / s_t
152+
153+
def policy_gen(Tpre, X, period):
154+
return gamma * Tpre + (1 - gamma) * np.dot(Delta, X) + \
155+
np.random.normal(0, sigma_t, size=self.n_treatments)
156+
return self._gen_data_with_policy(n_units, policy_gen, random_seed=random_seed)

econml/tests/test_dynamic_dml.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from econml.inference import BootstrapInference, EmpiricalInferenceResults, NormalInferenceResults
1313
from econml.utilities import shape, hstack, vstack, reshape, cross_product
1414
import econml.tests.utilities # bugfix for assertWarns
15+
from econml.tests.dgp import DynamicPanelDGP
1516

1617

1718
class TestDynamicDML(unittest.TestCase):
@@ -79,7 +80,6 @@ def make_random(n, is_discrete, d):
7980
(d_y if d_y > 0 else 1) * (d_t_final if d_t_final > 0 else 1), 6)
8081

8182
all_infs = [None, 'auto', BootstrapInference(2)]
82-
#all_infs = [None, 'auto']
8383
est = DynamicDML(model_y=Lasso() if d_y < 1 else MultiTaskLasso(),
8484
model_t=LogisticRegression() if is_discrete else
8585
(Lasso() if d_t < 1 else MultiTaskLasso()),
@@ -256,3 +256,40 @@ def make_random(n, is_discrete, d):
256256
eff = est.effect(X) if not is_discrete else est.effect(
257257
X, T0='a', T1='b')
258258
self.assertEqual(shape(eff), effect_shape2)
259+
260+
def test_perf(self):
261+
np.random.seed(123)
262+
n_units = 400
263+
n_periods = 3
264+
n_treatments = 1
265+
n_x = 100
266+
s_x = 10
267+
s_t = 10
268+
hetero_strength = .5
269+
hetero_inds = np.arange(n_x - n_treatments, n_x)
270+
alpha_regs = [1e-4, 1e-3, 1e-2, 5e-2, .1, 1]
271+
272+
def lasso_model():
273+
return LassoCV(cv=3, alphas=alpha_regs, max_iter=500)
274+
# No heterogeneity
275+
dgp = DynamicPanelDGP(n_periods, n_treatments, n_x).create_instance(
276+
s_x, random_seed=1)
277+
Y, T, X, W, groups = dgp.observational_data(n_units, s_t=s_t, random_seed=12)
278+
est = DynamicDML(model_y=lasso_model(), model_t=lasso_model(), cv=3)
279+
est.fit(Y, T, X=X, W=W, groups=groups, inference="auto")
280+
np.testing.assert_allclose(est.intercept_, dgp.true_effect.flatten(), atol=1e-01)
281+
np.testing.assert_array_less(est.intercept__interval()[0], dgp.true_effect.flatten())
282+
np.testing.assert_array_less(dgp.true_effect.flatten(), est.intercept__interval()[1])
283+
# Heterogeneous effects
284+
hetero_strength = .5
285+
hetero_inds = np.arange(n_x - n_treatments, n_x)
286+
dgp = DynamicPanelDGP(n_periods, n_treatments, n_x).create_instance(
287+
s_x, hetero_strength=hetero_strength, hetero_inds=hetero_inds, random_seed=1)
288+
Y, T, X, W, groups = dgp.observational_data(n_units, s_t=s_t, random_seed=12)
289+
est.fit(Y, T, X=X, W=W, groups=groups, inference="auto")
290+
np.testing.assert_allclose(est.intercept_, dgp.true_effect.flatten(), atol=0.2)
291+
np.testing.assert_allclose(est.coef_, dgp.true_hetero_effect[:, hetero_inds + 1], atol=0.2)
292+
np.testing.assert_array_less(est.intercept__interval()[0], dgp.true_effect.flatten())
293+
np.testing.assert_array_less(dgp.true_effect.flatten(), est.intercept__interval()[1])
294+
np.testing.assert_array_less(est.coef__interval()[0], dgp.true_hetero_effect[:, hetero_inds + 1])
295+
np.testing.assert_array_less(dgp.true_hetero_effect[:, hetero_inds + 1], est.coef__interval()[1])

0 commit comments

Comments
 (0)