diff --git a/test/ir/pir/cinn/symbolic/test_llama_mlp_dy.py b/test/ir/pir/cinn/symbolic/test_llama_mlp_dy.py index a21cb66c3143ae..c765e1c7c20b73 100644 --- a/test/ir/pir/cinn/symbolic/test_llama_mlp_dy.py +++ b/test/ir/pir/cinn/symbolic/test_llama_mlp_dy.py @@ -34,13 +34,28 @@ def __init__(self): self.hidden_size = 768 self.intermediate_size = 1008 self.gate_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias_attr=False + self.hidden_size, + self.intermediate_size, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(value=0.5) + ), + bias_attr=False, ) self.up_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias_attr=False + self.hidden_size, + self.intermediate_size, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(value=0.5) + ), + bias_attr=False, ) self.down_proj = nn.Linear( - self.intermediate_size, self.hidden_size, bias_attr=False + self.intermediate_size, + self.hidden_size, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(value=0.5) + ), + bias_attr=False, ) def forward(self, x): diff --git a/test/ir/pir/cinn/symbolic/test_llama_mlp_st.py b/test/ir/pir/cinn/symbolic/test_llama_mlp_st.py index f164fb7afc5d0a..74606daf05eae7 100644 --- a/test/ir/pir/cinn/symbolic/test_llama_mlp_st.py +++ b/test/ir/pir/cinn/symbolic/test_llama_mlp_st.py @@ -34,13 +34,28 @@ def __init__(self): self.hidden_size = 768 self.intermediate_size = 1008 self.gate_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias_attr=False + self.hidden_size, + self.intermediate_size, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(value=0.5) + ), + bias_attr=False, ) self.up_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias_attr=False + self.hidden_size, + self.intermediate_size, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(value=0.5) + ), + bias_attr=False, ) self.down_proj = nn.Linear( - self.intermediate_size, self.hidden_size, bias_attr=False + self.intermediate_size, + self.hidden_size, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Constant(value=0.5) + ), + bias_attr=False, ) def forward(self, x): diff --git a/test/ir/pir/cinn/symbolic/test_multiple_subgraph_dy.py b/test/ir/pir/cinn/symbolic/test_multiple_subgraph_dy.py index b2b993d522c3fa..b0de86534ffe08 100644 --- a/test/ir/pir/cinn/symbolic/test_multiple_subgraph_dy.py +++ b/test/ir/pir/cinn/symbolic/test_multiple_subgraph_dy.py @@ -31,7 +31,12 @@ class MultipleSubgraph(nn.Layer): def __init__(self): super().__init__() self.hidden_size = 768 - self.mlp = nn.Linear(self.hidden_size, self.hidden_size) + self.weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Constant(value=0.5) + ) + self.mlp = nn.Linear( + self.hidden_size, self.hidden_size, weight_attr=self.weight_attr + ) def exp_sub(self, x): y = paddle.exp(x) diff --git a/test/ir/pir/cinn/symbolic/test_multiple_subgraph_st.py b/test/ir/pir/cinn/symbolic/test_multiple_subgraph_st.py index df77f0f2811e36..0df0796b0df143 100644 --- a/test/ir/pir/cinn/symbolic/test_multiple_subgraph_st.py +++ b/test/ir/pir/cinn/symbolic/test_multiple_subgraph_st.py @@ -31,7 +31,12 @@ class MultipleSubgraph(nn.Layer): def __init__(self): super().__init__() self.hidden_size = 768 - self.mlp = nn.Linear(self.hidden_size, self.hidden_size) + self.weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Constant(value=0.5) + ) + self.mlp = nn.Linear( + self.hidden_size, self.hidden_size, weight_attr=self.weight_attr + ) def exp_sub(self, x): y = paddle.exp(x)