Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 131 additions & 54 deletions paddle/phi/infermeta/spmd_rules/cross_entropy_with_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,51 @@ void GetCrossEntropyNotations(int x_ndim,
int axis,
bool soft_label,
bool use_softmax,
std::string* x_axes,
std::string* label_axes,
std::string* x_axes_src,
std::string* x_axes_dst,
std::string* label_axes_src,
std::string* label_axes_dst,
std::string* loss_axes,
std::string* softmax_out_axes) {
std::string* softmax_out_axes_src,
std::string* softmax_out_axes_dst,
bool support_shard_softmax_dim = false) {
std::string alphabet =
"abcdefghijlmnopqrstuvwxyz"; // k for softmax_normalize axis
*x_axes = GetBroadcastAxes(x_ndim, x_ndim, alphabet);
(*x_axes)[axis] = 'k';
*label_axes = *x_axes;
*x_axes_src = GetBroadcastAxes(x_ndim, x_ndim, alphabet);
(*x_axes_src)[axis] = 'k';
*x_axes_dst = *x_axes_src;
if (!support_shard_softmax_dim) {
(*x_axes_dst)[axis] = '1';
}

*label_axes_src = *x_axes_src;
*label_axes_dst = *x_axes_dst;
if (!soft_label) {
(*label_axes)[axis] = '1';
(*label_axes_src)[axis] = '1';
(*label_axes_dst)[axis] = '1';
}
*loss_axes = *x_axes;

*loss_axes = *x_axes_src;
(*loss_axes)[axis] = '1';

// optional output
if (use_softmax) {
*softmax_out_axes = *x_axes;
*softmax_out_axes_src = *x_axes_src;
*softmax_out_axes_dst = *x_axes_dst;
} else {
*softmax_out_axes = "";
*softmax_out_axes_src = "";
*softmax_out_axes_dst = "";
}
}

SpmdInfo CrossEntropyWithSoftmaxInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& label,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis) {
SpmdInfo CrossEntropyWithSoftmaxInferSpmdBase(const DistMetaTensor& x,
const DistMetaTensor& label,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis,
bool support_shard_softmax_dim) {
// Step0: Verify input args based on cross_entropy_with_softmax logic

EXTRACT_SHAPE_AND_DIST_ATTR(x);
Expand Down Expand Up @@ -111,21 +127,26 @@ SpmdInfo CrossEntropyWithSoftmaxInferSpmd(const DistMetaTensor& x,
}

// Step1: Build Einsum Notation
std::string x_axes, label_axes, loss_axes, softmax_out_axes;
std::string x_axes_src, x_axes_dst, label_axes_src, label_axes_dst, loss_axes,
softmax_out_axes_src, softmax_out_axes_dst;
GetCrossEntropyNotations(x_ndim,
axis,
soft_label,
use_softmax,
&x_axes,
&label_axes,
&x_axes_src,
&x_axes_dst,
&label_axes_src,
&label_axes_dst,
&loss_axes,
&softmax_out_axes);
&softmax_out_axes_src,
&softmax_out_axes_dst,
support_shard_softmax_dim);

// Step2: Sharding Propogation
// Step2.1: merge input shardings
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(
{{x_axes, x_dims_mapping_src}, {label_axes, label_dims_mapping_src}});
ShardingMergeForTensors({{x_axes_src, x_dims_mapping_src},
{label_axes_src, label_dims_mapping_src}});

// Step2.2: infer output dims mappings
TensorDistAttr loss_dist_attr_dst =
Expand All @@ -135,25 +156,25 @@ SpmdInfo CrossEntropyWithSoftmaxInferSpmd(const DistMetaTensor& x,
TensorDistAttr softmax_out_dist_attr_dst =
CopyTensorDistAttrForOutput(x_dist_attr_src);
softmax_out_dist_attr_dst.set_dims_mapping(
GetDimsMappingForAxes(softmax_out_axes, axis_to_dim_map));
GetDimsMappingForAxes(softmax_out_axes_dst, axis_to_dim_map));

// Step2.3: update input dims mappings with merged one
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(
GetDimsMappingForAxes(x_axes, axis_to_dim_map));
GetDimsMappingForAxes(x_axes_dst, axis_to_dim_map));
TensorDistAttr label_dist_attr_dst =
CopyTensorDistAttrForOutput(label_dist_attr_src);
label_dist_attr_dst.set_dims_mapping(
GetDimsMappingForAxes(label_axes, axis_to_dim_map));
GetDimsMappingForAxes(label_axes_dst, axis_to_dim_map));

VLOG(4) << "CrossEntropyInferSpmd:";
VLOG(4) << "axis: [" << axis << "], ignore_index: [" << ignore_index
<< "], numeric_stable_mode: ["
<< (numeric_stable_mode ? "true" : "false") << "], use_softmax: ["
<< (use_softmax ? "true" : "false") << "], soft_label: ["
<< (soft_label ? "true" : "false") << "].";
VLOG(4) << "Einsum notation: [" << x_axes << "," << label_axes << " --> "
<< softmax_out_axes << "," << loss_axes << "].\n"
VLOG(4) << "Einsum notation: [" << x_axes_src << "," << label_axes_src
<< " --> " << softmax_out_axes_src << "," << loss_axes << "].\n"
<< "X shape: [" << str_join(x_shape) << "], x_dims_mapping_src: ["
<< str_join(x_dims_mapping_src) << "], x_dims_mapping_dst: ["
<< str_join(x_dist_attr_dst.dims_mapping()) << "]\n Label shape: ["
Expand All @@ -174,6 +195,40 @@ SpmdInfo CrossEntropyWithSoftmaxInferSpmd(const DistMetaTensor& x,
{softmax_out_dist_attr_dst, loss_dist_attr_dst}};
}

SpmdInfo CrossEntropyWithSoftmaxInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& label,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis) {
return CrossEntropyWithSoftmaxInferSpmdBase(x,
label,
soft_label,
use_softmax,
numeric_stable_mode,
ignore_index,
axis,
false);
}

SpmdInfo CrossEntropyWithSoftmaxInferSpmdStatic(const DistMetaTensor& x,
const DistMetaTensor& label,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis) {
return CrossEntropyWithSoftmaxInferSpmdBase(x,
label,
soft_label,
use_softmax,
numeric_stable_mode,
ignore_index,
axis,
true);
}

SpmdInfo CrossEntropyWithSoftmaxInferSpmdReverse(
const DistMetaTensor& x,
const DistMetaTensor& label,
Expand Down Expand Up @@ -220,32 +275,37 @@ SpmdInfo CrossEntropyWithSoftmaxInferSpmdReverse(
axis = x_ndim + axis;
}

std::string x_axes, label_axes, loss_axes, softmax_out_axes;
std::string x_axes, x_axes_dst, label_axes_src, label_axes_dst, loss_axes,
softmax_out_axes_src, softmax_out_axes_dst;
GetCrossEntropyNotations(x_ndim,
axis,
soft_label,
use_softmax,
&x_axes,
&label_axes,
&x_axes_dst,
&label_axes_src,
&label_axes_dst,
&loss_axes,
&softmax_out_axes);
&softmax_out_axes_src,
&softmax_out_axes_dst,
true);

// Step2: Sharding Propogation
// Step2.1 merge output dims mappings
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors({{loss_axes, loss_dims_mapping_src},
{softmax_out_axes, s_out_dims_mapping_src}});
{softmax_out_axes_src, s_out_dims_mapping_src}});

// Step2.2 infer inputs' dims mappings from merged dims mapping
std::vector<int64_t> x_dims_mapping, label_dims_mapping;
// infer and X's dims mapping
x_dims_mapping = GetDimsMappingForAxes(x_axes, axis_to_dim_map);
x_dims_mapping = GetDimsMappingForAxes(x_axes_dst, axis_to_dim_map);
// infer and label's dims mapping
label_dims_mapping = GetDimsMappingForAxes(label_axes, axis_to_dim_map);
label_dims_mapping = GetDimsMappingForAxes(label_axes_dst, axis_to_dim_map);

// Step2.3 update outputs' dims mappings with merged dims mapping
std::vector<int64_t> s_out_dims_mapping_dst =
GetDimsMappingForAxes(softmax_out_axes, axis_to_dim_map);
GetDimsMappingForAxes(softmax_out_axes_dst, axis_to_dim_map);
std::vector<int64_t> loss_dims_mapping_dst =
GetDimsMappingForAxes(loss_axes, axis_to_dim_map);

Expand Down Expand Up @@ -290,8 +350,8 @@ SpmdInfo CrossEntropyWithSoftmaxInferSpmdReverse(
<< (numeric_stable_mode ? "true" : "false") << "], use_softmax: ["
<< (use_softmax ? "true" : "false") << "], soft_label: ["
<< (soft_label ? "true" : "false") << "].";
VLOG(4) << "Einsum notation: [" << x_axes << "," << label_axes << " --> "
<< softmax_out_axes << "," << loss_axes << "].\n"
VLOG(4) << "Einsum notation: [" << x_axes << "," << label_axes_src << " --> "
<< softmax_out_axes_src << "," << loss_axes << "].\n"
<< "Loss shape: [" << str_join(loss_shape)
<< "], loss_dims_mapping_src: [" << str_join(loss_dims_mapping_src)
<< "], loss_dims_mapping_dst: [" << str_join(loss_dims_mapping_dst)
Expand All @@ -313,25 +373,39 @@ void GetCrossEntropyGradNotations(int loss_ndim,
int axis,
bool soft_label,
bool use_softmax,
std::string* label_axes,
std::string* softmax_axes,
std::string* loss_grad_axes) {
std::string* label_axes_src,
std::string* label_axes_dst,
std::string* softmax_axes_src,
std::string* softmax_axes_dst,
std::string* loss_grad_axes,
bool support_shard_softmax_dim = false) {
std::string alphabet =
"abcdefghijlmnopqrstuvwxyz"; // k for softmax_normalize axis
auto x_axes = alphabet.substr(0, loss_ndim);
x_axes[axis] = 'k';
*label_axes = x_axes;
auto x_axes_src = alphabet.substr(0, loss_ndim);
x_axes_src[axis] = 'k';
auto x_axes_dst = x_axes_src;
if (!support_shard_softmax_dim) {
x_axes_dst[axis] = '1';
}
*label_axes_src = x_axes_src;
*label_axes_dst = x_axes_dst;
if (!soft_label) {
(*label_axes)[axis] = '1';
(*label_axes_src)[axis] = '1';
(*label_axes_dst)[axis] = '1';
}

*loss_grad_axes = x_axes;
*loss_grad_axes = x_axes_src;
(*loss_grad_axes)[axis] = '1';
// optional output
if (use_softmax) {
*softmax_axes = x_axes;
*softmax_axes_src = x_axes_src;
*softmax_axes_dst = x_axes_dst;
if (!soft_label) {
(*softmax_axes_dst)[axis] = '1';
}
} else {
*softmax_axes = "";
*softmax_axes_src = "";
*softmax_axes_dst = "";
}
}

Expand All @@ -351,29 +425,32 @@ SpmdInfo CrossEntropyWithSoftmaxGradInferSpmd(const DistMetaTensor& label,
axis = loss_grad_ndim + axis;
}

std::string label_axes, softmax_axes, loss_grad_axes;
std::string label_axes_src, label_axes_dst, softmax_axes_src,
softmax_axes_dst, loss_grad_axes;
GetCrossEntropyGradNotations(loss_grad_ndim,
axis,
soft_label,
use_softmax,
&label_axes,
&softmax_axes,
&label_axes_src,
&label_axes_dst,
&softmax_axes_src,
&softmax_axes_dst,
&loss_grad_axes);

std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors({{label_axes, label_dims_mapping_src},
{softmax_axes, softmax_dims_mapping_src},
ShardingMergeForTensors({{label_axes_src, label_dims_mapping_src},
{softmax_axes_src, softmax_dims_mapping_src},
{loss_grad_axes, loss_grad_dims_mapping_src}});

auto label_dist_attr_dst = CopyTensorDistAttrForOutput(label_dist_attr_src);
auto label_dims_mapping_dst =
GetDimsMappingForAxes(label_axes, axis_to_dim_map, true);
GetDimsMappingForAxes(label_axes_dst, axis_to_dim_map, true);
label_dist_attr_dst.set_dims_mapping(label_dims_mapping_dst);

auto softmax_dist_attr_dst =
CopyTensorDistAttrForOutput(softmax_dist_attr_src);
auto softmax_dims_mapping_dst =
GetDimsMappingForAxes(softmax_axes, axis_to_dim_map, true);
GetDimsMappingForAxes(softmax_axes_dst, axis_to_dim_map, true);
softmax_dist_attr_dst.set_dims_mapping(softmax_dims_mapping_dst);

auto loss_grad_dist_attr_dst =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ SpmdInfo CrossEntropyWithSoftmaxInferSpmd(const DistMetaTensor& x,
int ignore_index,
int axis);

SpmdInfo CrossEntropyWithSoftmaxInferSpmdStatic(const DistMetaTensor& x,
const DistMetaTensor& label,
bool soft_label,
bool use_softmax,
bool numeric_stable_mode,
int ignore_index,
int axis);

SpmdInfo CrossEntropyWithSoftmaxInferSpmdReverse(
const DistMetaTensor& x,
const DistMetaTensor& label,
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/infermeta/spmd_rules/rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -586,12 +586,12 @@ PD_REGISTER_SPMD_RULE(tile,
// cross_entropy_with_softmax
PD_REGISTER_SPMD_RULE(
cross_entropy_with_softmax,
PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmd),
PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmdStatic),
PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmdReverse));

PD_REGISTER_SPMD_RULE(
softmax_with_cross_entropy,
PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmd),
PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmdStatic),
PD_INFER_SPMD(phi::distributed::CrossEntropyWithSoftmaxInferSpmdReverse));

// fused_linear_param_grad_add got no reverse infer spmd rule
Expand Down
32 changes: 29 additions & 3 deletions test/cpp/auto_parallel/cross_entropy_softmax_spmd_rule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,32 @@ TEST(CrossEntropyInferSpmd, Ctor) {
<< std::endl;
}

// test sharding along softmax axis.
{
x_dist_attr.set_dims_mapping(std::vector<int64_t>({0, 1}));
label_dist_attr.set_dims_mapping(std::vector<int64_t>({0, -1}));
phi::distributed::DistMetaTensor x(phi::make_ddim(x_shape), x_dist_attr);
phi::distributed::DistMetaTensor label(phi::make_ddim(x_shape),
label_dist_attr);
int axis = 1;

auto spmdinfo =
CrossEntropyWithSoftmaxInferSpmd(x, label, false, true, true, 1, axis);

EXPECT_EQ(spmdinfo.first.size(), 2UL);
EXPECT_EQ(spmdinfo.second.size(), 2UL);
check_dim_mapping(spmdinfo.first[0], {0, -1});
check_dim_mapping(spmdinfo.first[1], {0, -1});
check_dim_mapping(spmdinfo.second[0], {0, -1});
check_dim_mapping(spmdinfo.second[1], {0, -1});
check_partial_dims(spmdinfo.second[0], {});

VLOG(4) << "Test CrossEntropyWithSoftmaxInferSpmd sharding on other axes."
<< std::endl
<< std::endl
<< std::endl;
}

// backward
{
std::vector<int64_t> loss_shape = {32, 1};
Expand All @@ -82,10 +108,10 @@ TEST(CrossEntropyInferSpmd, Ctor) {

EXPECT_EQ(spmdinfo.first.size(), 3UL);
EXPECT_EQ(spmdinfo.second.size(), 1UL);
check_dim_mapping(spmdinfo.first[0], {0, 1});
check_dim_mapping(spmdinfo.first[1], {0, 1});
check_dim_mapping(spmdinfo.first[0], {0, -1});
check_dim_mapping(spmdinfo.first[1], {0, -1});
check_dim_mapping(spmdinfo.first[2], {0, -1});
check_dim_mapping(spmdinfo.second[0], {0, 1});
check_dim_mapping(spmdinfo.second[0], {0, -1});
check_partial_dims(spmdinfo.second[0], {});

VLOG(4)
Expand Down