Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8696c3d
add nn.init.kaiming_uniform_
zhiminzhang0830 Aug 7, 2025
87d9313
update kaiming_uniform_
zhiminzhang0830 Aug 11, 2025
5b8a1a5
update unit test for kaiming_uniform_
zhiminzhang0830 Aug 11, 2025
3eaa19b
add nn.init.kaiming_uniform_
zhiminzhang0830 Aug 7, 2025
b8e6b9c
update kaiming_uniform_
zhiminzhang0830 Aug 11, 2025
e791c18
update unit test for kaiming_uniform_
zhiminzhang0830 Aug 11, 2025
e2b2cc9
add xavier_uniform_, kaiming_normal_, uniform_
zhiminzhang0830 Aug 11, 2025
d2d614a
add unit test for xavier_uniform_, kaiming_normal_, uniform_
zhiminzhang0830 Aug 11, 2025
50cfb5c
add xavier_normal_ and its unit test
zhiminzhang0830 Aug 11, 2025
2c08d23
add normal_ and its unit test
zhiminzhang0830 Aug 11, 2025
c4c6917
fix: remove 'block' parameter from init.*() function
zhiminzhang0830 Aug 11, 2025
d31e3d3
fix
zhiminzhang0830 Aug 11, 2025
b5ccf0a
add nn.init.trunc_normal_ and its unit test
zhiminzhang0830 Aug 11, 2025
31cdc8b
add nn.init.constant_, nn.init.ones_, nn.init.zeros_
zhiminzhang0830 Aug 11, 2025
44d9d26
support paddle.pir.Value type
zhiminzhang0830 Aug 12, 2025
5afa04c
add dirac_, eye_, orthogonal_
zhiminzhang0830 Aug 12, 2025
ecb4da0
update unit test for nn.init.*
zhiminzhang0830 Aug 12, 2025
0c8bfd1
update init
zhiminzhang0830 Aug 12, 2025
4d4334f
add paddle.pir.Value
zhiminzhang0830 Aug 12, 2025
08a85c6
update unit test for nn.init.orthogonal_
zhiminzhang0830 Aug 12, 2025
1d4550e
Merge remote-tracking branch 'upstream/develop' into init
zhiminzhang0830 Aug 13, 2025
d330635
fix unit test for nn.init.eye_
zhiminzhang0830 Aug 14, 2025
02555a7
fix: skip unit test on dcu
zhiminzhang0830 Aug 14, 2025
d10e3bc
Merge remote-tracking branch 'upstream/develop' into init
zhiminzhang0830 Aug 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from . import functional, initializer, quant, utils # noqa: F401

from . import functional, init, initializer, quant, utils # noqa: F401
from .clip import ClipGradByGlobalNorm, ClipGradByNorm, ClipGradByValue
from .decode import BeamSearchDecoder, dynamic_decode

Expand Down
318 changes: 318 additions & 0 deletions python/paddle/nn/init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import paddle

from ..base.framework import in_dygraph_mode, in_pir_mode
from .initializer.constant import Constant
from .initializer.dirac import Dirac
from .initializer.initializer import calculate_gain # noqa: F401
from .initializer.kaiming import KaimingNormal, KaimingUniform
from .initializer.normal import Normal, TruncatedNormal
from .initializer.orthogonal import Orthogonal
from .initializer.uniform import Uniform
from .initializer.xavier import XavierNormal, XavierUniform


def kaiming_uniform_(
tensor: paddle.Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
) -> paddle.Tensor | None:
"""Modify tensor inplace using Kaiming uniform method.

Args:
tensor (Tensor): Paddle Tensor.
a (float, optional): The negative slope of the rectifier used after this layer.
Defaults to 0.
mode (str, optional): Mode to compute the fan. Choose from ["fan_in", "fan_out"].
When set to 'fan_in', the fan_in parameter is used for initialization.
When set to 'fan_out', the out_features of trainable Tensor will be used.
Default is 'fan_in'.
nonlinearity (str, optional): Nonlinearity method name. Defaults to "leaky_relu".

Returns:
Tensor: Initialized tensor.
"""
init = KaimingUniform(
negative_slope=a, nonlinearity=nonlinearity, mode=mode
)

return init(tensor)


def kaiming_normal_(
tensor: paddle.Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
) -> paddle.Tensor | None:
"""Modify tensor inplace using Kaiming normal method.

Args:
tensor (Tensor): Paddle Tensor.
a (float, optional): The negative slope of the rectifier used after this layer.
Defaults to 0.
mode (str, optional): Mode to compute the fan. Choose from ["fan_in", "fan_out"].
When set to 'fan_in', the fan_in parameter is used for initialization.
When set to 'fan_out', the out_features of trainable Tensor will be used.
Default is 'fan_in'.
nonlinearity (str, optional): Nonlinearity method name. Defaults to "leaky_relu".

Returns:
Tensor: Initialized tensor.
"""
init = KaimingNormal(negative_slope=a, nonlinearity=nonlinearity, mode=mode)

return init(tensor)


def xavier_uniform_(
tensor: paddle.Tensor,
gain: float = 1.0,
fan_in: float | None = None,
fan_out: float | None = None,
) -> paddle.Tensor | None:
"""Modify tensor inplace using Xavier uniform method.

Args:
tensor (Tensor): Paddle Tensor.
gain (float, optional): Scaling Tensor. Default is 1.0.
fan_in (float|None, optional): fan_in for Xavier initialization, which is
inferred from the Tensor. Default is None.
fan_out (float|None, optional): fan_out for Xavier initialization, which is
inferred from the Tensor. Default is None.

Returns:
Tensor: Initialized tensor.
"""
init = XavierUniform(
gain=gain,
fan_in=fan_in,
fan_out=fan_out,
)

return init(tensor)


def xavier_normal_(
tensor: paddle.Tensor,
gain: float = 1.0,
fan_in: float | None = None,
fan_out: float | None = None,
) -> paddle.Tensor | None:
"""Modify tensor inplace using Xavier normal method.

Args:
tensor (Tensor): Paddle Tensor.
gain (float, optional): Scaling Tensor. Default is 1.0.
fan_in (float|None, optional): fan_in for Xavier initialization, which is
inferred from the Tensor. Default is None.
fan_out (float|None, optional): fan_out for Xavier initialization, which is
inferred from the Tensor. Default is None.

Returns:
Tensor: Initialized tensor.
"""
init = XavierNormal(
gain=gain,
fan_in=fan_in,
fan_out=fan_out,
)

return init(tensor)


def uniform_(
tensor: paddle.Tensor,
a: float = 0.0,
b: float = 1.0,
) -> paddle.Tensor | None:
"""Modify tensor inplace using uniform method.

Args:
tensor (Tensor): Paddle Tensor.
low (float, optional): Lower boundary of the uniform distribution. Default is :math:`-1.0`.
high (float, optional): Upper boundary of the uniform distribution. Default is :math:`1.0`.

Returns:
Tensor: Initialized tensor.
"""
init = Uniform(low=a, high=b)

return init(tensor)


def normal_(
tensor: paddle.Tensor,
mean: float = 0.0,
std: float = 1.0,
) -> paddle.Tensor | None:
"""Modify tensor inplace using normal method.

Args:
tensor (Tensor): Paddle Tensor.
mean (float|complex, optional): mean of the normal distribution. Default is 0.0.
std (float, optional): standard deviation of the normal distribution. Default is 1.0.

Returns:
Tensor: Initialized tensor.
"""
init = Normal(mean=mean, std=std)

return init(tensor)


def trunc_normal_(
tensor: paddle.Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
) -> paddle.Tensor | None:
"""Modify tensor inplace using truncated normal method.

Args:
tensor (Tensor): Paddle Tensor.
mean (float|complex, optional): mean of the normal distribution. Default is 0.0.
std (float, optional): standard deviation of the normal distribution. Default is 1.0.
a (float, optional): The minimum cutoff value. Default is -2.0.
b (float, optional): The maximum cutoff value. Default is 2.0.

Returns:
Tensor: Initialized tensor.
"""
init = TruncatedNormal(mean=mean, std=std, a=a, b=b)

return init(tensor)


def constant_(
tensor: paddle.Tensor,
val: float,
) -> paddle.Tensor | None:
"""Modify tensor inplace using constant method.

Args:
tensor (Tensor): Paddle Tensor.
value (float32|float64, optional): constant value to initialize the parameter.

Returns:
Tensor: Initialized tensor.
"""
init = Constant(value=val)

return init(tensor)


def ones_(
tensor: paddle.Tensor,
) -> paddle.Tensor | None:
"""Fill the input Tensor with the scalar value 1.

Args:
tensor (Tensor): Paddle Tensor.

Returns:
Tensor: Initialized tensor.
"""
init = Constant(value=1.0)

return init(tensor)


def zeros_(
tensor: paddle.Tensor,
) -> paddle.Tensor | None:
"""Fill the input Tensor with the scalar value 0.

Args:
tensor (Tensor): Paddle Tensor.

Returns:
Tensor: Initialized tensor.
"""
init = Constant(value=0.0)

return init(tensor)


def dirac_(
tensor: paddle.Tensor,
groups: int = 1,
) -> paddle.Tensor | None:
"""Initialize the 3D/4D/5D Tensor with Dirac delta function.

Args:
tensor (Tensor): Paddle Tensor.
groups (int|None, optional): 0-dimension of the Tensor will be divided by groups,
each group has the same value. Default: 1.
Returns:
Tensor: Initialized tensor.
"""
init = Dirac(groups=groups)

return init(tensor)


def eye_(
tensor: paddle.Tensor,
) -> paddle.Tensor | None:
"""Fill the 2-dimensional input Tensor with the identity matrix.

Args:
tensor (Tensor): Paddle Tensor.
Returns:
Tensor: Initialized tensor.
"""

if len(tensor.shape) != 2:
raise AssertionError(
f"Only support 2 dimensional tensor, but got {len(tensor.shape)}."
)

if in_dygraph_mode():
new_tensor = paddle.eye(
tensor.shape[0], tensor.shape[1], dtype=tensor.dtype
)
new_tensor._share_underline_tensor_to(tensor)
return None
elif in_pir_mode():
new_tensor = paddle.eye(
tensor.shape[0], tensor.shape[1], dtype=tensor.dtype
)
return new_tensor
else:
raise NotImplementedError(
'Only support run in dygraph mode or PIR mode.'
)


def orthogonal_(
tensor: paddle.Tensor,
gain: float = 1,
) -> paddle.Tensor | None:
"""Fill the input Tensor with a (semi) orthogonal matrix.

Args:
tensor (Tensor): Paddle Tensor.
gain(float, optional): The multiplication coefficient for initialized tensor. Default: 1.0.
Returns:
Tensor: Initialized tensor.
"""
init = Orthogonal(gain=gain)
return init(tensor)
4 changes: 3 additions & 1 deletion python/paddle/nn/initializer/dirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def __call__(
isinstance(var, framework.EagerParamBase) and var.is_dist()
), "Currently, dirac initializer not support lazy init for dist param."
block = self._check_block(block)
assert isinstance(var, (framework.Variable, pir.core.ParameterMeta))
assert isinstance(
var, (framework.Variable, paddle.pir.Value, pir.core.ParameterMeta)
)
assert isinstance(block, (framework.Block, pir.Block))
check_variable_and_dtype(
var, "Out", ['float16', 'bfloat16', 'float32', 'float64'], 'Dirac'
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/nn/initializer/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@
"conv2d",
"conv3d",
"conv1d_transpose",
"conv_transpose1d",
"conv2d_transpose",
"conv_transpose2d",
"conv3d_transpose",
"conv_transpose3d",
"tanh",
"relu",
"leaky_relu",
Expand Down Expand Up @@ -193,8 +196,11 @@ def calculate_gain(
'conv2d': 1,
'conv3d': 1,
'conv1d_transpose': 1,
'conv_transpose1d': 1,
'conv2d_transpose': 1,
'conv_transpose2d': 1,
'conv3d_transpose': 1,
'conv_transpose3d': 1,
'tanh': 5.0 / 3,
'relu': math.sqrt(2.0),
'leaky_relu': math.sqrt(2.0 / (1 + param**2)),
Expand Down
7 changes: 6 additions & 1 deletion python/paddle/nn/initializer/kaiming.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,12 @@ def forward(
), "Currently, kaiming initializer not support lazy init for dist param."
block = self._check_block(block)
assert isinstance(
var, (framework.Variable, paddle.pir.core.ParameterMeta)
var,
(
framework.Variable,
paddle.pir.Value,
paddle.pir.core.ParameterMeta,
),
)
assert isinstance(block, (framework.Block, paddle.pir.Block))
f_in, f_out = self._compute_fans(var)
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/nn/initializer/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,11 @@ def forward(
core.eager.Tensor,
)
else:
expected = (framework.Variable, paddle.pir.core.ParameterMeta)
expected = (
framework.Variable,
paddle.pir.Value,
paddle.pir.core.ParameterMeta,
)

assert isinstance(var, expected)
assert isinstance(block, (framework.Block, pir.Block))
Expand Down
Loading
Loading