Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions src/datadeps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ function is_writedep(arg, deps, task::DTask)
end

# Aliasing state setup
function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)
function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask, proc::Processor)
# Populate task dependencies
dependencies_to_add = Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}()

Expand All @@ -278,6 +278,11 @@ function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)
# Skip non-aliasing arguments
type_may_alias(typeof(arg)) || continue

# Unwrap Shards
if arg isa Shard
arg = shard_unwrap(arg, proc)
end

# Add all aliasing dependencies
for (dep_mod, readdep, writedep) in deps
if state.aliasing
Expand Down Expand Up @@ -592,9 +597,6 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
pressures = Dict{Processor,Int}()
proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024)
for (spec, task) in queue.seen_tasks[task_order]
# Populate all task dependencies
populate_task_info!(state, spec, task)

task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope())
scheduler = queue.scheduler
if scheduler == :naive
Expand Down Expand Up @@ -737,6 +739,9 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
@assert our_proc in all_procs
our_space = only(memory_spaces(our_proc))

# Populate all task dependencies
populate_task_info!(state, spec, task, our_proc)

# Find the scope for this task (and its copies)
if task_scope == scope
# Optimize for the common case, cache the proc=>scope mapping
Expand Down Expand Up @@ -776,6 +781,11 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
continue
end

# Unwrap Shards
if arg isa Shard
arg = shard_unwrap(arg, our_proc)
end

# Is the source of truth elsewhere?
arg_remote = get!(get!(IdDict{Any,Any}, state.remote_args, our_space), arg) do
generate_slot!(state, our_space, arg)
Expand Down Expand Up @@ -851,6 +861,12 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
arg = arg isa DTask ? fetch(arg; raw=true) : arg
type_may_alias(typeof(arg)) || continue
supports_inplace_move(state, arg) || continue

# Unwrap Shards
if arg isa Shard
arg = shard_unwrap(arg, our_proc)
end

if queue.aliasing
for (dep_mod, _, writedep) in deps
ainfo = aliasing(astate, arg, dep_mod)
Expand Down Expand Up @@ -884,6 +900,12 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
arg, deps = unwrap_inout(arg)
arg = arg isa DTask ? fetch(arg; raw=true) : arg
type_may_alias(typeof(arg)) || continue

# Unwrap Shards
if arg isa Shard
arg = shard_unwrap(arg, our_proc)
end

if queue.aliasing
for (dep_mod, _, writedep) in deps
ainfo = aliasing(astate, arg, dep_mod)
Expand Down
2 changes: 2 additions & 0 deletions src/memory-spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ end
type_may_alias(::Type{String}) = false
type_may_alias(::Type{Symbol}) = false
type_may_alias(::Type{<:Type}) = false
type_may_alias(::Type{Shard}) = true
type_may_alias(::Type{C}) where C<:Chunk{T} where T = type_may_alias(T)
function type_may_alias(::Type{T}) where T
if isbitstype(T)
Expand Down Expand Up @@ -213,6 +214,7 @@ end
aliasing(::String) = NoAliasing() # FIXME: Not necessarily true
aliasing(::Symbol) = NoAliasing()
aliasing(::Type) = NoAliasing()
aliasing(::Shard) = throw(ArgumentError("Cannot resolve aliasing for Shard"))
aliasing(x::Chunk, T) = remotecall_fetch(root_worker_id(x.processor), x, T) do x, T
aliasing(unwrap(x), T)
end
Expand Down
12 changes: 7 additions & 5 deletions src/utils/chunks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,25 @@ macro shard(exs...)
end
end

function move(from_proc::Processor, to_proc::Processor, shard::Shard)
function shard_unwrap(shard::Shard, proc::Processor)
# Match either this proc or some ancestor
# N.B. This behavior may bypass the piece's scope restriction
proc = to_proc
if haskey(shard.chunks, proc)
return move(from_proc, to_proc, shard.chunks[proc])
return shard.chunks[proc]
end
parent = Dagger.get_parent(proc)
while parent != proc
proc = parent
parent = Dagger.get_parent(proc)
if haskey(shard.chunks, proc)
return move(from_proc, to_proc, shard.chunks[proc])
return shard.chunks[proc]
end
end

throw(KeyError(to_proc))
throw(KeyError(proc))
end
function move(from_proc::Processor, to_proc::Processor, shard::Shard)
return move(from_proc, to_proc, shard_unwrap(shard, to_proc))
end
Base.iterate(s::Shard) = iterate(values(s.chunks))
Base.iterate(s::Shard, state) = iterate(values(s.chunks), state)
Expand Down
Loading