1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

use bit_cache in range check

This commit is contained in:
dark64 2021-12-28 13:41:21 +01:00
parent 85ca126e03
commit d80730a65a

View file

@ -16,10 +16,7 @@ use crate::flat_absy::{RuntimeError, *};
use crate::solvers::Solver;
use crate::zir::types::{Type, UBitwidth};
use crate::zir::*;
use std::collections::{
hash_map::{Entry, HashMap},
VecDeque,
};
use std::collections::{hash_map::HashMap, VecDeque};
use std::convert::TryFrom;
use zokrates_field::Field;
@ -274,7 +271,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
fn constant_le_check(
&mut self,
statements_flattened: &mut FlatStatements<T>,
a: &[FlatVariable],
a: &[FlatExpression<T>],
b: &[bool],
) -> Vec<FlatExpression<T>> {
let len = b.len();
@ -300,7 +297,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
if *b {
statements_flattened.push_back(FlatStatement::Definition(
is_not_smaller_run[i],
a[i].into(),
a[i].clone(),
));
// don't need to update size_unknown in the last round
@ -329,7 +326,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
box size_unknown[i].into(),
);
let or_right: FlatExpression<_> =
FlatExpression::Sub(box FlatExpression::Number(T::from(1)), box a[i].into());
FlatExpression::Sub(box FlatExpression::Number(T::from(1)), box a[i].clone());
let and_name = self.use_sym();
let and = FlatExpression::Mult(box or_left.clone(), box or_right.clone());
@ -412,7 +409,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
fn enforce_constant_le_check(
&mut self,
statements_flattened: &mut FlatStatements<T>,
a: &[FlatVariable],
a: &[FlatExpression<T>],
b: &[bool],
) {
let conditions = self.constant_le_check(statements_flattened, a, b);
@ -620,50 +617,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
e: FlatExpression<T>,
c: T,
) -> FlatExpression<T> {
let bit_width = T::get_required_bits();
// decompose e to bits
let e_id = self.define(e, statements_flattened);
let e_var = self.define(e, statements_flattened);
let e_id = FlatExpression::Identifier(e_var);
// define variables for the bits
let e_bits_be: Vec<FlatVariable> = (0..bit_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
e_bits_be.clone(),
Solver::bits(bit_width),
vec![e_id],
)));
// bitness checks
for bit in e_bits_be.iter().take(bit_width) {
statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Identifier(*bit),
FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Identifier(*bit),
),
RuntimeError::ConstantLtBitness,
));
}
// bit decomposition check
let mut e_sum = FlatExpression::Number(T::from(0));
for (i, bit) in e_bits_be.iter().take(bit_width).enumerate() {
e_sum = FlatExpression::Add(
box e_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
),
);
}
statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Identifier(e_id),
e_sum,
RuntimeError::ConstantLtSum,
));
let width = T::get_required_bits();
let e_bits_be = self.get_bits(&e_id, width, width, statements_flattened);
// check that this decomposition does not overflow the field
self.enforce_constant_le_check(
@ -672,7 +630,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
&T::max_value().to_bits_be(),
);
let conditions = self.constant_le_check(statements_flattened, &e_bits_be, &c.to_bits_be());
let c_bits_be: Vec<bool> = c.to_bits_be();
let conditions = self.constant_le_check(statements_flattened, &e_bits_be, &c_bits_be);
// return `len(conditions) == sum(conditions)`
self.eq_check(
@ -1068,7 +1027,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// define variables for the variable bits
let variables: Vec<_> = expressions
.into_iter()
.map(|e| self.define(e.get_field_unchecked(), statements_flattened))
.map(|e| {
FlatExpression::Identifier(
self.define(e.get_field_unchecked(), statements_flattened),
)
})
.collect();
// get constants for the constant bits
@ -1324,29 +1287,31 @@ impl<'ast, T: Field> Flattener<'ast, T> {
solver: Solver::EuclideanDiv,
}));
let target_bitwidth = target_bitwidth.to_usize();
// q in range
let _ = self.get_bits(
&FlatUExpression::with_field(FlatExpression::from(q)),
target_bitwidth.to_usize(),
&FlatExpression::from(q),
target_bitwidth,
target_bitwidth,
statements_flattened,
);
// r in range
let _ = self.get_bits(
&FlatUExpression::with_field(FlatExpression::from(r)),
target_bitwidth.to_usize(),
&FlatExpression::from(r),
target_bitwidth,
target_bitwidth,
statements_flattened,
);
// r < d <=> r - d + 2**w < 2**w
let _ = self.get_bits(
&FlatUExpression::with_field(FlatExpression::Add(
&FlatExpression::Add(
box FlatExpression::Sub(box r.into(), box d.clone()),
box FlatExpression::Number(T::from(2_u128.pow(target_bitwidth.to_usize() as u32))),
)),
target_bitwidth.to_usize(),
box FlatExpression::Number(T::from(2_u128.pow(target_bitwidth as u32))),
),
target_bitwidth,
target_bitwidth,
statements_flattened,
);
@ -1815,8 +1780,15 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let res = match should_reduce {
true => {
let bits =
self.get_bits(&res, actual_bitwidth, target_bitwidth, statements_flattened);
let bits = match &res.bits {
Some(bits) => bits.clone(),
None => self.get_bits(
res.field.as_ref().unwrap(),
actual_bitwidth,
target_bitwidth.to_usize(),
statements_flattened,
),
};
let field = if actual_bitwidth > target_bitwidth.to_usize() {
bits.iter().enumerate().fold(
@ -1847,18 +1819,16 @@ impl<'ast, T: Field> Flattener<'ast, T> {
fn get_bits(
&mut self,
e: &FlatUExpression<T>,
e: &FlatExpression<T>,
from: usize,
to: UBitwidth,
to: usize,
statements_flattened: &mut FlatStatements<T>,
) -> Vec<FlatExpression<T>> {
let to = to.to_usize();
assert!(from < T::get_required_bits());
assert!(to < T::get_required_bits());
assert!(from <= T::get_required_bits());
assert!(to <= T::get_required_bits());
// constants do not require directives
if let Some(FlatExpression::Number(ref x)) = e.field {
if let FlatExpression::Number(ref x) = e {
let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.clone()])
.unwrap()
.into_iter()
@ -1867,67 +1837,59 @@ impl<'ast, T: Field> Flattener<'ast, T> {
assert_eq!(bits.len(), to);
self.bits_cache
.insert(e.field.clone().unwrap(), bits.clone());
self.bits_cache.insert(e.clone(), bits.clone());
return bits;
};
e.bits.clone().unwrap_or_else(|| {
// we are not reducing a constant, therefore the result should always have a smaller bitwidth:
// `to` is the target bitwidth, and `from` cannot be smaller than that unless we're looking at a
// constant
let from = std::cmp::max(from, to);
match self.bits_cache.entry(e.field.clone().unwrap()) {
Entry::Occupied(entry) => {
let res: Vec<_> = entry.get().clone();
// if we already know a decomposition, it has to be of the size of the target bitwidth
assert_eq!(res.len(), to);
res
}
Entry::Vacant(_) => {
let bits = (0..from).map(|_| self.use_sym()).collect::<Vec<_>>();
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
bits.clone(),
Solver::Bits(from),
vec![e.field.clone().unwrap()],
)));
let bits: Vec<_> = bits.into_iter().map(FlatExpression::Identifier).collect();
// decompose to the actual bitwidth
// bit checks
statements_flattened.extend(bits.iter().take(from).map(|bit| {
FlatStatement::Condition(
bit.clone(),
FlatExpression::Mult(box bit.clone(), box bit.clone()),
RuntimeError::Bitness,
)
}));
let sum = flat_expression_from_bits(bits.clone());
// sum check
statements_flattened.push_back(FlatStatement::Condition(
e.field.clone().unwrap(),
sum.clone(),
RuntimeError::Sum,
));
// truncate to the `to` lowest bits
let bits = bits[from - to..].to_vec();
assert_eq!(bits.len(), to);
self.bits_cache
.insert(e.field.clone().unwrap(), bits.clone());
self.bits_cache.insert(sum, bits.clone());
bits
}
match self.bits_cache.get(e) {
Some(bits) => {
// if we already know a decomposition, it has to be of the size of the target bitwidth
assert_eq!(bits.len(), to);
bits.clone()
}
})
None => {
let from = std::cmp::max(from, to);
let bits = (0..from).map(|_| self.use_sym()).collect::<Vec<_>>();
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
bits.clone(),
Solver::Bits(from),
vec![e.clone()],
)));
let bits: Vec<_> = bits.into_iter().map(FlatExpression::Identifier).collect();
// decompose to the actual bitwidth
// bit checks
statements_flattened.extend(bits.iter().take(from).map(|bit| {
FlatStatement::Condition(
bit.clone(),
FlatExpression::Mult(box bit.clone(), box bit.clone()),
RuntimeError::Bitness,
)
}));
let sum = flat_expression_from_bits(bits.clone());
// sum check
statements_flattened.push_back(FlatStatement::Condition(
e.clone(),
sum.clone(),
RuntimeError::Sum,
));
// truncate to the `to` lowest bits
let bits = bits[from - to..].to_vec();
assert_eq!(bits.len(), to);
self.bits_cache.insert(e.clone(), bits.clone());
self.bits_cache.insert(sum, bits.clone());
bits
}
}
}
fn flatten_select_expression<U: Flatten<'ast, T>>(
@ -2510,9 +2472,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// to constrain unsigned integer inputs to be in range, we get their bit decomposition.
// it will be cached
self.get_bits(
&FlatUExpression::with_field(FlatExpression::Identifier(variable)),
&FlatExpression::Identifier(variable),
bitwidth.to_usize(),
bitwidth.to_usize(),
bitwidth,
statements_flattened,
);
}