-
Notifications
You must be signed in to change notification settings - Fork 22
Description
#dynamic shape tracking
ir
- add is_dynamic to
IRObject - add dynamic_dims (type of
Set[int]) to IRFullTensor - When writing IR functions, we can totally rely on dynamic dims/dynamic objects, and don't need to take IRObject into consideration any more.
- [REMOVED]
Before PAS, put^to all dynamic dim annotations, so graph transformation/autodist will not partition it.
In following stages, dynamic-shape can be safely ignored. - Add
is_dynamictoDimAnno, so pas can check the value to determine if a dimension should be partitioned. - Potentially we can add more metadata to DimAnno,
for example how this dim relates to the dim in inputs - Add more metadata to inputs, for example
input.metadata = {
dynamic_dims: {
0: and(mul(8), max(1000)) # dim 0 has constraints: multiply of 8, < 1000
1: max(1000) # dim 1 has constraints: < 1000
2: and(min(10), max(1000)) # dim 2 has constraints: 10 < x < 1000
3: dynamic, # dim 3 is general dynamic
4: static, # dim 4 is static
}
}
Communication:
- Communication primitives
move,rdscatter,rvscatter,rdgatter,rvgatter,broadcasthaveshapeparameters, whic is determined in compiling time. - When any inputs of communication primitives are dynamic, we need to ensure that the shape information is shared across ranks in runtime, so we need to add an extra parameter to these primitives to ask them to sync their shape information.
For example, for move:
# original
def move(tensor: Optional[torch.Tensor], shape: Tuple[int], dtype: torch.dtype, src: int, dst: int, async_op=False):
# New
def move(tensor: Optional[torch.Tensor], shape: Tuple[int], dtype: torch.dtype, src: int, dst: int, async_op=False, sync_shape=False):
if sync_shape:
if rank == src:
torch.distributed.send(
torch.tensor(shape, dtype=torch.int64, device='cuda'),
dst,
)
else:
shape_tensor = torch.empty(len(shape), dtype=torch.int64, device='cuda')
torch.distributed.recv(shape_tensor, src)
shape = tuple(shape_tensor.tolist())
....Ops
-
IRDimops: Based on annotation.
Add extra annoation for dependency
I would choose!for dependency annotation. (better alternatives?)You only need to add ! annotation to const dimentions.
for example,128!a,bmeans this dimension(const 128) depends on dimension(or IRObjects in kwargs) a and b128!*means this dimension depends on all * dimensions.128!means this dimension is always dynamic.(it may depend on non-tensor inputs)128(without!) means this dimension is not dynamic
1.1 Normal case(matmul):
m k+, k+ n -> m nthe dynamic-ness ofm/ncan be inferred from input. No further efforts.1.2 Const:
Consts in input annotation is not a problem, but we must annotate the dependencies of the const in output.``` Repeat(tensor([3], 4, 2)) b^ -> 4 (2 b^) ``` 4 and 2 are from kwargs input (repeats: `List[int]` = [4, 2]), so if `repeats[0]` is not dynamic, the first dimension of output is not dynamic for example, if `repeats[0]` is dynamic, but `repeats[1]` is constant, we can change the annotation to `b^ -> 4! (2 b^)` Or should we have more complicated annatation like `b^ -> 4!repeats[0] (2!repeats[1] b^)`? Too complicated. ``` Conv2D(input=tensor([2, 3, 4, 4]), weight=tensor([3, 3, 1, 1]), stride=1, padding=0) n iC+ 4 4, oC iC+ 1 1 -> n oC 4 4 ``` the first `4` in output annoation depends on padding/dilation/stride/input.shape[-2]/weights[-2] the secon `4` in output annoation depends on padding/dilation/stride/input.shape[-1]/weights[-1] Two options: 1. `n iC+ iH iW, oC iC+ kH kW -> n oC 4!pading,dilation,stride,iH,kH 4!padding,dilation,stride,iW,kW` 2. Resolve the dynamicness by yourself. for example, if input.shape[-2] is dynamic you can use n iC+ 4 4, oC iC+ 1 1 -> n oC 4! 41.3. ?
Nonzero(tensor([3,4,5]), as_tuple=True) a^ b^ c^ -> ?, ?, ?
Nonzero(tensor([3,4,5])) a^ b^ c^ -> ?
FullSlice(IRTensor([3, 4]), [True, False, True]) a^ b^, ? -> ?
We should treat all dims in the outputs as dynamic?
What if some output dims are dynamic and some are not? (For exmaple, a b, ? -> a ? ?, not legal now)1.4: Op with different behavior
Max(input, dim, keepdim)
if dim/keepdim is not constant, we can't predict how the output dim looks like in runtime.
Should we trigger an error or let it go (and treat all dims of outputs as dynamic) ?
1.5 Hidden dimension in Input
We should treat the dynamicness of all hidden dimensions the same with input?
1.6 Custom Op:
Totally reply on annotation, so128!a,bis necessary. For official op, we can use!a,bor just resolve the dynamicness in implementation (return128or128!), so specifying dependency is optional. -
IRPyFunc:
2.1 tensor.shape/tensor.size: set is_dynamic based on dynamic dims
2.2 other python functions: set is_dynmaic if any of its inputs are dynamic or tensor -
IRFwOperation: If none of input tensors has dynamic dims, then all outputs will not have dynamic dims. Otherwise, all dims of outputs are dynamic. Too strict? (autograd function)
-
Other Ops(IRPad/IRConv2d/etc): Ignore. I think we should remove it some time.