Skip to content

Dynamic Shape Support #48

@0xWJ

Description

@0xWJ

#dynamic shape tracking

ir

  1. add is_dynamic to IRObject
  2. add dynamic_dims (type of Set[int]) to IRFullTensor
  3. When writing IR functions, we can totally rely on dynamic dims/dynamic objects, and don't need to take IRObject into consideration any more.
  4. [REMOVED] Before PAS, put ^ to all dynamic dim annotations, so graph transformation/autodist will not partition it.
    In following stages, dynamic-shape can be safely ignored.
  5. Add is_dynamic to DimAnno, so pas can check the value to determine if a dimension should be partitioned.
  6. Potentially we can add more metadata to DimAnno,
    for example how this dim relates to the dim in inputs
  7. Add more metadata to inputs, for example
input.metadata = {
    dynamic_dims: {
        0: and(mul(8), max(1000))  # dim 0 has constraints: multiply of 8, < 1000
        1: max(1000)  # dim 1 has constraints: < 1000
        2: and(min(10), max(1000)) # dim 2 has constraints: 10 < x < 1000
        3: dynamic,  # dim 3 is general dynamic
        4: static,   # dim 4 is static
    }
}

Communication:

  1. Communication primitives move, rdscatter, rvscatter, rdgatter, rvgatter, broadcast have shape parameters, whic is determined in compiling time.
  2. When any inputs of communication primitives are dynamic, we need to ensure that the shape information is shared across ranks in runtime, so we need to add an extra parameter to these primitives to ask them to sync their shape information.

For example, for move:

# original
def move(tensor: Optional[torch.Tensor], shape: Tuple[int], dtype: torch.dtype, src: int, dst: int, async_op=False):
# New
def move(tensor: Optional[torch.Tensor], shape: Tuple[int], dtype: torch.dtype, src: int, dst: int, async_op=False, sync_shape=False):
    if sync_shape:
        if rank == src:
            torch.distributed.send(
                torch.tensor(shape, dtype=torch.int64, device='cuda'),
                dst,
            )
        else:
            shape_tensor = torch.empty(len(shape), dtype=torch.int64, device='cuda')
            torch.distributed.recv(shape_tensor, src)
            shape = tuple(shape_tensor.tolist())
    ....

Ops

  1. IRDimops: Based on annotation.
    Add extra annoation for dependency
    I would choose ! for dependency annotation. (better alternatives?)

    You only need to add ! annotation to const dimentions.
    for example,

    • 128!a,b means this dimension(const 128) depends on dimension(or IRObjects in kwargs) a and b
    • 128!* means this dimension depends on all * dimensions.
    • 128! means this dimension is always dynamic.(it may depend on non-tensor inputs)
    • 128 (without !) means this dimension is not dynamic

    1.1 Normal case(matmul): m k+, k+ n -> m n the dynamic-ness of m/n can be inferred from input. No further efforts.

    1.2 Const:
    Consts in input annotation is not a problem, but we must annotate the dependencies of the const in output.

     ```
         Repeat(tensor([3], 4, 2))
         b^ -> 4 (2 b^)
     ```
    
     4 and 2 are from kwargs input (repeats: `List[int]` = [4, 2]), so if `repeats[0]` is not dynamic, the first dimension of output is not dynamic
     for example, if `repeats[0]` is dynamic, but `repeats[1]` is constant, we can change the annotation to
     `b^ -> 4! (2 b^)`
     Or should we have more complicated annatation like `b^ -> 4!repeats[0] (2!repeats[1] b^)`? Too complicated.
    
     ```
     Conv2D(input=tensor([2, 3, 4, 4]), weight=tensor([3, 3, 1, 1]), stride=1, padding=0)
     n iC+ 4 4, oC iC+ 1 1 -> n oC 4 4
     ```
     the first `4` in output annoation depends on padding/dilation/stride/input.shape[-2]/weights[-2]
     the secon `4` in output annoation depends on padding/dilation/stride/input.shape[-1]/weights[-1]
     Two options:
     1. `n iC+ iH iW, oC iC+ kH kW -> n oC 4!pading,dilation,stride,iH,kH 4!padding,dilation,stride,iW,kW`
     2. Resolve the dynamicness by yourself.
        for example, if input.shape[-2] is dynamic you can use
        n iC+ 4 4, oC iC+ 1 1 -> n oC 4! 4
    

    1.3. ?
    Nonzero(tensor([3,4,5]), as_tuple=True) a^ b^ c^ -> ?, ?, ?
    Nonzero(tensor([3,4,5])) a^ b^ c^ -> ?
    FullSlice(IRTensor([3, 4]), [True, False, True]) a^ b^, ? -> ?
    We should treat all dims in the outputs as dynamic?
    What if some output dims are dynamic and some are not? (For exmaple, a b, ? -> a ? ?, not legal now)

    1.4: Op with different behavior
    Max(input, dim, keepdim)
    if dim/keepdim is not constant, we can't predict how the output dim looks like in runtime.
    Should we trigger an error or let it go (and treat all dims of outputs as dynamic) ?
    1.5 Hidden dimension in Input
    We should treat the dynamicness of all hidden dimensions the same with input?
    1.6 Custom Op:
    Totally reply on annotation, so 128!a,b is necessary. For official op, we can use !a,b or just resolve the dynamicness in implementation (return 128 or 128!), so specifying dependency is optional.

  2. IRPyFunc:
    2.1 tensor.shape/tensor.size: set is_dynamic based on dynamic dims
    2.2 other python functions: set is_dynmaic if any of its inputs are dynamic or tensor

  3. IRFwOperation: If none of input tensors has dynamic dims, then all outputs will not have dynamic dims. Otherwise, all dims of outputs are dynamic. Too strict? (autograd function)

  4. Other Ops(IRPad/IRConv2d/etc): Ignore. I think we should remove it some time.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions