use bit_cache in range check
This commit is contained in:
parent
85ca126e03
commit
d80730a65a
1 changed files with 91 additions and 129 deletions
|
@ -16,10 +16,7 @@ use crate::flat_absy::{RuntimeError, *};
|
|||
use crate::solvers::Solver;
|
||||
use crate::zir::types::{Type, UBitwidth};
|
||||
use crate::zir::*;
|
||||
use std::collections::{
|
||||
hash_map::{Entry, HashMap},
|
||||
VecDeque,
|
||||
};
|
||||
use std::collections::{hash_map::HashMap, VecDeque};
|
||||
use std::convert::TryFrom;
|
||||
use zokrates_field::Field;
|
||||
|
||||
|
@ -274,7 +271,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
fn constant_le_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
a: &[FlatVariable],
|
||||
a: &[FlatExpression<T>],
|
||||
b: &[bool],
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
let len = b.len();
|
||||
|
@ -300,7 +297,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
if *b {
|
||||
statements_flattened.push_back(FlatStatement::Definition(
|
||||
is_not_smaller_run[i],
|
||||
a[i].into(),
|
||||
a[i].clone(),
|
||||
));
|
||||
|
||||
// don't need to update size_unknown in the last round
|
||||
|
@ -329,7 +326,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
box size_unknown[i].into(),
|
||||
);
|
||||
let or_right: FlatExpression<_> =
|
||||
FlatExpression::Sub(box FlatExpression::Number(T::from(1)), box a[i].into());
|
||||
FlatExpression::Sub(box FlatExpression::Number(T::from(1)), box a[i].clone());
|
||||
|
||||
let and_name = self.use_sym();
|
||||
let and = FlatExpression::Mult(box or_left.clone(), box or_right.clone());
|
||||
|
@ -412,7 +409,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
fn enforce_constant_le_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
a: &[FlatVariable],
|
||||
a: &[FlatExpression<T>],
|
||||
b: &[bool],
|
||||
) {
|
||||
let conditions = self.constant_le_check(statements_flattened, a, b);
|
||||
|
@ -620,50 +617,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
e: FlatExpression<T>,
|
||||
c: T,
|
||||
) -> FlatExpression<T> {
|
||||
let bit_width = T::get_required_bits();
|
||||
// decompose e to bits
|
||||
let e_id = self.define(e, statements_flattened);
|
||||
let e_var = self.define(e, statements_flattened);
|
||||
let e_id = FlatExpression::Identifier(e_var);
|
||||
|
||||
// define variables for the bits
|
||||
let e_bits_be: Vec<FlatVariable> = (0..bit_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||
e_bits_be.clone(),
|
||||
Solver::bits(bit_width),
|
||||
vec![e_id],
|
||||
)));
|
||||
|
||||
// bitness checks
|
||||
for bit in e_bits_be.iter().take(bit_width) {
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(*bit),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Identifier(*bit),
|
||||
),
|
||||
RuntimeError::ConstantLtBitness,
|
||||
));
|
||||
}
|
||||
|
||||
// bit decomposition check
|
||||
let mut e_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for (i, bit) in e_bits_be.iter().take(bit_width).enumerate() {
|
||||
e_sum = FlatExpression::Add(
|
||||
box e_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(e_id),
|
||||
e_sum,
|
||||
RuntimeError::ConstantLtSum,
|
||||
));
|
||||
let width = T::get_required_bits();
|
||||
let e_bits_be = self.get_bits(&e_id, width, width, statements_flattened);
|
||||
|
||||
// check that this decomposition does not overflow the field
|
||||
self.enforce_constant_le_check(
|
||||
|
@ -672,7 +630,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
&T::max_value().to_bits_be(),
|
||||
);
|
||||
|
||||
let conditions = self.constant_le_check(statements_flattened, &e_bits_be, &c.to_bits_be());
|
||||
let c_bits_be: Vec<bool> = c.to_bits_be();
|
||||
let conditions = self.constant_le_check(statements_flattened, &e_bits_be, &c_bits_be);
|
||||
|
||||
// return `len(conditions) == sum(conditions)`
|
||||
self.eq_check(
|
||||
|
@ -1068,7 +1027,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// define variables for the variable bits
|
||||
let variables: Vec<_> = expressions
|
||||
.into_iter()
|
||||
.map(|e| self.define(e.get_field_unchecked(), statements_flattened))
|
||||
.map(|e| {
|
||||
FlatExpression::Identifier(
|
||||
self.define(e.get_field_unchecked(), statements_flattened),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// get constants for the constant bits
|
||||
|
@ -1324,29 +1287,31 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
solver: Solver::EuclideanDiv,
|
||||
}));
|
||||
|
||||
let target_bitwidth = target_bitwidth.to_usize();
|
||||
|
||||
// q in range
|
||||
let _ = self.get_bits(
|
||||
&FlatUExpression::with_field(FlatExpression::from(q)),
|
||||
target_bitwidth.to_usize(),
|
||||
&FlatExpression::from(q),
|
||||
target_bitwidth,
|
||||
target_bitwidth,
|
||||
statements_flattened,
|
||||
);
|
||||
|
||||
// r in range
|
||||
let _ = self.get_bits(
|
||||
&FlatUExpression::with_field(FlatExpression::from(r)),
|
||||
target_bitwidth.to_usize(),
|
||||
&FlatExpression::from(r),
|
||||
target_bitwidth,
|
||||
target_bitwidth,
|
||||
statements_flattened,
|
||||
);
|
||||
|
||||
// r < d <=> r - d + 2**w < 2**w
|
||||
let _ = self.get_bits(
|
||||
&FlatUExpression::with_field(FlatExpression::Add(
|
||||
&FlatExpression::Add(
|
||||
box FlatExpression::Sub(box r.into(), box d.clone()),
|
||||
box FlatExpression::Number(T::from(2_u128.pow(target_bitwidth.to_usize() as u32))),
|
||||
)),
|
||||
target_bitwidth.to_usize(),
|
||||
box FlatExpression::Number(T::from(2_u128.pow(target_bitwidth as u32))),
|
||||
),
|
||||
target_bitwidth,
|
||||
target_bitwidth,
|
||||
statements_flattened,
|
||||
);
|
||||
|
@ -1815,8 +1780,15 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
let res = match should_reduce {
|
||||
true => {
|
||||
let bits =
|
||||
self.get_bits(&res, actual_bitwidth, target_bitwidth, statements_flattened);
|
||||
let bits = match &res.bits {
|
||||
Some(bits) => bits.clone(),
|
||||
None => self.get_bits(
|
||||
res.field.as_ref().unwrap(),
|
||||
actual_bitwidth,
|
||||
target_bitwidth.to_usize(),
|
||||
statements_flattened,
|
||||
),
|
||||
};
|
||||
|
||||
let field = if actual_bitwidth > target_bitwidth.to_usize() {
|
||||
bits.iter().enumerate().fold(
|
||||
|
@ -1847,18 +1819,16 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
fn get_bits(
|
||||
&mut self,
|
||||
e: &FlatUExpression<T>,
|
||||
e: &FlatExpression<T>,
|
||||
from: usize,
|
||||
to: UBitwidth,
|
||||
to: usize,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
let to = to.to_usize();
|
||||
|
||||
assert!(from < T::get_required_bits());
|
||||
assert!(to < T::get_required_bits());
|
||||
assert!(from <= T::get_required_bits());
|
||||
assert!(to <= T::get_required_bits());
|
||||
|
||||
// constants do not require directives
|
||||
if let Some(FlatExpression::Number(ref x)) = e.field {
|
||||
if let FlatExpression::Number(ref x) = e {
|
||||
let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()])
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
|
@ -1867,67 +1837,59 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
assert_eq!(bits.len(), to);
|
||||
|
||||
self.bits_cache
|
||||
.insert(e.field.clone().unwrap(), bits.clone());
|
||||
self.bits_cache.insert(e.clone(), bits.clone());
|
||||
return bits;
|
||||
};
|
||||
|
||||
e.bits.clone().unwrap_or_else(|| {
|
||||
// we are not reducing a constant, therefore the result should always have a smaller bitwidth:
|
||||
// `to` is the target bitwidth, and `from` cannot be smaller than that unless we're looking at a
|
||||
// constant
|
||||
|
||||
let from = std::cmp::max(from, to);
|
||||
match self.bits_cache.entry(e.field.clone().unwrap()) {
|
||||
Entry::Occupied(entry) => {
|
||||
let res: Vec<_> = entry.get().clone();
|
||||
// if we already know a decomposition, it has to be of the size of the target bitwidth
|
||||
assert_eq!(res.len(), to);
|
||||
res
|
||||
}
|
||||
Entry::Vacant(_) => {
|
||||
let bits = (0..from).map(|_| self.use_sym()).collect::<Vec<_>>();
|
||||
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||
bits.clone(),
|
||||
Solver::Bits(from),
|
||||
vec![e.field.clone().unwrap()],
|
||||
)));
|
||||
|
||||
let bits: Vec<_> = bits.into_iter().map(FlatExpression::Identifier).collect();
|
||||
|
||||
// decompose to the actual bitwidth
|
||||
|
||||
// bit checks
|
||||
statements_flattened.extend(bits.iter().take(from).map(|bit| {
|
||||
FlatStatement::Condition(
|
||||
bit.clone(),
|
||||
FlatExpression::Mult(box bit.clone(), box bit.clone()),
|
||||
RuntimeError::Bitness,
|
||||
)
|
||||
}));
|
||||
|
||||
let sum = flat_expression_from_bits(bits.clone());
|
||||
|
||||
// sum check
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
e.field.clone().unwrap(),
|
||||
sum.clone(),
|
||||
RuntimeError::Sum,
|
||||
));
|
||||
|
||||
// truncate to the `to` lowest bits
|
||||
let bits = bits[from - to..].to_vec();
|
||||
|
||||
assert_eq!(bits.len(), to);
|
||||
|
||||
self.bits_cache
|
||||
.insert(e.field.clone().unwrap(), bits.clone());
|
||||
self.bits_cache.insert(sum, bits.clone());
|
||||
|
||||
bits
|
||||
}
|
||||
match self.bits_cache.get(e) {
|
||||
Some(bits) => {
|
||||
// if we already know a decomposition, it has to be of the size of the target bitwidth
|
||||
assert_eq!(bits.len(), to);
|
||||
bits.clone()
|
||||
}
|
||||
})
|
||||
None => {
|
||||
let from = std::cmp::max(from, to);
|
||||
|
||||
let bits = (0..from).map(|_| self.use_sym()).collect::<Vec<_>>();
|
||||
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||
bits.clone(),
|
||||
Solver::Bits(from),
|
||||
vec![e.clone()],
|
||||
)));
|
||||
|
||||
let bits: Vec<_> = bits.into_iter().map(FlatExpression::Identifier).collect();
|
||||
|
||||
// decompose to the actual bitwidth
|
||||
|
||||
// bit checks
|
||||
statements_flattened.extend(bits.iter().take(from).map(|bit| {
|
||||
FlatStatement::Condition(
|
||||
bit.clone(),
|
||||
FlatExpression::Mult(box bit.clone(), box bit.clone()),
|
||||
RuntimeError::Bitness,
|
||||
)
|
||||
}));
|
||||
|
||||
let sum = flat_expression_from_bits(bits.clone());
|
||||
|
||||
// sum check
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
e.clone(),
|
||||
sum.clone(),
|
||||
RuntimeError::Sum,
|
||||
));
|
||||
|
||||
// truncate to the `to` lowest bits
|
||||
let bits = bits[from - to..].to_vec();
|
||||
|
||||
assert_eq!(bits.len(), to);
|
||||
|
||||
self.bits_cache.insert(e.clone(), bits.clone());
|
||||
self.bits_cache.insert(sum, bits.clone());
|
||||
|
||||
bits
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn flatten_select_expression<U: Flatten<'ast, T>>(
|
||||
|
@ -2510,9 +2472,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// to constrain unsigned integer inputs to be in range, we get their bit decomposition.
|
||||
// it will be cached
|
||||
self.get_bits(
|
||||
&FlatUExpression::with_field(FlatExpression::Identifier(variable)),
|
||||
&FlatExpression::Identifier(variable),
|
||||
bitwidth.to_usize(),
|
||||
bitwidth.to_usize(),
|
||||
bitwidth,
|
||||
statements_flattened,
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue