Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/mpsgraphs/MPSGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,6 @@ include("operations.jl")
include("random.jl")

include("matmul.jl")
include("conv.jl")

end
111 changes: 111 additions & 0 deletions lib/mpsgraphs/conv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
struct Conv2DGraphKey
size_x::Tuple{Vararg{Int}}
size_w::Tuple{Vararg{Int}}
size_y::Tuple{Vararg{Int}}
eltype_xw::DataType
eltype_y::DataType
stride::NTuple{2, Int}
dilation::NTuple{2, Int}
padding::NTuple{4, Int}
groups::Int
end

function Conv2DGraphKey(
x::MtlArray{Tx, 4}, w::MtlArray{Tx, 4}, y::MtlArray{Ty, 4},
stride::NTuple{2, Int}, dilation::NTuple{2, Int},
padding::NTuple{4, Int}, groups::Integer
) where {Tx, Ty}
return Conv2DGraphKey(size(x), size(w), size(y), Tx, Ty, stride, dilation, padding, Int(groups))
end

struct CachedConv2DGraph
graph::MPSGraph
place_y::MPSGraphTensor
place_x::MPSGraphTensor
place_w::MPSGraphTensor
result::MPSGraphTensor
end

function CachedConv2DGraph(key::Conv2DGraphKey)
graph = MPSGraph()

placeX = placeholderTensor(graph, key.size_x, key.eltype_xw)
placeW = placeholderTensor(graph, key.size_w, key.eltype_xw)
placeY = placeholderTensor(graph, key.size_y, key.eltype_y)

castT = key.eltype_xw <: Integer ? key.eltype_y : key.eltype_xw
castX = castTensor(graph, placeX, castT, "castX")
castW = castTensor(graph, placeW, castT, "castW")

conv_desc = MPSGraphConvolution2DOpDescriptor(;
stride = key.stride,
dilation = key.dilation,
padding = key.padding,
groups = key.groups,
dataLayout = MPSGraphTensorNamedDataLayoutNCHW,
weightsLayout = MPSGraphTensorNamedDataLayoutOIHW,
paddingStyle = MPSGraphPaddingStyleExplicit,
)

conv = convolution2DWithSourceTensor(graph, castX, castW, conv_desc)
castY = castTensor(graph, conv, key.eltype_y, "castY")

return CachedConv2DGraph(graph, placeY, placeX, placeW, castY)
end

function _get_cached_graph!(graph_cache_lock, graph_cache, key::Conv2DGraphKey)
cached = get(graph_cache, key, nothing)
if cached !== nothing
return cached
end

return @lock graph_cache_lock get!(graph_cache, key) do
CachedConv2DGraph(key)
end
end

const _conv2d_graph_cache = Dict{Conv2DGraphKey, CachedConv2DGraph}()
const _conv2d_graph_cache_lock = ReentrantLock()

@inline _conv2d_padding(padding::Integer) = (Int(padding), Int(padding), Int(padding), Int(padding))
@inline _conv2d_padding(padding::NTuple{2, <:Integer}) = (Int(padding[1]), Int(padding[1]), Int(padding[2]), Int(padding[2]))
@inline _conv2d_padding(padding::NTuple{4, <:Integer}) = (Int(padding[1]), Int(padding[2]), Int(padding[3]), Int(padding[4]))

@autoreleasepool function _conv2d!(
y::MtlArray{Ty, 4}, x::MtlArray{Tx, 4}, w::MtlArray{Tx, 4},
stride::NTuple{2, Int}, dilation::NTuple{2, Int},
padding::NTuple{4, Int}, groups::Integer
) where {Ty, Tx}
key = Conv2DGraphKey(x, w, y, stride, dilation, padding, groups)
cached = _get_cached_graph!(_conv2d_graph_cache_lock, _conv2d_graph_cache, key)

feeds = Dict{MPSGraphTensor, MPSGraphTensorData}(
cached.place_x => MPSGraphTensorData(x),
cached.place_w => MPSGraphTensorData(w),
cached.place_y => MPSGraphTensorData(y),
)

resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}(
cached.result => feeds[cached.place_y],
)

cmdbuf = MPSCommandBuffer(Metal.global_queue(device()))
encode!(cmdbuf, cached.graph, NSDictionary(feeds), NSDictionary(resultdict), nil, default_exec_desc())
commit!(cmdbuf)
wait_completed(cmdbuf)

y
end

function graph_conv!(
y::MtlArray{Ty, 4}, x::MtlArray{Tx, 4}, w::MtlArray{Tx, 4};
stride::NTuple{2, <:Integer} = (1, 1),
dilation::NTuple{2, <:Integer} = (1, 1),
padding::Union{Integer, NTuple{2, <:Integer}, NTuple{4, <:Integer}} = (0, 0, 0, 0),
groups::Integer = 1
) where {Ty, Tx}
return _conv2d!(
y, x, w, (Int(stride[1]), Int(stride[2])), (Int(dilation[1]), Int(dilation[2])),
_conv2d_padding(padding), groups
)
end
139 changes: 98 additions & 41 deletions lib/mpsgraphs/operations.jl
Original file line number Diff line number Diff line change
@@ -1,68 +1,125 @@

function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shape::MPSShape, name="broadcast")
obj = @objc [graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor}
toShape:shape::id{MPSShape}
name:name::id{NSString}]::id{MPSGraphTensor}
MPSGraphTensor(obj)
function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shape::MPSShape, name = "broadcast")
obj = @objc [
graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor}
toShape:shape::id{MPSShape}
name:name::id{NSString}
]::id{MPSGraphTensor}
return MPSGraphTensor(obj)
end
function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shapeTensor::MPSGraphTensor, name="broadcast")
obj = @objc [graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor}
toShapeTensor:shapeTensor::id{MPSGraphTensor}
name:name::id{NSString}]::id{MPSGraphTensor}
MPSGraphTensor(obj)
function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shapeTensor::MPSGraphTensor, name = "broadcast")
obj = @objc [
graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor}
toShapeTensor:shapeTensor::id{MPSGraphTensor}
name:name::id{NSString}
]::id{MPSGraphTensor}
return MPSGraphTensor(obj)
end

function castTensor(graph::MPSGraph, tensor::MPSGraphTensor, toType, name = "cast")
obj = @objc [graph::id{MPSGraph} castTensor:tensor::id{MPSGraphTensor}
toType:toType::MPSDataType
name:name::id{NSString}]::id{MPSGraphTensor}
MPSGraphTensor(obj)
obj = @objc [
graph::id{MPSGraph} castTensor:tensor::id{MPSGraphTensor}
toType:toType::MPSDataType
name:name::id{NSString}
]::id{MPSGraphTensor}
return MPSGraphTensor(obj)
end

function constantWithScalar(graph::MPSGraph, scalar::Number, dataType)
obj = @objc [graph::id{MPSGraph} constantWithScalar:scalar::Float64
dataType:dataType::MPSDataType]::id{MPSGraphTensor}
MPSGraphTensor(obj)
obj = @objc [
graph::id{MPSGraph} constantWithScalar:scalar::Float64
dataType:dataType::MPSDataType
]::id{MPSGraphTensor}
return MPSGraphTensor(obj)
end

function matrixMultiplicationWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "matmul")
obj = @objc [graph::id{MPSGraph} matrixMultiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
secondaryTensor:secondary::id{MPSGraphTensor}
name:name::id{NSString}]::id{MPSGraphTensor}
MPSGraphTensor(obj)
obj = @objc [
graph::id{MPSGraph} matrixMultiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
secondaryTensor:secondary::id{MPSGraphTensor}
name:name::id{NSString}
]::id{MPSGraphTensor}
return MPSGraphTensor(obj)
end

function multiplicationWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "mul")
obj = @objc [graph::id{MPSGraph} multiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
secondaryTensor:secondary::id{MPSGraphTensor}
name:name::id{NSString}]::id{MPSGraphTensor}
MPSGraphTensor(obj)
obj = @objc [
graph::id{MPSGraph} multiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
secondaryTensor:secondary::id{MPSGraphTensor}
name:name::id{NSString}
]::id{MPSGraphTensor}
return MPSGraphTensor(obj)
end
function additionWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "add")
obj = @objc [graph::id{MPSGraph} additionWithPrimaryTensor:primary::id{MPSGraphTensor}
secondaryTensor:secondary::id{MPSGraphTensor}
name:name::id{NSString}]::id{MPSGraphTensor}
MPSGraphTensor(obj)
obj = @objc [
graph::id{MPSGraph} additionWithPrimaryTensor:primary::id{MPSGraphTensor}
secondaryTensor:secondary::id{MPSGraphTensor}
name:name::id{NSString}
]::id{MPSGraphTensor}
return MPSGraphTensor(obj)
end

function transposeTensor(graph::MPSGraph, tensor::MPSGraphTensor, dimension, withDimension, name = "transpose")
obj = @objc [graph::id{MPSGraph} transposeTensor:tensor::id{MPSGraphTensor}
dimension:dimension::NSUInteger
withDimension:withDimension::NSUInteger
name:name::id{NSString}]::id{MPSGraphTensor}
MPSGraphTensor(obj)
obj = @objc [
graph::id{MPSGraph} transposeTensor:tensor::id{MPSGraphTensor}
dimension:dimension::NSUInteger
withDimension:withDimension::NSUInteger
name:name::id{NSString}
]::id{MPSGraphTensor}
return MPSGraphTensor(obj)
end

function shapeOfTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "shapeOfTensor")
obj = @objc [graph::id{MPSGraph} shapeOfTensor:tensor::id{MPSGraphTensor}
name:name::id{NSString}]::id{MPSGraphTensor}
MPSGraphTensor(obj)
obj = @objc [
graph::id{MPSGraph} shapeOfTensor:tensor::id{MPSGraphTensor}
name:name::id{NSString}
]::id{MPSGraphTensor}
return MPSGraphTensor(obj)
end

function identityWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "identity")
obj = @objc [graph::id{MPSGraph} identityWithTensor:tensor::id{MPSGraphTensor}
name:name::id{NSString}]::id{MPSGraphTensor}
MPSGraphTensor(obj)
obj = @objc [
graph::id{MPSGraph} identityWithTensor:tensor::id{MPSGraphTensor}
name:name::id{NSString}
]::id{MPSGraphTensor}
return MPSGraphTensor(obj)
end

function MPSGraphConvolution2DOpDescriptor(;
stride::NTuple{2, <:Integer} = (1, 1),
dilation::NTuple{2, <:Integer} = (1, 1),
padding::NTuple{4, <:Integer} = (0, 0, 0, 0),
groups::Integer = 1,
dataLayout = MPSGraphTensorNamedDataLayoutNCHW,
weightsLayout = MPSGraphTensorNamedDataLayoutOIHW,
paddingStyle = MPSGraphPaddingStyleExplicit
)
desc = MPSGraphConvolution2DOpDescriptor(@objc [MPSGraphConvolution2DOpDescriptor new]::id{MPSGraphConvolution2DOpDescriptor})
desc.strideInX = UInt64(stride[1])
desc.strideInY = UInt64(stride[2])
desc.dilationRateInX = UInt64(dilation[1])
desc.dilationRateInY = UInt64(dilation[2])
desc.paddingLeft = UInt64(padding[1])
desc.paddingRight = UInt64(padding[2])
desc.paddingTop = UInt64(padding[3])
desc.paddingBottom = UInt64(padding[4])
desc.paddingStyle = paddingStyle
desc.dataLayout = dataLayout
desc.weightsLayout = weightsLayout
desc.groups = UInt64(groups)
return desc
end

function convolution2DWithSourceTensor(
graph::MPSGraph, source::MPSGraphTensor, weights::MPSGraphTensor,
descriptor::MPSGraphConvolution2DOpDescriptor, name = "conv2d"
)
obj = @objc [
graph::id{MPSGraph} convolution2DWithSourceTensor:source::id{MPSGraphTensor}
weightsTensor:weights::id{MPSGraphTensor}
descriptor:descriptor::id{MPSGraphConvolution2DOpDescriptor}
name:name::id{NSString}
]::id{MPSGraphTensor}
return MPSGraphTensor(obj)
end

"""
Expand Down
69 changes: 69 additions & 0 deletions test/mpsgraphs/conv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
if MPS.is_supported(device())

function conv2d_reference(x, w; stride = (1, 1), dilation = (1, 1), padding = (0, 0, 0, 0), groups = 1)
W, H, Cin, N = size(x)
Kw, Kh, CinPerGroup, Cout = size(w)

sx, sy = stride
dx, dy = dilation
pl, pr, pt, pb = padding

@assert CinPerGroup * groups == Cin
@assert Cout % groups == 0

Wout = fld(W + pl + pr - dx * (Kw - 1) - 1, sx) + 1
Hout = fld(H + pt + pb - dy * (Kh - 1) - 1, sy) + 1

Tout = promote_type(eltype(x), eltype(w))
y = zeros(Tout, Wout, Hout, Cout, N)

CoutPerGroup = div(Cout, groups)
for n in 1:N, co in 1:Cout, oh in 1:Hout, ow in 1:Wout
g = div(co - 1, CoutPerGroup) + 1
acc = zero(Tout)

for ci_local in 1:CinPerGroup, kh in 1:Kh, kw in 1:Kw
ci = (g - 1) * CinPerGroup + ci_local
iw = (ow - 1) * sx - pl + (kw - 1) * dx + 1
ih = (oh - 1) * sy - pt + (kh - 1) * dy + 1
if 1 <= iw <= W && 1 <= ih <= H
acc += x[iw, ih, ci, n] * w[kw, kh, ci_local, co]
end
end

y[ow, oh, co, n] = acc
end

return y
end

@testset "mpsgraph convolution 2D" begin
x = rand(Float32, 9, 8, 4, 2)
w = rand(Float32, 3, 2, 4, 6)
stride = (2, 1)
dilation = (1, 2)
padding = (1, 0, 2, 1)

y_ref = conv2d_reference(x, w; stride, dilation, padding, groups = 1)
y = Metal.zeros(Float32, size(y_ref))

MPSGraphs.graph_conv!(y, MtlArray(x), MtlArray(w); stride, dilation, padding, groups = 1)
@test Array(y) ≈ y_ref rtol = 1.0f-5 atol = 1.0f-5
end

@testset "mpsgraph grouped convolution 2D" begin
x = rand(Float32, 10, 7, 4, 2)
w = rand(Float32, 3, 3, 2, 6)
stride = (1, 2)
dilation = (1, 1)
padding = (1, 1, 1, 1)
groups = 2

y_ref = conv2d_reference(x, w; stride, dilation, padding, groups)
y = Metal.zeros(Float32, size(y_ref))

MPSGraphs.graph_conv!(y, MtlArray(x), MtlArray(w); stride, dilation, padding, groups)
@test Array(y) ≈ y_ref rtol = 1.0f-5 atol = 1.0f-5
end

end # MPS.is_supported(device())