diff --git a/lib/mpsgraphs/MPSGraphs.jl b/lib/mpsgraphs/MPSGraphs.jl index 2a06ae140..af168e45c 100644 --- a/lib/mpsgraphs/MPSGraphs.jl +++ b/lib/mpsgraphs/MPSGraphs.jl @@ -47,5 +47,6 @@ include("operations.jl") include("random.jl") include("matmul.jl") +include("conv.jl") end diff --git a/lib/mpsgraphs/conv.jl b/lib/mpsgraphs/conv.jl new file mode 100644 index 000000000..58e7485bd --- /dev/null +++ b/lib/mpsgraphs/conv.jl @@ -0,0 +1,111 @@ +struct Conv2DGraphKey + size_x::Tuple{Vararg{Int}} + size_w::Tuple{Vararg{Int}} + size_y::Tuple{Vararg{Int}} + eltype_xw::DataType + eltype_y::DataType + stride::NTuple{2, Int} + dilation::NTuple{2, Int} + padding::NTuple{4, Int} + groups::Int +end + +function Conv2DGraphKey( + x::MtlArray{Tx, 4}, w::MtlArray{Tx, 4}, y::MtlArray{Ty, 4}, + stride::NTuple{2, Int}, dilation::NTuple{2, Int}, + padding::NTuple{4, Int}, groups::Integer + ) where {Tx, Ty} + return Conv2DGraphKey(size(x), size(w), size(y), Tx, Ty, stride, dilation, padding, Int(groups)) +end + +struct CachedConv2DGraph + graph::MPSGraph + place_y::MPSGraphTensor + place_x::MPSGraphTensor + place_w::MPSGraphTensor + result::MPSGraphTensor +end + +function CachedConv2DGraph(key::Conv2DGraphKey) + graph = MPSGraph() + + placeX = placeholderTensor(graph, key.size_x, key.eltype_xw) + placeW = placeholderTensor(graph, key.size_w, key.eltype_xw) + placeY = placeholderTensor(graph, key.size_y, key.eltype_y) + + castT = key.eltype_xw <: Integer ? key.eltype_y : key.eltype_xw + castX = castTensor(graph, placeX, castT, "castX") + castW = castTensor(graph, placeW, castT, "castW") + + conv_desc = MPSGraphConvolution2DOpDescriptor(; + stride = key.stride, + dilation = key.dilation, + padding = key.padding, + groups = key.groups, + dataLayout = MPSGraphTensorNamedDataLayoutNCHW, + weightsLayout = MPSGraphTensorNamedDataLayoutOIHW, + paddingStyle = MPSGraphPaddingStyleExplicit, + ) + + conv = convolution2DWithSourceTensor(graph, castX, castW, conv_desc) + castY = castTensor(graph, conv, key.eltype_y, "castY") + + return CachedConv2DGraph(graph, placeY, placeX, placeW, castY) +end + +function _get_cached_graph!(graph_cache_lock, graph_cache, key::Conv2DGraphKey) + cached = get(graph_cache, key, nothing) + if cached !== nothing + return cached + end + + return @lock graph_cache_lock get!(graph_cache, key) do + CachedConv2DGraph(key) + end +end + +const _conv2d_graph_cache = Dict{Conv2DGraphKey, CachedConv2DGraph}() +const _conv2d_graph_cache_lock = ReentrantLock() + +@inline _conv2d_padding(padding::Integer) = (Int(padding), Int(padding), Int(padding), Int(padding)) +@inline _conv2d_padding(padding::NTuple{2, <:Integer}) = (Int(padding[1]), Int(padding[1]), Int(padding[2]), Int(padding[2])) +@inline _conv2d_padding(padding::NTuple{4, <:Integer}) = (Int(padding[1]), Int(padding[2]), Int(padding[3]), Int(padding[4])) + +@autoreleasepool function _conv2d!( + y::MtlArray{Ty, 4}, x::MtlArray{Tx, 4}, w::MtlArray{Tx, 4}, + stride::NTuple{2, Int}, dilation::NTuple{2, Int}, + padding::NTuple{4, Int}, groups::Integer + ) where {Ty, Tx} + key = Conv2DGraphKey(x, w, y, stride, dilation, padding, groups) + cached = _get_cached_graph!(_conv2d_graph_cache_lock, _conv2d_graph_cache, key) + + feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( + cached.place_x => MPSGraphTensorData(x), + cached.place_w => MPSGraphTensorData(w), + cached.place_y => MPSGraphTensorData(y), + ) + + resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}( + cached.result => feeds[cached.place_y], + ) + + cmdbuf = MPSCommandBuffer(Metal.global_queue(device())) + encode!(cmdbuf, cached.graph, NSDictionary(feeds), NSDictionary(resultdict), nil, default_exec_desc()) + commit!(cmdbuf) + wait_completed(cmdbuf) + + y +end + +function graph_conv!( + y::MtlArray{Ty, 4}, x::MtlArray{Tx, 4}, w::MtlArray{Tx, 4}; + stride::NTuple{2, <:Integer} = (1, 1), + dilation::NTuple{2, <:Integer} = (1, 1), + padding::Union{Integer, NTuple{2, <:Integer}, NTuple{4, <:Integer}} = (0, 0, 0, 0), + groups::Integer = 1 + ) where {Ty, Tx} + return _conv2d!( + y, x, w, (Int(stride[1]), Int(stride[2])), (Int(dilation[1]), Int(dilation[2])), + _conv2d_padding(padding), groups + ) +end diff --git a/lib/mpsgraphs/operations.jl b/lib/mpsgraphs/operations.jl index 107c9ae31..8aa880c6f 100644 --- a/lib/mpsgraphs/operations.jl +++ b/lib/mpsgraphs/operations.jl @@ -1,68 +1,125 @@ - -function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shape::MPSShape, name="broadcast") - obj = @objc [graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor} - toShape:shape::id{MPSShape} - name:name::id{NSString}]::id{MPSGraphTensor} - MPSGraphTensor(obj) +function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shape::MPSShape, name = "broadcast") + obj = @objc [ + graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor} + toShape:shape::id{MPSShape} + name:name::id{NSString} + ]::id{MPSGraphTensor} + return MPSGraphTensor(obj) end -function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shapeTensor::MPSGraphTensor, name="broadcast") - obj = @objc [graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor} - toShapeTensor:shapeTensor::id{MPSGraphTensor} - name:name::id{NSString}]::id{MPSGraphTensor} - MPSGraphTensor(obj) +function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shapeTensor::MPSGraphTensor, name = "broadcast") + obj = @objc [ + graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor} + toShapeTensor:shapeTensor::id{MPSGraphTensor} + name:name::id{NSString} + ]::id{MPSGraphTensor} + return MPSGraphTensor(obj) end function castTensor(graph::MPSGraph, tensor::MPSGraphTensor, toType, name = "cast") - obj = @objc [graph::id{MPSGraph} castTensor:tensor::id{MPSGraphTensor} - toType:toType::MPSDataType - name:name::id{NSString}]::id{MPSGraphTensor} - MPSGraphTensor(obj) + obj = @objc [ + graph::id{MPSGraph} castTensor:tensor::id{MPSGraphTensor} + toType:toType::MPSDataType + name:name::id{NSString} + ]::id{MPSGraphTensor} + return MPSGraphTensor(obj) end function constantWithScalar(graph::MPSGraph, scalar::Number, dataType) - obj = @objc [graph::id{MPSGraph} constantWithScalar:scalar::Float64 - dataType:dataType::MPSDataType]::id{MPSGraphTensor} - MPSGraphTensor(obj) + obj = @objc [ + graph::id{MPSGraph} constantWithScalar:scalar::Float64 + dataType:dataType::MPSDataType + ]::id{MPSGraphTensor} + return MPSGraphTensor(obj) end function matrixMultiplicationWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "matmul") - obj = @objc [graph::id{MPSGraph} matrixMultiplicationWithPrimaryTensor:primary::id{MPSGraphTensor} - secondaryTensor:secondary::id{MPSGraphTensor} - name:name::id{NSString}]::id{MPSGraphTensor} - MPSGraphTensor(obj) + obj = @objc [ + graph::id{MPSGraph} matrixMultiplicationWithPrimaryTensor:primary::id{MPSGraphTensor} + secondaryTensor:secondary::id{MPSGraphTensor} + name:name::id{NSString} + ]::id{MPSGraphTensor} + return MPSGraphTensor(obj) end function multiplicationWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "mul") - obj = @objc [graph::id{MPSGraph} multiplicationWithPrimaryTensor:primary::id{MPSGraphTensor} - secondaryTensor:secondary::id{MPSGraphTensor} - name:name::id{NSString}]::id{MPSGraphTensor} - MPSGraphTensor(obj) + obj = @objc [ + graph::id{MPSGraph} multiplicationWithPrimaryTensor:primary::id{MPSGraphTensor} + secondaryTensor:secondary::id{MPSGraphTensor} + name:name::id{NSString} + ]::id{MPSGraphTensor} + return MPSGraphTensor(obj) end function additionWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "add") - obj = @objc [graph::id{MPSGraph} additionWithPrimaryTensor:primary::id{MPSGraphTensor} - secondaryTensor:secondary::id{MPSGraphTensor} - name:name::id{NSString}]::id{MPSGraphTensor} - MPSGraphTensor(obj) + obj = @objc [ + graph::id{MPSGraph} additionWithPrimaryTensor:primary::id{MPSGraphTensor} + secondaryTensor:secondary::id{MPSGraphTensor} + name:name::id{NSString} + ]::id{MPSGraphTensor} + return MPSGraphTensor(obj) end function transposeTensor(graph::MPSGraph, tensor::MPSGraphTensor, dimension, withDimension, name = "transpose") - obj = @objc [graph::id{MPSGraph} transposeTensor:tensor::id{MPSGraphTensor} - dimension:dimension::NSUInteger - withDimension:withDimension::NSUInteger - name:name::id{NSString}]::id{MPSGraphTensor} - MPSGraphTensor(obj) + obj = @objc [ + graph::id{MPSGraph} transposeTensor:tensor::id{MPSGraphTensor} + dimension:dimension::NSUInteger + withDimension:withDimension::NSUInteger + name:name::id{NSString} + ]::id{MPSGraphTensor} + return MPSGraphTensor(obj) end function shapeOfTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "shapeOfTensor") - obj = @objc [graph::id{MPSGraph} shapeOfTensor:tensor::id{MPSGraphTensor} - name:name::id{NSString}]::id{MPSGraphTensor} - MPSGraphTensor(obj) + obj = @objc [ + graph::id{MPSGraph} shapeOfTensor:tensor::id{MPSGraphTensor} + name:name::id{NSString} + ]::id{MPSGraphTensor} + return MPSGraphTensor(obj) end function identityWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "identity") - obj = @objc [graph::id{MPSGraph} identityWithTensor:tensor::id{MPSGraphTensor} - name:name::id{NSString}]::id{MPSGraphTensor} - MPSGraphTensor(obj) + obj = @objc [ + graph::id{MPSGraph} identityWithTensor:tensor::id{MPSGraphTensor} + name:name::id{NSString} + ]::id{MPSGraphTensor} + return MPSGraphTensor(obj) +end + +function MPSGraphConvolution2DOpDescriptor(; + stride::NTuple{2, <:Integer} = (1, 1), + dilation::NTuple{2, <:Integer} = (1, 1), + padding::NTuple{4, <:Integer} = (0, 0, 0, 0), + groups::Integer = 1, + dataLayout = MPSGraphTensorNamedDataLayoutNCHW, + weightsLayout = MPSGraphTensorNamedDataLayoutOIHW, + paddingStyle = MPSGraphPaddingStyleExplicit + ) + desc = MPSGraphConvolution2DOpDescriptor(@objc [MPSGraphConvolution2DOpDescriptor new]::id{MPSGraphConvolution2DOpDescriptor}) + desc.strideInX = UInt64(stride[1]) + desc.strideInY = UInt64(stride[2]) + desc.dilationRateInX = UInt64(dilation[1]) + desc.dilationRateInY = UInt64(dilation[2]) + desc.paddingLeft = UInt64(padding[1]) + desc.paddingRight = UInt64(padding[2]) + desc.paddingTop = UInt64(padding[3]) + desc.paddingBottom = UInt64(padding[4]) + desc.paddingStyle = paddingStyle + desc.dataLayout = dataLayout + desc.weightsLayout = weightsLayout + desc.groups = UInt64(groups) + return desc +end + +function convolution2DWithSourceTensor( + graph::MPSGraph, source::MPSGraphTensor, weights::MPSGraphTensor, + descriptor::MPSGraphConvolution2DOpDescriptor, name = "conv2d" + ) + obj = @objc [ + graph::id{MPSGraph} convolution2DWithSourceTensor:source::id{MPSGraphTensor} + weightsTensor:weights::id{MPSGraphTensor} + descriptor:descriptor::id{MPSGraphConvolution2DOpDescriptor} + name:name::id{NSString} + ]::id{MPSGraphTensor} + return MPSGraphTensor(obj) end """ diff --git a/test/mpsgraphs/conv.jl b/test/mpsgraphs/conv.jl new file mode 100644 index 000000000..4c461d181 --- /dev/null +++ b/test/mpsgraphs/conv.jl @@ -0,0 +1,69 @@ +if MPS.is_supported(device()) + + function conv2d_reference(x, w; stride = (1, 1), dilation = (1, 1), padding = (0, 0, 0, 0), groups = 1) + W, H, Cin, N = size(x) + Kw, Kh, CinPerGroup, Cout = size(w) + + sx, sy = stride + dx, dy = dilation + pl, pr, pt, pb = padding + + @assert CinPerGroup * groups == Cin + @assert Cout % groups == 0 + + Wout = fld(W + pl + pr - dx * (Kw - 1) - 1, sx) + 1 + Hout = fld(H + pt + pb - dy * (Kh - 1) - 1, sy) + 1 + + Tout = promote_type(eltype(x), eltype(w)) + y = zeros(Tout, Wout, Hout, Cout, N) + + CoutPerGroup = div(Cout, groups) + for n in 1:N, co in 1:Cout, oh in 1:Hout, ow in 1:Wout + g = div(co - 1, CoutPerGroup) + 1 + acc = zero(Tout) + + for ci_local in 1:CinPerGroup, kh in 1:Kh, kw in 1:Kw + ci = (g - 1) * CinPerGroup + ci_local + iw = (ow - 1) * sx - pl + (kw - 1) * dx + 1 + ih = (oh - 1) * sy - pt + (kh - 1) * dy + 1 + if 1 <= iw <= W && 1 <= ih <= H + acc += x[iw, ih, ci, n] * w[kw, kh, ci_local, co] + end + end + + y[ow, oh, co, n] = acc + end + + return y + end + + @testset "mpsgraph convolution 2D" begin + x = rand(Float32, 9, 8, 4, 2) + w = rand(Float32, 3, 2, 4, 6) + stride = (2, 1) + dilation = (1, 2) + padding = (1, 0, 2, 1) + + y_ref = conv2d_reference(x, w; stride, dilation, padding, groups = 1) + y = Metal.zeros(Float32, size(y_ref)) + + MPSGraphs.graph_conv!(y, MtlArray(x), MtlArray(w); stride, dilation, padding, groups = 1) + @test Array(y) ≈ y_ref rtol = 1.0f-5 atol = 1.0f-5 + end + + @testset "mpsgraph grouped convolution 2D" begin + x = rand(Float32, 10, 7, 4, 2) + w = rand(Float32, 3, 3, 2, 6) + stride = (1, 2) + dilation = (1, 1) + padding = (1, 1, 1, 1) + groups = 2 + + y_ref = conv2d_reference(x, w; stride, dilation, padding, groups) + y = Metal.zeros(Float32, size(y_ref)) + + MPSGraphs.graph_conv!(y, MtlArray(x), MtlArray(w); stride, dilation, padding, groups) + @test Array(y) ≈ y_ref rtol = 1.0f-5 atol = 1.0f-5 + end + +end # MPS.is_supported(device())