Skip to content

Commit d6d9825

Browse files
author
ssjia
committed
Update on "[ET-VK] Implement missing Vulkan operators for Parakeet TDT model"
Add missing operators needed for Parakeet TDT model support: - New symint ops: sym_sub, sym_floordiv, sym_mul in SymIntOps.cpp; register operator.floordiv and operator.mul as ephemeral ops in op_registry.py - New tensor ops: bitwise_not (via unary_op shader with uint8 DTYPE), logical_and (alias for bitwise_and dispatch) - Improve _to_copy: expand dtype support to FP_INT_BOOL_T and use pick_io_storage_fn to restrict to CONTIGUOUS_BUFFER for non-fp conversions - Fix where resize: compute output shape via broadcast across all tensor inputs instead of always using the second input's shape - Add symint support to split: use extract_int_or_symint_list instead of get_int_list in resize_split_node and split_with_sizes_copy_default - Mark scalar_tensor as supporting resize Differential Revision: [D95970159](https://our.internmc.facebook.com/intern/diff/D95970159/) cc manuelcandales digantdesai cbilgin [ghstack-poisoned]
2 parents 91ab881 + e67e10a commit d6d9825

2 files changed

Lines changed: 8 additions & 2 deletions

File tree

backends/vulkan/runtime/graph/ops/glsl/repeat_texture.glsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ void main() {
4545

4646
VEC4_T out_texel = VEC4_T(0);
4747

48-
int limit = min(
48+
const int limit = min(
4949
4, out_meta.sizes[packed_dim] - out_tidx.data[packed_dim]);
50-
for (int comp = 0; comp < 4; comp++) {
50+
for (int comp = 0; comp < limit; comp++) {
5151
TensorIndex4D in_tidx = out_tidx;
5252
in_tidx.data = ivec4(
5353
out_tidx.data.x % in_meta.sizes.x,

backends/vulkan/runtime/graph/ops/impl/Expand.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ void resize_expand_node(
2828
const std::vector<int64_t> target_sizes =
2929
graph->extract_int_or_symint_list(size_ref);
3030

31+
VK_CHECK_COND(
32+
target_sizes.size() >= in_sizes.size(),
33+
"expand: target sizes must have at least as many dims as input");
34+
VK_CHECK_COND(
35+
!target_sizes.empty(), "expand: target sizes must not be empty");
36+
3137
const size_t dim_offset = target_sizes.size() - in_sizes.size();
3238
std::vector<int64_t> out_sizes(target_sizes.size());
3339
for (size_t i = 0; i < target_sizes.size(); i++) {

0 commit comments

Comments
 (0)