Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 0 additions & 193 deletions autoparallel/shardings/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,165 +453,6 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
return OpStrategy(all_strategies)


# ======================================
# the following ops require meta_tensor fix


@register_opschema_rule(torch.ops.aten.native_layer_norm.default)
def native_layer_norm_rule(mesh, op_schema):
from torch.distributed.tensor._ops._math_ops import (
Sequence,
normalize_to_torch_size,
)
from torch.distributed.tensor._ops._pointwise_ops import pointwise_strategy

# mesh = op_schema.get_mesh_from_args()
# args must be: input, normalized_shape, weight, bias, eps
# for None weight and bias, their corresponding objects will
# be None as well. layer_norm_strategy returns one OpStrategy
# for the triple return values (out, mean, rstd).
assert len(op_schema.args_schema) == 5
(
input_strategy,
normalized_shape,
weight_strategy,
bias_strategy,
_,
) = op_schema.args_schema

# the current layer norm implementation requires that all
# input DTensor's sharding must be in form of OpStrategy
assert isinstance(input_strategy, OpStrategy)
assert isinstance(normalized_shape, (int, Sequence, torch.Size))
normalized_size = normalize_to_torch_size(normalized_shape)

input_ndim = input_strategy.ndim
axis = input_ndim - len(normalized_size)

output_strategy = pointwise_strategy(op_schema, linearity=False)

# now let's remove the cases that are invalid, as they require
# reduction on a sharded dimension
kept = []
for strategy in output_strategy.strategies:
is_valid = True
for plc in strategy.input_specs[0].placements:
if plc.is_shard() and plc.dim >= axis:
is_valid = False
break
if is_valid:
output_spec = strategy.output_specs
input_spec = strategy.input_specs[0]
mesh = strategy.mesh

# Create output tensor_meta with same shape as input but contiguous strides
# (LayerNorm forward returns contiguous tensor even if input was non-contiguous)
output_tensor_meta = _gen_tensor_meta(
input_spec.tensor_meta.shape, input_spec.tensor_meta.dtype
)
output_spec = DTensorSpec(
mesh=mesh,
placements=output_spec.placements,
tensor_meta=output_tensor_meta,
)

# the output spec is the same as input spec
shape = input_spec.tensor_meta.shape[:axis] + (1,) * len(normalized_size)
mean_std_tgt_spec = DTensorSpec(
mesh=mesh,
placements=output_spec.placements,
tensor_meta=_gen_tensor_meta(shape),
)
output_target_spec = (
output_spec,
mean_std_tgt_spec,
mean_std_tgt_spec,
)
if len(output_target_spec) == 1:
output_target_spec = output_target_spec[0]
strategy.output_specs = output_target_spec
kept.append(strategy)

return OpStrategy(kept)


@register_opschema_rule(torch.ops.aten.native_layer_norm_backward.default)
def native_layer_norm_backward_rule(mesh, op_schema):
from torch.distributed.tensor._ops._math_ops import (
Sequence,
normalize_to_torch_size,
)
from torch.distributed.tensor._ops._pointwise_ops import pointwise_strategy

assert len(op_schema.args_schema) == 8
(
grad_out_strategy,
input_strategy,
normalized_shape,
mean_strategy,
rstd_strategy,
weight_strategy,
bias_strategy,
_,
) = op_schema.args_schema

assert isinstance(input_strategy, OpStrategy)
assert isinstance(normalized_shape, (int, Sequence, torch.Size))
normalized_size = normalize_to_torch_size(normalized_shape)

input_ndim = input_strategy.ndim
axis = input_ndim - len(normalized_size)

output_strategy = pointwise_strategy(op_schema, linearity=False)

# now let's remove the cases that are invalid, as they require
# reduction on a sharded dimension
kept = []
for strategy in output_strategy.strategies:
is_valid = True
input_spec = strategy.input_specs[1]
for plc in input_spec.placements:
if plc.is_shard() and plc.dim >= axis:
is_valid = False
break
if is_valid:
mesh = strategy.mesh
# Create grad_input tensor_meta with same shape as input but contiguous strides
# (LayerNorm backward returns contiguous gradient even if input was non-contiguous)
grad_input_tensor_meta = _gen_tensor_meta(
input_spec.tensor_meta.shape, input_spec.tensor_meta.dtype
)
grad_input_spec = DTensorSpec(
mesh=mesh,
placements=strategy.output_specs.placements,
tensor_meta=grad_input_tensor_meta,
)
assert grad_input_spec.tensor_meta is not None
weight_spec = strategy.input_specs[4]
bias_spec = strategy.input_specs[5]
weight_tgt_spec = DTensorSpec(
mesh=mesh,
placements=weight_spec.placements,
tensor_meta=weight_spec.tensor_meta,
)
bias_tgt_spec = DTensorSpec(
mesh=mesh,
placements=bias_spec.placements,
tensor_meta=bias_spec.tensor_meta,
)
output_target_spec = (
grad_input_spec,
weight_tgt_spec,
bias_tgt_spec,
)
if len(output_target_spec) == 1:
output_target_spec = output_target_spec[0]
strategy.output_specs = output_target_spec
kept.append(strategy)

return OpStrategy(kept)


@register_opschema_rule(
[
torch.ops.prims.convert_element_type.default,
Expand All @@ -627,40 +468,6 @@ def convert_element_type_rule(mesh, op_schema):
return out_strat


@register_opschema_rule(torch.ops.aten.constant_pad_nd.default)
def constant_pad_nd_rule(mesh, op_schema):
from torch.distributed.tensor._ops._tensor_ops import (
propagate_single_input_strategy,
)

out_strat = propagate_single_input_strategy(op_schema)
pad = op_schema.args_schema[1]
ndim = len(out_strat.strategies[0].output_specs.tensor_meta.shape)
dims_to_remove = [
ndim - i - 1
for i in range(len(pad) // 2)
if pad[i * 2] != 0 or pad[i * 2 + 1] != 0
]

to_remove = []
filtered_strats = []
for idx, strat in enumerate(out_strat.strategies):
remove_this = False
for plc in strat.output_specs.placements:
if plc.is_shard() and plc.dim in dims_to_remove:
to_remove.append(idx)
remove_this = True
break
if not remove_this:
filtered_strats.append(strat)

for strat in filtered_strats:
for idx in to_remove:
strat.redistribute_cost[0][idx] = math.inf

return OpStrategy(filtered_strats)


@register_opschema_rule(torch.ops.aten.split.Tensor)
def split_rule(mesh, op_schema):
strat = op_schema.args_schema
Expand Down
Loading