Skip to content

lstm错误 #24300

@3wGTA

Description

@3wGTA

python 3.7
paddle 1.72-gpu

V = 32
h = 3
emb_size=5
max_len = 7 #最大为5,每个句子最长为5
hidden_size = 2
num_layers = 1

label = fluid.data(name='label', shape=[None, 1], dtype='int64')
x = fluid.data(name='t', shape=[None], dtype='int64',lod_level=1)
y = fluid.data(name='h', shape=[None], dtype='int64',lod_level=1)


w = fluid.ParamAttr(name='emb_vec', initializer=fluid.initializer.NumpyArrayInitializer(weight), trainable=False)

emb_x = fluid.embedding(input=x, size=[32,5], param_attr=w)
emb_y = fluid.embedding(input=y, size=[32,5], param_attr=w)

pad_value = fluid.layers.assign(input=np.array([0.0], dtype=np.float32))
pad_x,info_x = fluid.layers.sequence_pad(emb_x,pad_value)
pad_y,info_y = fluid.layers.sequence_pad(emb_y,pad_value)
batch_size=5
init_h = fluid.layers.fill_constant([num_layers, batch_size, hidden_size], 'float32', 0)
init_c = fluid.layers.fill_constant([num_layers, batch_size, hidden_size], 'float32', 0)
# lstm 网络

# 返回的形状是 batch_size, seq_len, hiddensize
lstm_x, x_last_h, x_last_c = fluid.layers.lstm(pad_x, init_h, init_c, max_len, hidden_size, num_layers,is_bidirec=True)

use_gpu = True
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
main_program = fluid.default_main_program()
feeder = fluid.DataFeeder(feed_list=['t', 'h','label'], place=place)
exe.run(fluid.default_startup_program())

fetch_var = [x, y, emb_x, emb_y, pad_x, pad_y, lstm_x, x_last_h]

for i, data in enumerate(train_reader()):
    print(data)
    result = exe.run(
        main_program,
        feeder.feed(data),
        fetch_list=fetch_var,
        return_numpy=False
        )
    break

经过多次测试,发现经过lstm网络之后输出的形状是[batch_size, seq_len, hidden_size]
并且lstm中的双向是没有效果的,是否使用双向,得到的结果都是上述的形状

本次输入数据
[([9, 1, 3, 8], [8, 5], 0), ([0, 3, 2, 5], [8, 3], 0), ([9, 5], [4, 5, 9, 3, 2, 3, 7], 1), ([4, 5, 5, 4], [7, 5, 2], 1), ([6, 7], [5, 4, 6, 0, 4], 1)]

batch_size=5
fetch的结果如下
t
[[0, 4, 8, 10, 14, 16]]
[9 1 3 8 0 3 2 5 9 5 4 5 5 4 6 7]
(16,)


h
[[0, 2, 4, 11, 14, 19]]
[8 5 8 3 4 5 9 3 2 3 7 7 5 2 5 4 6 0 4]
(19,)


embedding_0.tmp_0
[[0, 4, 8, 10, 14, 16]]
[[0. 1. 1. 0. 1.]
[1. 0. 0. 1. 0.]
[0. 1. 0. 1. 0.]
[0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0.]
[0. 1. 0. 1. 0.]
[1. 0. 1. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 1. 0. 1.]
[0. 1. 0. 0. 0.]
[0. 1. 1. 1. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 1. 1. 0.]
[0. 0. 1. 1. 1.]
[0. 0. 1. 1. 0.]]
(16, 5)


embedding_1.tmp_0
[[0, 2, 4, 11, 14, 19]]
[[0. 0. 0. 0. 1.]
[0. 1. 0. 0. 0.]
[0. 0. 0. 0. 1.]
[0. 1. 0. 1. 0.]
[0. 1. 1. 1. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 1. 0. 1.]
[0. 1. 0. 1. 0.]
[1. 0. 1. 0. 0.]
[0. 1. 0. 1. 0.]
[0. 0. 1. 1. 0.]
[0. 0. 1. 1. 0.]
[0. 1. 0. 0. 0.]
[1. 0. 1. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 1. 1. 0.]
[0. 0. 1. 1. 1.]
[0. 0. 0. 0. 0.]
[0. 1. 1. 1. 0.]]
(19, 5)


sequence_pad_0.tmp_0
[]
[[[0. 1. 1. 0. 1.]
[1. 0. 0. 1. 0.]
[0. 1. 0. 1. 0.]
[0. 0. 0. 0. 1.]]

[[0. 0. 0. 0. 0.]
[0. 1. 0. 1. 0.]
[1. 0. 1. 0. 0.]
[0. 1. 0. 0. 0.]]

[[0. 1. 1. 0. 1.]
[0. 1. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]

[[0. 1. 1. 1. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 1. 1. 0.]]

[[0. 0. 1. 1. 1.]
[0. 0. 1. 1. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]]
(5, 4, 5)


sequence_pad_1.tmp_0
[]
[[[0. 0. 0. 0. 1.]
[0. 1. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]

[[0. 0. 0. 0. 1.]
[0. 1. 0. 1. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]

[[0. 1. 1. 1. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 1. 0. 1.]
[0. 1. 0. 1. 0.]
[1. 0. 1. 0. 0.]
[0. 1. 0. 1. 0.]
[0. 0. 1. 1. 0.]]

[[0. 0. 1. 1. 0.]
[0. 1. 0. 0. 0.]
[1. 0. 1. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]

[[0. 1. 0. 0. 0.]
[0. 1. 1. 1. 0.]
[0. 0. 1. 1. 1.]
[0. 0. 0. 0. 0.]
[0. 1. 1. 1. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]]
(5, 7, 5)


cudnn_lstm_0.tmp_0
[]
[[[-0.01163395 -0.03404206]
[-0.03346499 0.09141013]
[-0.01192071 -0.00806206]
[ 0.01852721 0.02037505]]

[[-0.02112382 0.0178064 ]
[ 0.00036139 0.02811656]
[-0.04091306 -0.03953137]
[ 0.01994075 0.04983816]]

[[-0.03010507 -0.02586269]
[-0.01206493 0.04903913]
[-0.02649166 0.01360786]
[-0.02587583 0.01543753]]

[[ 0.00416536 -0.02817909]
[ 0.04569617 0.04708511]
[-0.03898652 -0.01187311]
[ 0.00342977 0.01862298]]

[[-0.02776574 -0.04738818]
[-0.06142154 0.09363435]
[-0.03319343 0.01389923]
[-0.00538392 0.0187942 ]]]
(5, 4, 2)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions