Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tweak order of operations to get
nnz
to infer as Int
return type
If the sparse array does not have a concrete index type, then union splitting occurs over the possible `<:Integer` types permitted by `SparseMatrixCSC`: ```julia julia> code_warntype(nnz, (SparseMatrixCSC{Float64,<:Integer},), optimize=true, debuginfo=:none) Variables #self#::Core.Const(SparseArrays.nnz) S::SparseMatrixCSC{Float64, var"#s96"} where var"#s96"<:Integer Body::Any 1 ── %1 = SparseArrays.getfield(S, :colptr)::Vector{var"#s96"} where var"#s96"<:Integer │ %2 = SparseArrays.getfield(S, :n)::Int64 │ %3 = Base.add_int(%2, 1)::Int64 │ %4 = Base.getindex(%1, %3)::Integer │ %5 = (isa)(%4, Int64)::Bool └─── goto JuliaLang#3 if not %5 2 ── %7 = π (%4, Int64) │ %8 = Base.sub_int(%7, 1)::Int64 └─── goto JuliaLang#15 3 ── %10 = (isa)(%4, BigInt)::Bool └─── goto JuliaLang#14 if not %10 4 ── %12 = π (%4, BigInt) │ %13 = Base.slt_int(1, 0)::Bool └─── goto JuliaLang#6 if not %13 5 ── %15 = Base.bitcast(UInt64, 1)::UInt64 │ %16 = Base.neg_int(%15)::UInt64 │ %17 = Base.GMP.MPZ.add_ui::typeof(Base.GMP.MPZ.add_ui) │ %18 = invoke %17(%12::BigInt, %16::UInt64)::BigInt └─── goto JuliaLang#13 6 ── %20 = Core.lshr_int(1, 63)::Int64 │ %21 = Core.trunc_int(Core.UInt8, %20)::UInt8 │ %22 = Core.eq_int(%21, 0x01)::Bool └─── goto JuliaLang#8 if not %22 7 ── invoke Core.throw_inexacterror(:check_top_bit::Symbol, UInt64::Type{UInt64}, 1::Int64) └─── unreachable 8 ── goto JuliaLang#9 9 ── %27 = Core.bitcast(Core.UInt64, 1)::UInt64 └─── goto JuliaLang#10 10 ─ goto JuliaLang#11 11 ─ goto JuliaLang#12 12 ─ %31 = Base.GMP.MPZ.sub_ui::typeof(Base.GMP.MPZ.sub_ui) │ %32 = invoke %31(%12::BigInt, %27::UInt64)::BigInt └─── goto JuliaLang#13 13 ┄ %34 = φ (JuliaLang#5 => %18, JuliaLang#12 => %32)::Any └─── goto JuliaLang#15 14 ─ %36 = (%4 - 1)::Any └─── goto JuliaLang#15 15 ┄ %38 = φ (JuliaLang#2 => %8, JuliaLang#13 => %34, JuliaLang#14 => %36)::Any │ %39 = SparseArrays.Int(%38)::Any └─── return %39 ``` It appears that union splitting over the subtraction by one includes an `Any` branch that widens the return type of `nnz`. By instead converting the index type to `Int` before subtracting, type inference is able to infer that all paths give an `Int` result: ```julia julia> code_warntype(nnz, (SparseMatrixCSC{Float64,<:Integer},), optimize=true, debuginfo=:none) Variables #self#::Core.Const(SparseArrays.nnz) S::SparseMatrixCSC{Float64, var"#s96"} where var"#s96"<:Integer Body::Int64 1 ── %1 = SparseArrays.getfield(S, :colptr)::Vector{var"#s96"} where var"#s96"<:Integer │ %2 = SparseArrays.getfield(S, :n)::Int64 │ %3 = Base.add_int(%2, 1)::Int64 │ %4 = Base.getindex(%1, %3)::Integer │ %5 = (isa)(%4, BigInt)::Bool └─── goto JuliaLang#14 if not %5 2 ── %7 = π (%4, BigInt) │ %8 = Base.getfield(%7, :size)::Int32 │ %9 = Base.flipsign_int(%8, %8)::Int32 │ %10 = Core.sext_int(Core.Int64, %9)::Int64 │ %11 = Base.sle_int(0, %10)::Bool └─── goto JuliaLang#4 if not %11 3 ── %13 = Core.sext_int(Core.Int64, %9)::Int64 │ %14 = Base.sle_int(%13, 1)::Bool └─── goto JuliaLang#5 4 ── nothing 5 ┄─ %17 = φ (JuliaLang#3 => %14, JuliaLang#4 => false)::Bool └─── goto JuliaLang#12 if not %17 6 ── %19 = Base.getfield(%7, :size)::Int32 │ %20 = Core.sext_int(Core.Int64, %19)::Int64 │ %21 = (%20 === 0)::Bool └─── goto JuliaLang#8 if not %21 7 ── goto JuliaLang#9 8 ── %24 = Base.getfield(%7, :d)::Ptr{UInt64} │ %25 = Base.pointerref(%24, 1, 1)::UInt64 │ %26 = Base.bitcast(Int64, %25)::Int64 │ %27 = Base.getfield(%7, :size)::Int32 │ %28 = Core.sext_int(Core.Int64, %27)::Int64 │ %29 = Base.flipsign_int(%26, %28)::Int64 └─── goto JuliaLang#9 9 ┄─ %31 = φ (JuliaLang#7 => 0, JuliaLang#8 => %29)::Int64 │ %32 = Base.getfield(%7, :size)::Int32 │ %33 = Core.sext_int(Core.Int64, %32)::Int64 │ %34 = Base.slt_int(0, %33)::Bool │ %35 = Base.slt_int(0, %31)::Bool │ %36 = (%34 === %35)::Bool │ %37 = Base.not_int(%36)::Bool └─── goto JuliaLang#11 if not %37 10 ─ %39 = Base.GMP.nameof(Int64)::Any │ %40 = Base.GMP.InexactError(%39, Int64, %7)::Any │ Base.GMP.throw(%40) └─── unreachable 11 ─ goto JuliaLang#13 12 ─ %44 = Base.GMP.nameof(Int64)::Any │ %45 = Base.GMP.InexactError(%44, Int64, %7)::Any │ Base.GMP.throw(%45) └─── unreachable 13 ─ goto JuliaLang#15 14 ─ %49 = SparseArrays.Int(%4)::Int64 └─── goto JuliaLang#15 15 ┄ %51 = φ (JuliaLang#13 => %31, JuliaLang#14 => %49)::Int64 │ %52 = Base.sub_int(%51, 1)::Int64 └─── return %52 ```
- Loading branch information