add the basic apis for auto_parallel #33804
add the basic apis for auto_parallel #33804sandyhouse merged 60 commits intoPaddlePaddle:developfrom
Conversation
|
Thanks for your contribution! |
|
Sorry to inform you that bf24fb7's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
… auto_parallel_basic
… auto_parallel_basic
| And the first logical process is the one with id=2. | ||
|
|
||
| Args: | ||
| mesh (numpy.ndarray): an N-dimensional array describes the toplogy |
There was a problem hiding this comment.
这里参数类型用numpy.ndarray的原因是什么呢?从示例代码看的话,是不是用python的list就可以了?
|
|
||
| Args: | ||
| x (Tensor): the tensor to process. | ||
| mask (numpy.ndarray): the shape of `mask` must be the same as the ProcessMesh belonging to |
There was a problem hiding this comment.
这里的mask是否用python的list就可以?
|
|
||
| Args: | ||
| x (tensor): the tensor to process. | ||
| device (str): the device that the tensor `x` will be put on, e.g., 'gpu:0', 'cpu'. |
There was a problem hiding this comment.
set_offload_device什么情况下需要设置成'gpu:0',表示什么意思呢?
There was a problem hiding this comment.
从实际使用场景看,offload的使用需求是offload指定的tensor到cpu,此处已去掉gpu:0
| optional bool need_check_feed = 4 [ default = false ]; | ||
| optional bool is_parameter = 5 [ default = false ]; | ||
| optional bool stop_gradient = 6 [ default = false ]; | ||
| repeated Attr attrs = 7; |
There was a problem hiding this comment.
这些新增的字段,在保存模型的时候,会被存下来吗?
我看示例代码,模型定义的时候就会添加这些字段,模型定义完再调用模型保存的时候,是不是会把这些字段都保存下来?什么时候把这些字段去掉呢?
There was a problem hiding this comment.
自动并行主要包括以下几个主要过程:1. 使用自动并行接口标识关键tensor或op;2. 自动补全:补全所有tensor和op的分布式属性;3. 逻辑切分;4. 物理映射;5. 执行训练。其中步骤1-3会使用到此处新增的字段;所以该接口新增的字段会在步骤1-3完成后删除,且该过程用户无感知。
常规的模型保存过程是 执行部分训练或全部训练完成后进行模型保存,这时,新增字段已经完全删除。
但存在一个特殊的情形,即用户完成组网后即刻保存模型,这时相关的字段会被保存下来。但我们认为,这种特殊情形是不应该存在的,因为完成组网后即保存模型是没有意义的。
| mesh_id = self.attr(mesh_attr_name) | ||
| return _g_process_mesh_map[mesh_id] | ||
|
|
||
| def dims_mapping(self, name): |
There was a problem hiding this comment.
表示Tensor整体的维度概念时用dimension, 一般从1开始编号,1维Tensor,2维Tensor
表示Tensor第几维概念是用axis和axes,一般从0开始编号,Tensor的第1维,Tensor的第2维
这里看起来是表示整体维度的概念,建议直接用单数dim_mapping
PR types
New features
PR changes
Others
Describe
Usage: