Skip to content

Allow typecast fusion#602

Open
vtsilytskyiTT wants to merge 4 commits into
mainfrom
vlad/typecast-fusion-fixed-2
Open

Allow typecast fusion#602
vtsilytskyiTT wants to merge 4 commits into
mainfrom
vlad/typecast-fusion-fixed-2

Conversation

@vtsilytskyiTT
Copy link
Copy Markdown
Contributor

Problem description

ttl.typecast could not participate in fused ttl.compute regions, which forced dtype-changing expressions to break fusion or fail during TTL-to-compute lowering. This also made mixed dtype fusion behavior unclear, especially when f32 inputs required unpack_to_dest_fp32 handling.

What's changed

This change enables ttl.typecast fusion by deriving tile result types from each source op's tensor result instead of relying on a single default tile type. It updates compute lowering, Python AST/API type handling, and kernel config assignment so fused compute regions can preserve dtype-changing intermediates correctly.
The compute kernel config logic now tracks unpack_to_dest_fp32 per CB and detects conflicts where the same f32 CB is consumed by both FPU and SFPU strategies. Tests were added and updated to cover typecast fusion, mixed dtype rejection, and unpack_to_dest_fp32 positive, negative, and conflict cases.

Ticket

#264

Checklist

  • New/Existing tests provide coverage for changes

@vtsilytskyiTT vtsilytskyiTT requested a review from a team as a code owner May 19, 2026 17:34
@vtsilytskyiTT vtsilytskyiTT force-pushed the vlad/typecast-fusion-fixed-2 branch from 2e47eeb to 4fdbe5e Compare May 19, 2026 18:40
Comment thread python/ttl/ttl_api.py Outdated
@vtsilytskyiTT vtsilytskyiTT force-pushed the vlad/typecast-fusion-fixed-2 branch from 5eb5955 to 582baf0 Compare May 21, 2026 11:05
@vtsilytskyiTT
Copy link
Copy Markdown
Contributor Author

Resolve conflicts.

Copy link
Copy Markdown
Contributor

@zoecarver zoecarver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool! LGTM. Thank you, Vlad!

@@ -557,7 +589,7 @@ static LogicalResult buildFusedCompute(Operation *sinkOp,
Value inputTile = tensorToTile[bcastOp.getInput()];
Value outputTile = body->getArguments().back(); // output block arg
auto bcastTileOp = createTileOpWithPlaceholderDstIndex<TileBcastOp>(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still uses outputTileType, which is the final fused sink type, not necessarily this particular op's result type. That breaks cases like typecast(bcast(...)), where the bcast result should keep its own dtype and the later ttl.tile_typecast should perform the conversion.

The same applies to the matmul special cases below. These special-case emitters should derive the result tile type from the current source op, like emitTileOpFor() does.

}
}
for (int64_t cb : conflicts) {
fpuCBs[cb]->emitWarning()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a pass error instead of easy to miss warning (since could lead to losing f32 precision)?

@brnorris03
Copy link
Copy Markdown
Contributor

Can you maybe add a couple fused testcases like typecast(matmul) or typecast(matmul + a)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants