Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Bit shift is restricted to u8 right operand #4907

Merged
merged 10 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish-nargo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ permissions:

jobs:
build-apple-darwin:
runs-on: macos-latest
runs-on: macos-12
env:
CROSS_CONFIG: ${{ github.workspace }}/.github/Cross.toml
NIGHTLY_RELEASE: ${{ inputs.tag == '' }}
Expand Down
4 changes: 3 additions & 1 deletion acvm-repo/brillig_vm/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ pub(crate) fn evaluate_binary_int_op(
}
}
})?;
let rhs = rhs.expect_integer_with_bit_size(bit_size).map_err(|err| match err {
let rhs_bit_size =
if op == &BinaryIntOp::Shl || op == &BinaryIntOp::Shr { 8 } else { bit_size };
let rhs = rhs.expect_integer_with_bit_size(rhs_bit_size).map_err(|err| match err {
MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } => {
BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: value_bit_size,
Expand Down
14 changes: 10 additions & 4 deletions compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1279,8 +1279,11 @@ impl<'block> BrilligBlock<'block> {
dfg: &DataFlowGraph,
result_variable: SingleAddrVariable,
) {
let binary_type =
type_of_binary_operation(dfg[binary.lhs].get_type(), dfg[binary.rhs].get_type());
let binary_type = type_of_binary_operation(
dfg[binary.lhs].get_type(),
dfg[binary.rhs].get_type(),
binary.operator,
);

let left = self.convert_ssa_single_addr_value(binary.lhs, dfg);
let right = self.convert_ssa_single_addr_value(binary.rhs, dfg);
Expand Down Expand Up @@ -1766,7 +1769,7 @@ impl<'block> BrilligBlock<'block> {
}

/// Returns the type of the operation considering the types of the operands
pub(crate) fn type_of_binary_operation(lhs_type: &Type, rhs_type: &Type) -> Type {
pub(crate) fn type_of_binary_operation(lhs_type: &Type, rhs_type: &Type, op: BinaryOp) -> Type {
match (lhs_type, rhs_type) {
(_, Type::Function) | (Type::Function, _) => {
unreachable!("Functions are invalid in binary operations")
Expand All @@ -1782,12 +1785,15 @@ pub(crate) fn type_of_binary_operation(lhs_type: &Type, rhs_type: &Type) -> Type
}
// If both sides are numeric type, then we expect their types to be
// the same.
(Type::Numeric(lhs_type), Type::Numeric(rhs_type)) => {
(Type::Numeric(lhs_type), Type::Numeric(rhs_type))
if op != BinaryOp::Shl && op != BinaryOp::Shr =>
{
assert_eq!(
lhs_type, rhs_type,
"lhs and rhs types in a binary operation are always the same but got {lhs_type} and {rhs_type}"
);
Type::Numeric(*lhs_type)
}
_ => lhs_type.clone(),
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,6 @@
result: SingleAddrVariable,
operation: BrilligBinaryOp,
) {
assert!(
lhs.bit_size == rhs.bit_size,
"Not equal bit size for lhs and rhs: lhs {}, rhs {}",
lhs.bit_size,
rhs.bit_size
);
let is_field_op = lhs.bit_size == FieldElement::max_num_bits();
let expected_result_bit_size =
BrilligContext::binary_result_bit_size(operation, lhs.bit_size);
Expand Down Expand Up @@ -129,7 +123,7 @@
) {
assert!(
left.bit_size == right.bit_size,
"Not equal bitsize: lhs {}, rhs {}",

Check warning on line 126 in compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (bitsize)
left.bit_size,
right.bit_size
);
Expand Down
7 changes: 6 additions & 1 deletion compiler/noirc_evaluator/src/ssa/function_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,12 @@ impl FunctionBuilder {
) -> ValueId {
let lhs_type = self.type_of_value(lhs);
let rhs_type = self.type_of_value(rhs);
assert_eq!(lhs_type, rhs_type, "ICE - Binary instruction operands must have the same type");
if operator != BinaryOp::Shl && operator != BinaryOp::Shr {
assert_eq!(
lhs_type, rhs_type,
"ICE - Binary instruction operands must have the same type"
);
}
let instruction = Instruction::Binary(Binary { lhs, rhs, operator });
self.insert_instruction(instruction, None).first()
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ impl Context<'_> {
} else {
// we use a predicate to nullify the result in case of overflow
let bit_size_var =
self.numeric_constant(FieldElement::from(bit_size as u128), typ.clone());
self.numeric_constant(FieldElement::from(bit_size as u128), Type::unsigned(8));
let overflow = self.insert_binary(rhs, BinaryOp::Lt, bit_size_var);
let predicate = self.insert_cast(overflow, typ.clone());
// we can safely cast to unsigned because overflow_checks prevent bit-shift with a negative value
Expand Down
30 changes: 5 additions & 25 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ impl<'a> FunctionContext<'a> {
self.insert_safe_cast(result, result_type, location)
}
BinaryOpKind::ShiftLeft | BinaryOpKind::ShiftRight => {
self.check_shift_overflow(result, rhs, bit_size, location, true)
self.check_shift_overflow(result, rhs, bit_size, location)
}
_ => unreachable!("operator {} should not overflow", operator),
}
Expand Down Expand Up @@ -408,7 +408,7 @@ impl<'a> FunctionContext<'a> {
}
}

self.check_shift_overflow(result, rhs, bit_size, location, false);
self.check_shift_overflow(result, rhs, bit_size, location);
}

_ => unreachable!("operator {} should not overflow", operator),
Expand All @@ -430,32 +430,12 @@ impl<'a> FunctionContext<'a> {
rhs: ValueId,
bit_size: u32,
location: Location,
is_signed: bool,
) -> ValueId {
let one = self.builder.numeric_constant(FieldElement::one(), Type::bool());
let rhs = if is_signed {
self.insert_safe_cast(rhs, Type::unsigned(bit_size), location)
} else {
rhs
};
// Bit-shift with a negative number is an overflow
if is_signed {
// We compute the sign of rhs.
let half_width = self.builder.numeric_constant(
FieldElement::from(2_i128.pow(bit_size - 1)),
Type::unsigned(bit_size),
);
let sign = self.builder.insert_binary(rhs, BinaryOp::Lt, half_width);
self.builder.set_location(location).insert_constrain(
sign,
one,
Some("attempt to bit-shift with overflow".to_owned().into()),
);
}
assert!(self.builder.current_function.dfg.type_of_value(rhs) == Type::unsigned(8));

let max = self
.builder
.numeric_constant(FieldElement::from(bit_size as i128), Type::unsigned(bit_size));
let max =
self.builder.numeric_constant(FieldElement::from(bit_size as i128), Type::unsigned(8));
let overflow = self.builder.insert_binary(rhs, BinaryOp::Lt, max);
self.builder.set_location(location).insert_constrain(
overflow,
Expand Down
5 changes: 4 additions & 1 deletion compiler/noirc_frontend/src/hir/type_check/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ pub enum TypeCheckError {
FieldModulo { span: Span },
#[error("Fields cannot be compared, try casting to an integer first")]
FieldComparison { span: Span },
#[error("The bit count in a bit-shift operation must fit in a u8, try casting the right hand side into a u8 first")]
InvalidShiftSize { span: Span },
#[error("The number of bits to use for this bitwise operation is ambiguous. Either the operand's type or return type should be specified")]
AmbiguousBitWidth { span: Span },
#[error("Error with additional context")]
Expand Down Expand Up @@ -234,7 +236,8 @@ impl From<TypeCheckError> for Diagnostic {
| TypeCheckError::UnconstrainedReferenceToConstrained { span }
| TypeCheckError::UnconstrainedSliceReturnToConstrained { span }
| TypeCheckError::NonConstantSliceLength { span }
| TypeCheckError::StringIndexAssign { span } => {
| TypeCheckError::StringIndexAssign { span }
| TypeCheckError::InvalidShiftSize { span } => {
Diagnostic::simple_error(error.to_string(), String::new(), span)
}
TypeCheckError::PublicReturnType { typ, span } => Diagnostic::simple_error(
Expand Down
30 changes: 28 additions & 2 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use iter_extended::vecmap;
use noirc_errors::Span;

use crate::ast::{BinaryOpKind, UnaryOp};
use crate::ast::{BinaryOpKind, IntegerBitSize, UnaryOp};
use crate::macros_api::Signedness;
use crate::{
hir::{resolution::resolver::verify_mutable_reference, type_check::errors::Source},
hir_def::{
Expand Down Expand Up @@ -1129,11 +1130,30 @@ impl<'interner> TypeChecker<'interner> {
if let TypeBinding::Bound(binding) = &*int.borrow() {
return self.infix_operand_type_rules(binding, op, other, span);
}

if op.kind == BinaryOpKind::ShiftLeft || op.kind == BinaryOpKind::ShiftRight {
self.unify(
rhs_type,
&Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight),
|| TypeCheckError::InvalidShiftSize { span },
);
let use_impl = if lhs_type.is_numeric() {
let integer_type = Type::polymorphic_integer(self.interner);
self.bind_type_variables_for_infix(lhs_type, op, &integer_type, span)
} else {
true
};
return Ok((lhs_type.clone(), use_impl));
}
let use_impl = self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span);
Ok((other.clone(), use_impl))
}
(Integer(sign_x, bit_width_x), Integer(sign_y, bit_width_y)) => {
if op.kind == BinaryOpKind::ShiftLeft || op.kind == BinaryOpKind::ShiftRight {
if *sign_y != Signedness::Unsigned || *bit_width_y != IntegerBitSize::Eight {
return Err(TypeCheckError::InvalidShiftSize { span });
}
return Ok((Integer(*sign_x, *bit_width_x), false));
}
if sign_x != sign_y {
return Err(TypeCheckError::IntegerSignedness {
sign_x: *sign_x,
Expand Down Expand Up @@ -1165,6 +1185,12 @@ impl<'interner> TypeChecker<'interner> {
(Bool, Bool) => Ok((Bool, false)),

(lhs, rhs) => {
if op.kind == BinaryOpKind::ShiftLeft || op.kind == BinaryOpKind::ShiftRight {
if rhs == &Type::Integer(Signedness::Unsigned, IntegerBitSize::Eight) {
return Ok((lhs.clone(), true));
}
return Err(TypeCheckError::InvalidShiftSize { span });
}
self.unify(lhs, rhs, || TypeCheckError::TypeMismatchWithSource {
expected: lhs.clone(),
actual: rhs.clone(),
Expand Down
4 changes: 2 additions & 2 deletions docs/docs/noir/concepts/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ sidebar_position: 3
| ^ | XOR two private input types together | Types must be integer |
| & | AND two private input types together | Types must be integer |
| \| | OR two private input types together | Types must be integer |
| \<\< | Left shift an integer by another integer amount | Types must be integer |
| >> | Right shift an integer by another integer amount | Types must be integer |
| \<\< | Left shift an integer by another integer amount | Types must be integer, shift must be u8 |
| >> | Right shift an integer by another integer amount | Types must be integer, shift must be u8 |
| ! | Bitwise not of a value | Type must be integer or boolean |
| \< | returns a bool if one value is less than the other | Upper bound must have a known bit size |
| \<= | returns a bool if one value is less than or equal to the other | Upper bound must have a known bit size |
Expand Down
28 changes: 14 additions & 14 deletions noir_stdlib/src/ops.nr
Original file line number Diff line number Diff line change
Expand Up @@ -126,30 +126,30 @@ impl BitXor for i64 { fn bitxor(self, other: i64) -> i64 { self ^ other } }

// docs:start:shl-trait
trait Shl {
fn shl(self, other: Self) -> Self;
fn shl(self, other: u8) -> Self;
}
// docs:end:shl-trait

impl Shl for u32 { fn shl(self, other: u32) -> u32 { self << other } }
impl Shl for u64 { fn shl(self, other: u64) -> u64 { self << other } }
impl Shl for u32 { fn shl(self, other: u8) -> u32 { self << other } }
impl Shl for u64 { fn shl(self, other: u8) -> u64 { self << other } }
impl Shl for u8 { fn shl(self, other: u8) -> u8 { self << other } }
impl Shl for u1 { fn shl(self, other: u1) -> u1 { self << other } }
impl Shl for u1 { fn shl(self, other: u8) -> u1 { self << other } }

impl Shl for i8 { fn shl(self, other: i8) -> i8 { self << other } }
impl Shl for i32 { fn shl(self, other: i32) -> i32 { self << other } }
impl Shl for i64 { fn shl(self, other: i64) -> i64 { self << other } }
impl Shl for i8 { fn shl(self, other: u8) -> i8 { self << other } }
impl Shl for i32 { fn shl(self, other: u8) -> i32 { self << other } }
impl Shl for i64 { fn shl(self, other: u8) -> i64 { self << other } }

// docs:start:shr-trait
trait Shr {
fn shr(self, other: Self) -> Self;
fn shr(self, other: u8) -> Self;
}
// docs:end:shr-trait

impl Shr for u64 { fn shr(self, other: u64) -> u64 { self >> other } }
impl Shr for u32 { fn shr(self, other: u32) -> u32 { self >> other } }
impl Shr for u64 { fn shr(self, other: u8) -> u64 { self >> other } }
impl Shr for u32 { fn shr(self, other: u8) -> u32 { self >> other } }
impl Shr for u8 { fn shr(self, other: u8) -> u8 { self >> other } }
impl Shr for u1 { fn shr(self, other: u1) -> u1 { self >> other } }
impl Shr for u1 { fn shr(self, other: u8) -> u1 { self >> other } }

impl Shr for i8 { fn shr(self, other: i8) -> i8 { self >> other } }
impl Shr for i32 { fn shr(self, other: i32) -> i32 { self >> other } }
impl Shr for i64 { fn shr(self, other: i64) -> i64 { self >> other } }
impl Shr for i8 { fn shr(self, other: u8) -> i8 { self >> other } }
impl Shr for i32 { fn shr(self, other: u8) -> i32 { self >> other } }
impl Shr for i64 { fn shr(self, other: u8) -> i64 { self >> other } }
2 changes: 1 addition & 1 deletion noir_stdlib/src/sha512.nr
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// 64 bytes.
// Internal functions act on 64-bit unsigned integers for simplicity.
// Auxiliary mappings; names as in FIPS PUB 180-4
fn rotr64(a: u64, b: u64) -> u64 // 64-bit right rotation
fn rotr64(a: u64, b: u8) -> u64 // 64-bit right rotation
{
// None of the bits overlap between `(a >> b)` and `(a << (64 - b))`
// Addition is then equivalent to OR, with fewer constraints.
Expand Down
12 changes: 6 additions & 6 deletions noir_stdlib/src/uint128.nr
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,9 @@ impl BitXor for U128 {
}

impl Shl for U128 {
fn shl(self, other: U128) -> U128 {
assert(other < U128::from_u64s_le(128,0), "attempt to shift left with overflow");
let exp_bits = other.lo.to_be_bits(7);
fn shl(self, other: u8) -> U128 {
assert(other < 128, "attempt to shift left with overflow");
let exp_bits = (other as Field).to_be_bits(7);

let mut r: Field = 2;
let mut y: Field = 1;
Expand All @@ -271,9 +271,9 @@ impl Shl for U128 {
}

impl Shr for U128 {
fn shr(self, other: U128) -> U128 {
assert(other < U128::from_u64s_le(128,0), "attempt to shift right with overflow");
let exp_bits = other.lo.to_be_bits(7);
fn shr(self, other: u8) -> U128 {
assert(other < 128, "attempt to shift right with overflow");
let exp_bits = (other as Field).to_be_bits(7);

let mut r: Field = 2;
let mut y: Field = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ fn main() {
assert((x | y) == or(x, y));
// TODO SSA => ACIR has some issues with xor ops
assert(check_xor(x, y, 4));
assert((x >> y) == shr(x, y));
assert((x << y) == shl(x, y));
assert((x >> y as u8) == shr(x, y as u8));
assert((x << y as u8) == shl(x, y as u8));
}

unconstrained fn add(x: u32, y: u32) -> u32 {
Expand Down Expand Up @@ -67,11 +67,11 @@ unconstrained fn check_xor(x: u32, y: u32, result: u32) -> bool {
(x ^ y) == result
}

unconstrained fn shr(x: u32, y: u32) -> u32 {
unconstrained fn shr(x: u32, y: u8) -> u32 {
x >> y
}

unconstrained fn shl(x: u32, y: u32) -> u32 {
unconstrained fn shl(x: u32, y: u8) -> u32 {
x << y
}

Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn main(x: u64) {
//regression for 3481
assert(x << 63 == 0);

assert_eq((1 as u64) << (32 as u64), 0x0100000000);
assert_eq((1 as u64) << 32, 0x0100000000);
}

fn regression_2250() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
fn main(x: u64, y: u64) {
fn main(x: u64, y: u8) {
// runtime shifts on compile-time known values
assert(64 << y == 128);
assert(64 >> y == 32);
Expand All @@ -11,10 +11,10 @@ fn main(x: u64, y: u64) {
let mut b: i8 = x as i8;
assert(b << 1 == -128);
assert(b >> 2 == 16);
assert(b >> a == 32);
assert(b >> y == 32);
a = -a;
assert(a << 7 == -128);
assert(a << -a == -2);
assert(a << y == -2);

assert(x >> x == 0);
assert(x >> (x as u8) == 0);
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
unconstrained fn main(x: u64, y: u64) {
unconstrained fn main(x: u64, y: u8) {
// runtime shifts on compile-time known values
assert(64 << y == 128);
assert(64 >> y == 32);
assert(64 as u32 << y == 128);
assert(64 as u32 >> y == 32);
// runtime shifts on runtime values
assert(x << y == 128);
assert(x >> y == 32);
Expand All @@ -11,10 +11,10 @@ unconstrained fn main(x: u64, y: u64) {
let mut b: i8 = x as i8;
assert(b << 1 == -128);
assert(b >> 2 == 16);
assert(b >> a == 32);
assert(b >> y == 32);
a = -a;
assert(a << 7 == -128);
assert(a << -a == -2);
assert(a << y == -2);

assert(x >> x == 0);
assert(x >> (x as u8) == 0);
}
Loading
Loading