Skip to content

Commit 69b1251

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK][ez] Fix IndexError in Vulkan partitioner DtypeSetList/TensorRepSetList
Pull Request resolved: #18048 The `__getitem__` methods of `DtypeSetList` and `TensorRepSetList` in `utils.py` could raise an `IndexError` when the index is greater than or equal to the length of the list. This can happen when partitioning ops whose number of inputs or outputs exceeds the number of entries in the dtype/tensor-rep specification list. Fix by returning an empty set in this case, matching the intent of the existing broadcasting logic. ghstack-source-id: 353546684 @exported-using-ghexport Differential Revision: [D95970163](https://our.internmc.facebook.com/intern/diff/D95970163/)
1 parent c5f9a5a commit 69b1251

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

backends/vulkan/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def __getitem__(self, idx: int) -> DtypeSet:
9191
# Broadcasting: single set applies to all positions
9292
if idx > 0 and len(self.vals) == 1:
9393
return self.vals[0]
94+
if idx >= len(self.vals):
95+
return set()
9496
return self.vals[idx]
9597

9698
def is_empty(self) -> bool:
@@ -1227,8 +1229,9 @@ def __len__(self):
12271229
def __getitem__(self, idx: int) -> TensorRepSet:
12281230
if idx > 0 and len(self) == 1:
12291231
return self.vals[0]
1230-
else:
1231-
return self.vals[idx]
1232+
if idx >= len(self.vals):
1233+
return set()
1234+
return self.vals[idx]
12321235

12331236
def __setitem__(self, idx: int, val: TensorRepSet) -> None:
12341237
if idx > 0 and len(self.vals) == 1:

0 commit comments

Comments
 (0)