Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix rawbigints OOB issues #55917

Merged
merged 1 commit into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions base/rawbigints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,21 @@ reversed_index(n::Int, i::Int) = n - i - 1
reversed_index(x, i::Int, v::Val) = reversed_index(elem_count(x, v), i)::Int
split_bit_index(x::RawBigInt, i::Int) = divrem(i, word_length(x), RoundToZero)

function get_elem_words_raw(x::RawBigInt{T}, i::Int) where {T}
@boundscheck if (i < 0) || (elem_count(x, Val(:words)) ≤ i)
throw(BoundsError(x, i))
end
d = x.d
j = i + 1
(GC.@preserve d unsafe_load(Ptr{T}(pointer(d)), j))::T
end

"""
`i` is the zero-based index of the wanted word in `x`, starting from
the less significant words.
"""
function get_elem(x::RawBigInt{T}, i::Int, ::Val{:words}, ::Val{:ascending}) where {T}
# `i` must be non-negative and less than `x.word_count`
d = x.d
(GC.@preserve d unsafe_load(Ptr{T}(pointer(d)), i + 1))::T
function get_elem(x::RawBigInt, i::Int, ::Val{:words}, ::Val{:ascending})
@inbounds @inline get_elem_words_raw(x, i)
end

function get_elem(x, i::Int, v::Val, ::Val{:descending})
Expand Down Expand Up @@ -96,25 +103,31 @@ end

"""
Returns an integer of type `R`, consisting of the `len` most
significant bits of `x`.
significant bits of `x`. If there are less than `len` bits in `x`,
the least significant bits are zeroed.
"""
function truncated(::Type{R}, x::RawBigInt, len::Int) where {R<:Integer}
ret = zero(R)
if 0 < len
word_count, bit_count_in_word = split_bit_index(x, len)
k = word_length(x)
vals = (Val(:words), Val(:descending))
lenx = elem_count(x, first(vals))

for w ∈ 0:(word_count - 1)
ret <<= k
word = get_elem(x, w, vals...)
ret |= R(word)
if w < lenx
word = get_elem(x, w, vals...)
ret |= R(word)
end
end

if !iszero(bit_count_in_word)
ret <<= bit_count_in_word
wrd = get_elem(x, word_count, vals...)
ret |= R(wrd >>> (k - bit_count_in_word))
if word_count < lenx
wrd = get_elem(x, word_count, vals...)
ret |= R(wrd >>> (k - bit_count_in_word))
end
end
end
ret::R
Expand Down
9 changes: 9 additions & 0 deletions test/mpfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1088,3 +1088,12 @@ end
clear_flags()
end
end

@testset "RawBigInt truncation OOB read" begin
@testset "T: $T" for T ∈ (UInt8, UInt16, UInt32, UInt64, UInt128)
v = Base.RawBigInt{T}("a"^sizeof(T), 1)
@testset "bit_count: $bit_count" for bit_count ∈ (0:10:80)
@test Base.truncated(UInt128, v, bit_count) isa Any
end
end
end