Skip to content

Commit e4bfa4a

Browse files
BiynXudiadestiny
authored andcommitted
[CINN] fix multi subgraph unittest (PaddlePaddle#61909)
1 parent eee8919 commit e4bfa4a

4 files changed

Lines changed: 48 additions & 8 deletions

File tree

test/ir/pir/cinn/symbolic/test_llama_mlp_dy.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,28 @@ def __init__(self):
3434
self.hidden_size = 768
3535
self.intermediate_size = 1008
3636
self.gate_proj = nn.Linear(
37-
self.hidden_size, self.intermediate_size, bias_attr=False
37+
self.hidden_size,
38+
self.intermediate_size,
39+
weight_attr=paddle.ParamAttr(
40+
initializer=nn.initializer.Constant(value=0.5)
41+
),
42+
bias_attr=False,
3843
)
3944
self.up_proj = nn.Linear(
40-
self.hidden_size, self.intermediate_size, bias_attr=False
45+
self.hidden_size,
46+
self.intermediate_size,
47+
weight_attr=paddle.ParamAttr(
48+
initializer=nn.initializer.Constant(value=0.5)
49+
),
50+
bias_attr=False,
4151
)
4252
self.down_proj = nn.Linear(
43-
self.intermediate_size, self.hidden_size, bias_attr=False
53+
self.intermediate_size,
54+
self.hidden_size,
55+
weight_attr=paddle.ParamAttr(
56+
initializer=nn.initializer.Constant(value=0.5)
57+
),
58+
bias_attr=False,
4459
)
4560

4661
def forward(self, x):

test/ir/pir/cinn/symbolic/test_llama_mlp_st.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,28 @@ def __init__(self):
3434
self.hidden_size = 768
3535
self.intermediate_size = 1008
3636
self.gate_proj = nn.Linear(
37-
self.hidden_size, self.intermediate_size, bias_attr=False
37+
self.hidden_size,
38+
self.intermediate_size,
39+
weight_attr=paddle.ParamAttr(
40+
initializer=nn.initializer.Constant(value=0.5)
41+
),
42+
bias_attr=False,
3843
)
3944
self.up_proj = nn.Linear(
40-
self.hidden_size, self.intermediate_size, bias_attr=False
45+
self.hidden_size,
46+
self.intermediate_size,
47+
weight_attr=paddle.ParamAttr(
48+
initializer=nn.initializer.Constant(value=0.5)
49+
),
50+
bias_attr=False,
4151
)
4252
self.down_proj = nn.Linear(
43-
self.intermediate_size, self.hidden_size, bias_attr=False
53+
self.intermediate_size,
54+
self.hidden_size,
55+
weight_attr=paddle.ParamAttr(
56+
initializer=nn.initializer.Constant(value=0.5)
57+
),
58+
bias_attr=False,
4459
)
4560

4661
def forward(self, x):

test/ir/pir/cinn/symbolic/test_multiple_subgraph_dy.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ class MultipleSubgraph(nn.Layer):
3131
def __init__(self):
3232
super().__init__()
3333
self.hidden_size = 768
34-
self.mlp = nn.Linear(self.hidden_size, self.hidden_size)
34+
self.weight_attr = paddle.ParamAttr(
35+
initializer=nn.initializer.Constant(value=0.5)
36+
)
37+
self.mlp = nn.Linear(
38+
self.hidden_size, self.hidden_size, weight_attr=self.weight_attr
39+
)
3540

3641
def exp_sub(self, x):
3742
y = paddle.exp(x)

test/ir/pir/cinn/symbolic/test_multiple_subgraph_st.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ class MultipleSubgraph(nn.Layer):
3131
def __init__(self):
3232
super().__init__()
3333
self.hidden_size = 768
34-
self.mlp = nn.Linear(self.hidden_size, self.hidden_size)
34+
self.weight_attr = paddle.ParamAttr(
35+
initializer=nn.initializer.Constant(value=0.5)
36+
)
37+
self.mlp = nn.Linear(
38+
self.hidden_size, self.hidden_size, weight_attr=self.weight_attr
39+
)
3540

3641
def exp_sub(self, x):
3742
y = paddle.exp(x)

0 commit comments

Comments
 (0)