Skip to content

Bad performance for sort!(CoSorter(os._data, os.scalars); alg=ThreadsX.QuickSort) #42

@xiangjianqian

Description

@xiangjianqian

I noticed that the function sort!(CoSorter(os._data, os.scalars); alg=ThreadsX.QuickSort) consumes a significant amount of runtime. I've found a potential way to improve this, but I'm unsure whether my approach is correct. Could someone help me verify it? Below is a short demonstration for the modification of the code (inside the MPOGraph function):

    use_perm = true
    t_sort = @elapsed begin
        resize!(os._data, length(os))
        resize!(os.scalars, length(os))
        n = length(os)
        if use_perm
            @inline function term_key(t::NTuple{N,ITensorMPOConstruction.OpID{Ti}}, maxn::Int, maxid::Int)::UInt128 where {N,Ti}
                stride = UInt128(maxid + 1)
                base = UInt128((maxn + 1) * (maxid + 1))
                s = UInt128(0)
                @inbounds for i in 1:N
                    ti = t[i]
                    digit = UInt128(Int(ti.n)) * stride + UInt128(Int(ti.id))
                    s = s * base + digit
                end
                return s
            end
            keys = Vector{UInt128}(undef, n)
            @inbounds for i in eachindex(os._data)
                keys[i] = term_key(os._data[i],n,length(os.op_cache_vec[1]))
            end
            perm = sortperm(keys)
        else
            perm = sortperm(os._data; alg=Base.Sort.QuickSort)
            # sort!(ITensorMPOConstruction.CoSorter(os._data, os.scalars); alg=Base.Sort.QuickSort)
        end
    end

    local nnz = 0
    t_combine = @elapsed begin
        # permute!(os._data, perm)
        # permute!(os.scalars, perm)
        # # # println("sorted terms: ", os._data)
        # nnz = 0
        # for i in eachindex(os)
        #     if i < length(os) && os._data[i] == os._data[i + 1]
        #         os.scalars[i + 1] += os.scalars[i]
        #         os.scalars[i] = 0
        #     elseif abs(os.scalars[i]) > os.abs_tol
        #         nnz += 1
        #         os.scalars[nnz] = os.scalars[i]
        #         os._data[nnz] = os._data[i]
        #     end
        # end
        # Alternative combine approach that avoids branching and may be more efficient for large numbers of terms:
        # Iterate in sorted-permutation order and write compacted result in place.
        x= Int[]
        for i in 1:length(perm)
            if i < length(os) && os._data[perm[i]] == os._data[perm[i+1]]
                os.scalars[perm[i+1]] += os.scalars[perm[i]]
                os.scalars[perm[i]] = 0
            elseif abs(os.scalars[perm[i]]) > os.abs_tol
                push!(x, perm[i])
                nnz += 1
            end
        end
        os.scalars = os.scalars[x]
        os._data = os._data[x]
    end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions