diff --git a/Sources/SwiftSlang/SLTypeLayout.h b/Sources/SwiftSlang/SLTypeLayout.h index c9b4dba..42771e6 100644 --- a/Sources/SwiftSlang/SLTypeLayout.h +++ b/Sources/SwiftSlang/SLTypeLayout.h @@ -38,6 +38,10 @@ NS_ASSUME_NONNULL_BEGIN /// Corresponds to TypeLayoutReflection::getFieldByIndex() - (nullable SLVariableLayoutReflection *)getFieldByIndex:(unsigned int)index; +/// Corresponds to TypeLayoutReflection::getResourceResultType() +/// Returns the element type for resource types (e.g., float4 in RWTexture2D). +- (nullable SLTypeReflection *)getResourceResultType; + @end NS_ASSUME_NONNULL_END diff --git a/Sources/SwiftSlang/SLTypeLayout.mm b/Sources/SwiftSlang/SLTypeLayout.mm index fde4d69..e3db413 100644 --- a/Sources/SwiftSlang/SLTypeLayout.mm +++ b/Sources/SwiftSlang/SLTypeLayout.mm @@ -79,4 +79,11 @@ - (nullable SLVariableLayoutReflection *)getFieldByIndex:(unsigned int)index { return [[SLVariableLayoutReflection alloc] initWithVariableLayoutReflectionPtr:field]; } +- (nullable SLTypeReflection *)getResourceResultType { + if (!_typeLayout) return nil; + slang::TypeReflection* resultType = _typeLayout->getResourceResultType(); + if (!resultType) return nil; + return [[SLTypeReflection alloc] initWithTypeReflectionPtr:resultType]; +} + @end diff --git a/Tests/SwiftSlangTests/SwiftSlangTests.swift b/Tests/SwiftSlangTests/SwiftSlangTests.swift index 5eccb55..afab6ba 100644 --- a/Tests/SwiftSlangTests/SwiftSlangTests.swift +++ b/Tests/SwiftSlangTests/SwiftSlangTests.swift @@ -321,6 +321,56 @@ final class SwiftSlangTests: XCTestCase { // MARK: - Resource Types + func testResourceResultType() throws { + let globalSession = try SLGlobalSession.create() + let profile = globalSession.findProfile("sm_5_0") + let targetDesc = SLTargetDesc(format: .metal, profile: profile) + let sessionDesc = SLSessionDesc() + sessionDesc.targets = [targetDesc] + let session = try globalSession.createSession(with: sessionDesc) + + let source = """ + RWTexture2D rwTex4; + RWTexture2D rwTex1; + RWTexture2D rwTexH4; + [shader("compute")] + [numthreads(1,1,1)] + void csMain(uint3 tid : SV_DispatchThreadID) { + rwTex4[tid.xy] = float4(0,0,0,0); + rwTex1[tid.xy] = 0; + rwTexH4[tid.xy] = half4(0,0,0,0); + } + """ + let module = try session.loadModule(fromSourceString: "Test", path: "", source: source) + let entryPoint = try module.entryPoint(at: 0) + let composite = try session.createCompositeComponentType(with: module, entryPoints: [entryPoint]) + let linked = try composite.link() + let params = try linked.getShaderParameters(0) + + // RWTexture2D + let tex4Param = try XCTUnwrap(params.first { $0.name == "rwTex4" }) + let tex4TypeLayout = try XCTUnwrap(tex4Param.typeLayout) + let tex4ResultType = try XCTUnwrap(tex4TypeLayout.getResourceResultType()) + XCTAssertEqual(tex4ResultType.getKind(), .vector) + XCTAssertEqual(tex4ResultType.getScalarType(), .float32) + XCTAssertEqual(tex4ResultType.getElementCount(), 4) + + // RWTexture2D + let tex1Param = try XCTUnwrap(params.first { $0.name == "rwTex1" }) + let tex1TypeLayout = try XCTUnwrap(tex1Param.typeLayout) + let tex1ResultType = try XCTUnwrap(tex1TypeLayout.getResourceResultType()) + XCTAssertEqual(tex1ResultType.getKind(), .scalar) + XCTAssertEqual(tex1ResultType.getScalarType(), .float32) + + // RWTexture2D + let texH4Param = try XCTUnwrap(params.first { $0.name == "rwTexH4" }) + let texH4TypeLayout = try XCTUnwrap(texH4Param.typeLayout) + let texH4ResultType = try XCTUnwrap(texH4TypeLayout.getResourceResultType()) + XCTAssertEqual(texH4ResultType.getKind(), .vector) + XCTAssertEqual(texH4ResultType.getScalarType(), .float16) + XCTAssertEqual(texH4ResultType.getElementCount(), 4) + } + func testTextureAndSamplerParameters() throws { let globalSession = try SLGlobalSession.create() let profile = globalSession.findProfile("sm_5_0")