Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8696c3d
add nn.init.kaiming_uniform_
zhiminzhang0830 Aug 7, 2025
87d9313
update kaiming_uniform_
zhiminzhang0830 Aug 11, 2025
5b8a1a5
update unit test for kaiming_uniform_
zhiminzhang0830 Aug 11, 2025
3eaa19b
add nn.init.kaiming_uniform_
zhiminzhang0830 Aug 7, 2025
b8e6b9c
update kaiming_uniform_
zhiminzhang0830 Aug 11, 2025
e791c18
update unit test for kaiming_uniform_
zhiminzhang0830 Aug 11, 2025
e2b2cc9
add xavier_uniform_, kaiming_normal_, uniform_
zhiminzhang0830 Aug 11, 2025
d2d614a
add unit test for xavier_uniform_, kaiming_normal_, uniform_
zhiminzhang0830 Aug 11, 2025
50cfb5c
add xavier_normal_ and its unit test
zhiminzhang0830 Aug 11, 2025
2c08d23
add normal_ and its unit test
zhiminzhang0830 Aug 11, 2025
c4c6917
fix: remove 'block' parameter from init.*() function
zhiminzhang0830 Aug 11, 2025
d31e3d3
fix
zhiminzhang0830 Aug 11, 2025
b5ccf0a
add nn.init.trunc_normal_ and its unit test
zhiminzhang0830 Aug 11, 2025
31cdc8b
add nn.init.constant_, nn.init.ones_, nn.init.zeros_
zhiminzhang0830 Aug 11, 2025
44d9d26
support paddle.pir.Value type
zhiminzhang0830 Aug 12, 2025
5afa04c
add dirac_, eye_, orthogonal_
zhiminzhang0830 Aug 12, 2025
ecb4da0
update unit test for nn.init.*
zhiminzhang0830 Aug 12, 2025
0c8bfd1
update init
zhiminzhang0830 Aug 12, 2025
4d4334f
add paddle.pir.Value
zhiminzhang0830 Aug 12, 2025
08a85c6
update unit test for nn.init.orthogonal_
zhiminzhang0830 Aug 12, 2025
1d4550e
Merge remote-tracking branch 'upstream/develop' into init
zhiminzhang0830 Aug 13, 2025
d330635
fix unit test for nn.init.eye_
zhiminzhang0830 Aug 14, 2025
02555a7
fix: skip unit test on dcu
zhiminzhang0830 Aug 14, 2025
d10e3bc
Merge remote-tracking branch 'upstream/develop' into init
zhiminzhang0830 Aug 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.nn import init as init
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不需要导入init,直接paddle.nn.init会自动访问到 init.py 下面的函数吧,看torch是这样的

Copy link
Contributor Author

@zhiminzhang0830 zhiminzhang0830 Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


from . import functional, initializer, quant, utils # noqa: F401
from .clip import ClipGradByGlobalNorm, ClipGradByNorm, ClipGradByValue
from .decode import BeamSearchDecoder, dynamic_decode
Expand Down Expand Up @@ -319,4 +321,5 @@
'LPPool2D',
'ZeroPad1D',
'ZeroPad3D',
'init',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个导出来是作为1个api吗,好像不用导出来

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是由于__all__的原因,不导出的话这样使用会报错:

import paddle
tensor = paddle.zeros([32, 64])
paddle.nn.init.kaiming_uniform_(tensor)
AttributeError: module 'paddle.nn' has no attribute 'init'

]
68 changes: 68 additions & 0 deletions python/paddle/nn/init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import paddle

from .initializer.initializer import calculate_gain, compute_fans


def _no_grad_uniform_(tensor, a, b):
with paddle.no_grad():
tensor.set_value(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

静态图下这个能跑不,如果不能跑 可以考虑复用之前的initializer,先实例化然后call:

init = paddle.nn.initializer()
init(param)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

paddle.uniform(shape=tensor.shape, dtype=tensor.dtype, min=a, max=b)
)
return tensor


def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ["fan_in", "fan_out"]
if mode not in valid_modes:
raise ValueError(
f"Mode {mode} not supported, please use one of {valid_modes}"
)

fan_in, fan_out = compute_fans(tensor)

return fan_in if mode == "fan_in" else fan_out


def kaiming_uniform_(
tensor: paddle.Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
) -> paddle.Tensor:
"""Modify tensor inplace using Kaiming uniform method.
Args:
tensor (Tensor): Paddle Tensor.
a (float, optional): The negative slope of the rectifier used after this layer.
Defaults to 0.
mode (str, optional): Mode to compute the fan. Choose from ["fan_in", "fan_out"].
When set to 'fan_in', the fan_in parameter is used for initialization.
When set to 'fan_out', the out_features of trainable Tensor will be used.
Default is 'fan_in'.
nonlinearity (str, optional): Nonlinearity method name. Defaults to "leaky_relu".
Returns:
Tensor: Initialized tensor.
"""
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
k = math.sqrt(3.0) * std
return _no_grad_uniform_(tensor, -k, k)
74 changes: 39 additions & 35 deletions python/paddle/nn/initializer/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,44 +109,45 @@ def _check_block(self, block: paddle.pir.Block | None) -> paddle.pir.Block:

return block

def _compute_fans(self, var: paddle.Tensor) -> tuple[int, int]:
"""Compute the fan_in and the fan_out for layers

This method computes the fan_in and the fan_out
for neural network layers, if not specified. It is
not possible to perfectly estimate fan_in and fan_out.
This method will estimate it correctly for matrix multiply and
convolutions.
def compute_fans(var: paddle.Tensor) -> tuple[int, int]:
"""Compute the fan_in and the fan_out for layers

Args:
var: variable for which fan_in and fan_out have to be computed.
This method computes the fan_in and the fan_out
for neural network layers, if not specified. It is
not possible to perfectly estimate fan_in and fan_out.
This method will estimate it correctly for matrix multiply and
convolutions.

Returns:
tuple of two integers (fan_in, fan_out).
"""
shape = (
var._local_shape
if (isinstance(var, EagerParamBase) and var.is_dist())
else var.shape
)
if not shape or len(shape) == 0:
fan_in = fan_out = 1
elif len(shape) == 1:
fan_in = fan_out = shape[0]
elif len(shape) == 2:
# This is the case for simple matrix multiply
fan_in = shape[0]
fan_out = shape[1]
else:
# Assume this to be a convolutional kernel
# In PaddlePaddle, the shape of the kernel is like:
# [num_filters, num_filter_channels, ...] where the remaining
# dimensions are the filter_size
receptive_field_size = np.prod(shape[2:])
fan_in = shape[1] * receptive_field_size
fan_out = shape[0] * receptive_field_size

return (fan_in, fan_out)
Args:
var: variable for which fan_in and fan_out have to be computed.

Returns:
tuple of two integers (fan_in, fan_out).
"""
shape = (
var._local_shape
if (isinstance(var, EagerParamBase) and var.is_dist())
else var.shape
)
if not shape or len(shape) == 0:
fan_in = fan_out = 1
elif len(shape) == 1:
fan_in = fan_out = shape[0]
elif len(shape) == 2:
# This is the case for simple matrix multiply
fan_in = shape[0]
fan_out = shape[1]
else:
# Assume this to be a convolutional kernel
# In PaddlePaddle, the shape of the kernel is like:
# [num_filters, num_filter_channels, ...] where the remaining
# dimensions are the filter_size
receptive_field_size = np.prod(shape[2:])
fan_in = shape[1] * receptive_field_size
fan_out = shape[0] * receptive_field_size

return (fan_in, fan_out)


def calculate_gain(
Expand Down Expand Up @@ -193,8 +194,11 @@ def calculate_gain(
'conv2d': 1,
'conv3d': 1,
'conv1d_transpose': 1,
'conv_transpose1d': 1,
'conv2d_transpose': 1,
'conv_transpose2d': 1,
'conv3d_transpose': 1,
'conv_transpose3d': 1,
'tanh': 5.0 / 3,
'relu': math.sqrt(2.0),
'leaky_relu': math.sqrt(2.0 / (1 + param**2)),
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/initializer/kaiming.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
in_dygraph_mode,
in_pir_mode,
)
from .initializer import Initializer, calculate_gain
from .initializer import Initializer, calculate_gain, compute_fans

if TYPE_CHECKING:
from .initializer import _NonLinearity
Expand Down Expand Up @@ -120,7 +120,7 @@ def forward(
var, (framework.Variable, paddle.pir.core.ParameterMeta)
)
assert isinstance(block, (framework.Block, paddle.pir.Block))
f_in, f_out = self._compute_fans(var)
f_in, f_out = compute_fans(var)

# If fan_in is passed, use it
if self._mode == 'fan_in':
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/initializer/xavier.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
in_dygraph_mode,
in_pir_mode,
)
from .initializer import Initializer
from .initializer import Initializer, compute_fans

__all__ = []

Expand Down Expand Up @@ -109,7 +109,7 @@ def forward(
"xavier_init",
)

f_in, f_out = self._compute_fans(var)
f_in, f_out = compute_fans(var)

# If fan_in and fan_out are passed, use them
fan_in = f_in if self._fan_in is None else self._fan_in
Expand Down
96 changes: 96 additions & 0 deletions test/legacy_test/test_nn_init_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import random
import unittest

import numpy as np
from scipy import stats

import paddle
from paddle import nn


def get_uniform_min_and_max(weight):
min_value = np.min(weight)
max_value = np.max(weight)
return min_value, max_value


class TestKaimingUniform(unittest.TestCase):
def _test_kaiming_uniform_common(self, tensor):
init = paddle.nn.init.kaiming_uniform_
init(tensor, a=0, mode="fan_in", nonlinearity="leaky_relu")
init(tensor, a=-0.2, mode="fan_out", nonlinearity="leaky_relu")
init(tensor, a=0, mode="fan_in", nonlinearity="relu")
init(tensor, a=0, mode="fan_out", nonlinearity="relu")

def test_kaiming_uniform_linear(self):
linear = nn.Linear(40, 20)
self._test_kaiming_uniform_common(linear.weight)

def _create_random_nd_tensor(self, dims, size_min, size_max):
size = [random.randint(size_min, size_max) for _ in range(dims)]
tensor = paddle.zeros(size)
return tensor

def _is_uniform(self, tensor, a, b):
samples = tensor.view([-1]).tolist()
p_value = stats.kstest(samples, "uniform", args=(a, (b - a)))[1]
return p_value > 0.0001

def _random_float(self, a, b):
return (b - a) * random.random() + a

def test_kaiming_uniform(self):
for use_a in [True, False]:
for dims in [2, 4]:
for mode in ["fan_in", "fan_out"]:
input_tensor = self._create_random_nd_tensor(
dims, size_min=20, size_max=25
)
if use_a:
a = self._random_float(0.1, 2)
paddle.nn.init.kaiming_uniform_(
input_tensor, a=a, mode=mode
)
else:
a = 0
paddle.nn.init.kaiming_uniform_(input_tensor, mode=mode)

if dims == 2:
# This is the case for simple matrix multiply
fan_in = input_tensor.shape[0]
fan_out = input_tensor.shape[1]
else:
fan_in = input_tensor.shape[1]
fan_out = input_tensor.shape[0]

if input_tensor.dim() > 2:
fan_in *= input_tensor[0, 0].numel()
fan_out *= input_tensor[0, 0].numel()

if mode == "fan_in":
n = fan_in
else:
n = fan_out

expected_std = math.sqrt(2.0 / ((1 + a**2) * n))
bounds = expected_std * math.sqrt(3.0)
assert self._is_uniform(input_tensor, -bounds, bounds)


if __name__ == '__main__':
unittest.main()
Loading