Skip to content

Conversation

@GuoxiaWang
Copy link
Contributor

@GuoxiaWang GuoxiaWang commented Sep 8, 2021

PR types

New features

PR changes

APIs

Describe

support hybrid parallel inference helper class.

  • mp
  • pp
  • mp vs pp

Example (mp=1, pp=2):

import os
import numpy as np
import paddle
import paddle.fluid.layers as layers
import paddle.distributed.fleet as fleet
from paddle.distributed.fleet.utils.hybrid_parallel_inference import HybridParallelInferenceHelper
paddle.enable_static()

nranks = int(os.getenv("PADDLE_TRAINERS_NUM", 1))
rank = int(os.getenv("PADDLE_TRAINER_ID", 0))
dev_id = int(os.getenv("FLAGS_selected_gpus", 0))

main_program = paddle.static.Program()
startup_program = paddle.static.Program()

if nranks > 1:
    dist_strategy = fleet.DistributedStrategy()
    dist_strategy.without_graph_optimization = True
    fleet.init(is_collective=True, strategy=dist_strategy)

device = "gpu"

with paddle.static.program_guard(main_program, startup_program):
    with paddle.fluid.device_guard(f'{device}:0'):
        X = paddle.static.data(name='X', shape=[None, 2], dtype='float32')

    with paddle.fluid.device_guard(f'{device}:all'):
        max_len = layers.fill_constant(
            shape=[1], dtype="int64", value=5, force_cpu=False, name="n")
        step_idx = layers.fill_constant(
            shape=[1], dtype="int64", value=0, force_cpu=False, name="i")

        data = layers.array_write(X, step_idx)

        cond_int = layers.fill_constant(shape=[1], dtype="int64", value=0, force_cpu=False, name="cond_int")
        cond = layers.less_than(x=step_idx, y=max_len)
        while_op = layers.While(cond, is_test=True)

    with while_op.block():
        with paddle.fluid.device_guard(f'{device}:all'):
            input = layers.array_read(array=data, i=step_idx)
            layers.increment(x=step_idx, value=1.0, in_place=True)
            layers.array_write(input, i=step_idx, array=data)

        with paddle.fluid.device_guard(f'{device}:0'):
            param_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1.0))
            weight1 = paddle.static.create_parameter(
                shape=[2, 5], dtype='float32', attr=param_attr, is_bias=False)
            hidden1 = paddle.matmul(input, weight1)

        with paddle.fluid.device_guard(f'{device}:1'):
            param_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(2.0))
            weight2 = paddle.static.create_parameter(
                shape=[5, 2], dtype='float32', attr=param_attr, is_bias=False)
            hidden2 = paddle.matmul(hidden1, weight2)

            layers.array_write(hidden2, i=step_idx, array=data)

            # update cond and assign to cond_int, we will sync cond_int
            layers.less_than(x=step_idx, y=max_len, cond=cond)
            layers.assign(layers.cast(cond, dtype="int32"), cond_int)

        with paddle.fluid.device_guard(f'{device}:all'):
            # the code below must at end of while block and exists in device:all
            layers.assign(layers.cast(cond_int, dtype='bool'), cond)

    with paddle.fluid.device_guard(f'{device}:all'):
        out = layers.create_array(data.dtype)
        layers.assign(data, out)

    with paddle.fluid.device_guard(f'{device}:all'):
        # use a empty lod_tensor_array to clear lod_tensor_array
        layers.assign(layers.create_array(data.dtype), data)

helper = HybridParallelInferenceHelper(startup_program, main_program, micro_batch_size=2, num_mp=1, num_pp=2, init_comm=nranks>1)
helper.gen_infer_program(['array_write_0.out'], ['cond_int.tmp_0'])

np.random.seed(2333)
exe = paddle.static.Executor(paddle.CUDAPlace(dev_id))
exe.run(startup_program)

def numpy_while(x, w1=1.0, w2=2.0, max_len=5):
    data = [x]
    weight1 = np.empty([2, 5], dtype='float32')
    weight1.fill(w1)
    weight2 = np.empty([5, 2], dtype='float32')
    weight2.fill(w2)
    for i in range(max_len):
        input = data[i]
        hidden1 = np.matmul(input, weight1)
        hidden2 = np.matmul(hidden1, weight2)
        data.append(hidden2)
        
    return data
    
for step in range(5):
    init_data = np.random.uniform(low=0.0, high=1.0, size=[2, 2]).astype('float32')
    [res] = exe.run(main_program, feed={"X": init_data}, fetch_list=[out])
    print('-------- step', step, ' --------')
    print(res)
    
    print('-------- numpy --------')
    res_np = numpy_while(init_data)
    print(res_np)
    
    assert len(res) == len(res_np)
    for d1, d2 in zip(res, res_np):
        np.testing.assert_allclose(d1, d2)

@paddle-bot-old
Copy link

paddle-bot-old bot commented Sep 8, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

XieYunshen
XieYunshen previously approved these changes Sep 10, 2021
Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120)

@sandyhouse sandyhouse merged commit dc3c845 into PaddlePaddle:develop Sep 13, 2021
AnnaTrainingG pushed a commit to AnnaTrainingG/Paddle that referenced this pull request Sep 29, 2021
* support hybrid parallel inference helper class
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants