Skip to content

Commit fe9fbae

Browse files
committed
fix test_crf_decoding_op
1 parent ba57f73 commit fe9fbae

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

python/paddle/v2/fluid/tests/test_crf_decoding_op.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ def __init__(self, emission_weights, transition_weights,
2020
self.w = transition_weights[2:, :]
2121

2222
self.track = np.zeros(
23-
(seq_start_positions[-1], self.tag_num), dtype="int32")
23+
(seq_start_positions[-1], self.tag_num), dtype="int64")
2424
self.decoded_path = np.zeros(
25-
(seq_start_positions[-1], 1), dtype="int32")
25+
(seq_start_positions[-1], 1), dtype="int64")
2626

2727
def _decode_one_sequence(self, decoded_path, x):
2828
seq_len, tag_num = x.shape
2929
alpha = np.zeros((seq_len, tag_num), dtype="float64")
30-
track = np.zeros((seq_len, tag_num), dtype="int32")
30+
track = np.zeros((seq_len, tag_num), dtype="int64")
3131

3232
for i in range(tag_num):
3333
alpha[0, i] = self.a[i] + x[0, i]
@@ -125,10 +125,10 @@ def setUp(self):
125125
axis=0)
126126

127127
labels = np.random.randint(
128-
low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int32")
128+
low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int64")
129129
predicted_labels = np.ones(
130-
(lod[-1][-1], 1), dtype="int32") * (TAG_NUM - 1)
131-
expected_output = (labels == predicted_labels).astype("int32")
130+
(lod[-1][-1], 1), dtype="int64") * (TAG_NUM - 1)
131+
expected_output = (labels == predicted_labels).astype("int64")
132132

133133
self.inputs = {
134134
"Emission": (emission, lod),

0 commit comments

Comments
 (0)