1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

Merge branch 'develop' into constraint-analyser

This commit is contained in:
dark64 2020-04-03 14:41:44 +02:00
commit 49297bf1db
3 changed files with 51 additions and 46 deletions

View file

@ -203,7 +203,7 @@ fn use_variable(
/// * the return value of the `FlatFunction` is not deterministic: as we decompose over log_2(p) + 1 bits, some
/// elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)`
pub fn unpack<T: Field>() -> FlatFunction<T> {
let nbits = T::get_required_bits();
let bit_width = T::get_required_bits();
let mut counter = 0;
@ -221,23 +221,23 @@ pub fn unpack<T: Field>() -> FlatFunction<T> {
format!("i0"),
&mut counter,
))];
let directive_outputs: Vec<FlatVariable> = (0..T::get_required_bits())
let directive_outputs: Vec<FlatVariable> = (0..bit_width)
.map(|index| use_variable(&mut layout, format!("o{}", index), &mut counter))
.collect();
let solver = Solver::bits();
let solver = Solver::bits(bit_width);
let outputs = directive_outputs
.iter()
.enumerate()
.filter(|(index, _)| *index >= T::get_required_bits() - nbits)
.filter(|(index, _)| *index >= T::get_required_bits() - bit_width)
.map(|(_, o)| FlatExpression::Identifier(o.clone()))
.collect();
// o253, o252, ... o{253 - (nbits - 1)} are bits
let mut statements: Vec<FlatStatement<T>> = (0..nbits)
// o253, o252, ... o{253 - (bit_width - 1)} are bits
let mut statements: Vec<FlatStatement<T>> = (0..bit_width)
.map(|index| {
let bit = FlatExpression::Identifier(FlatVariable::new(T::get_required_bits() - index));
let bit = FlatExpression::Identifier(FlatVariable::new(bit_width - index));
FlatStatement::Condition(
bit.clone(),
FlatExpression::Mult(box bit.clone(), box bit.clone()),
@ -245,14 +245,14 @@ pub fn unpack<T: Field>() -> FlatFunction<T> {
})
.collect();
// sum check: o253 + o252 * 2 + ... + o{253 - (nbits - 1)} * 2**(nbits - 1)
// sum check: o253 + o252 * 2 + ... + o{253 - (bit_width - 1)} * 2**(bit_width - 1)
let mut lhs_sum = FlatExpression::Number(T::from(0));
for i in 0..nbits {
for i in 0..bit_width {
lhs_sum = FlatExpression::Add(
box lhs_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(FlatVariable::new(T::get_required_bits() - i)),
box FlatExpression::Identifier(FlatVariable::new(bit_width - i)),
box FlatExpression::Number(T::from(2).pow(i)),
),
);
@ -312,7 +312,7 @@ mod tests {
(0..FieldPrime::get_required_bits())
.map(|i| FlatVariable::new(i + 1))
.collect(),
Solver::bits(),
Solver::bits(FieldPrime::get_required_bits()),
vec![FlatVariable::new(0)]
))
);

View file

@ -566,8 +566,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FlatExpression::Identifier(self.layout.get(&x).unwrap().clone()[0])
}
BooleanExpression::Lt(box lhs, box rhs) => {
// Get the bitwidth to know the size of the binary decompsitions for this Field
let bitwidth = T::get_required_bits();
// Get the bit width to know the size of the binary decompositions for this Field
let bit_width = T::get_required_bits();
let safe_width = bit_width - 2; // making sure we don't overflow, assert here?
// We know from semantic checking that lhs and rhs have the same type
// What the expression will flatten to depends on that type
@ -587,22 +588,22 @@ impl<'ast, T: Field> Flattener<'ast, T> {
{
// define variables for the bits
let lhs_bits_be: Vec<FlatVariable> =
(0..bitwidth).map(|_| self.use_sym()).collect();
(0..safe_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
lhs_bits_be.clone(),
Solver::bits(),
Solver::bits(safe_width),
vec![lhs_id],
)));
// bitness checks
for i in 0..bitwidth - 2 {
for i in 0..safe_width {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(lhs_bits_be[i + 2]),
FlatExpression::Identifier(lhs_bits_be[i]),
FlatExpression::Mult(
box FlatExpression::Identifier(lhs_bits_be[i + 2]),
box FlatExpression::Identifier(lhs_bits_be[i + 2]),
box FlatExpression::Identifier(lhs_bits_be[i]),
box FlatExpression::Identifier(lhs_bits_be[i]),
),
));
}
@ -610,12 +611,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// bit decomposition check
let mut lhs_sum = FlatExpression::Number(T::from(0));
for i in 0..bitwidth - 2 {
for i in 0..safe_width {
lhs_sum = FlatExpression::Add(
box lhs_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(lhs_bits_be[i + 2]),
box FlatExpression::Number(T::from(2).pow(bitwidth - 2 - i - 1)),
box FlatExpression::Identifier(lhs_bits_be[i]),
box FlatExpression::Number(T::from(2).pow(safe_width - i - 1)),
),
);
}
@ -634,22 +635,22 @@ impl<'ast, T: Field> Flattener<'ast, T> {
{
// define variables for the bits
let rhs_bits_be: Vec<FlatVariable> =
(0..bitwidth).map(|_| self.use_sym()).collect();
(0..safe_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
rhs_bits_be.clone(),
Solver::bits(),
Solver::bits(safe_width),
vec![rhs_id],
)));
// bitness checks
for i in 0..bitwidth - 2 {
for i in 0..safe_width {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(rhs_bits_be[i + 2]),
FlatExpression::Identifier(rhs_bits_be[i]),
FlatExpression::Mult(
box FlatExpression::Identifier(rhs_bits_be[i + 2]),
box FlatExpression::Identifier(rhs_bits_be[i + 2]),
box FlatExpression::Identifier(rhs_bits_be[i]),
box FlatExpression::Identifier(rhs_bits_be[i]),
),
));
}
@ -657,12 +658,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// bit decomposition check
let mut rhs_sum = FlatExpression::Number(T::from(0));
for i in 0..bitwidth - 2 {
for i in 0..safe_width {
rhs_sum = FlatExpression::Add(
box rhs_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(rhs_bits_be[i + 2]),
box FlatExpression::Number(T::from(2).pow(bitwidth - 2 - i - 1)),
box FlatExpression::Identifier(rhs_bits_be[i]),
box FlatExpression::Number(T::from(2).pow(safe_width - i - 1)),
),
);
}
@ -687,17 +688,17 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// define variables for the bits
let sub_bits_be: Vec<FlatVariable> =
(0..bitwidth).map(|_| self.use_sym()).collect();
(0..bit_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
sub_bits_be.clone(),
Solver::bits(),
Solver::bits(bit_width),
vec![subtraction_result.clone()],
)));
// bitness checks
for i in 0..bitwidth {
for i in 0..bit_width {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(sub_bits_be[i]),
FlatExpression::Mult(
@ -710,19 +711,19 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// sum(sym_b{i} * 2**i)
let mut expr = FlatExpression::Number(T::from(0));
for i in 0..bitwidth {
for i in 0..bit_width {
expr = FlatExpression::Add(
box expr,
box FlatExpression::Mult(
box FlatExpression::Identifier(sub_bits_be[i]),
box FlatExpression::Number(T::from(2).pow(bitwidth - i - 1)),
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
),
);
}
statements_flattened.push(FlatStatement::Condition(subtraction_result, expr));
FlatExpression::Identifier(sub_bits_be[bitwidth - 1])
FlatExpression::Identifier(sub_bits_be[bit_width - 1])
}
BooleanExpression::BoolEq(box lhs, box rhs) => {
// lhs and rhs are booleans, they flatten to 0 or 1

View file

@ -5,7 +5,7 @@ use zokrates_field::field::Field;
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
pub enum Solver {
ConditionEq,
Bits,
Bits(usize),
Div,
Sha256Round,
}
@ -20,7 +20,7 @@ impl Signed for Solver {
fn get_signature(&self) -> (usize, usize) {
match self {
Solver::ConditionEq => (1, 2),
Solver::Bits => (1, 254),
Solver::Bits(bit_width) => (1, *bit_width),
Solver::Div => (2, 1),
Solver::Sha256Round => (768, 26935),
}
@ -37,11 +37,11 @@ impl<T: Field> Executable<T> for Solver {
true => vec![T::zero(), T::one()],
false => vec![T::one(), T::one() / inputs[0].clone()],
},
Solver::Bits => {
Solver::Bits(bit_width) => {
let mut num = inputs[0].clone();
let mut res = vec![];
let bits = 254;
for i in (0..bits).rev() {
for i in (0..*bit_width).rev() {
if T::from(2).pow(i) <= num {
num = num - T::from(2).pow(i);
res.push(T::one());
@ -73,8 +73,8 @@ impl<T: Field> Executable<T> for Solver {
}
impl Solver {
pub fn bits() -> Self {
Solver::Bits
pub fn bits(width: usize) -> Self {
Solver::Bits(width)
}
}
@ -125,7 +125,9 @@ mod tests {
#[test]
fn bits_of_one() {
let inputs = vec![FieldPrime::from(1)];
let res = Solver::Bits.execute(&inputs).unwrap();
let res = Solver::Bits(FieldPrime::get_required_bits())
.execute(&inputs)
.unwrap();
assert_eq!(res[253], FieldPrime::from(1));
for i in 0..252 {
assert_eq!(res[i], FieldPrime::from(0));
@ -135,7 +137,9 @@ mod tests {
#[test]
fn bits_of_42() {
let inputs = vec![FieldPrime::from(42)];
let res = Solver::Bits.execute(&inputs).unwrap();
let res = Solver::Bits(FieldPrime::get_required_bits())
.execute(&inputs)
.unwrap();
assert_eq!(res[253], FieldPrime::from(0));
assert_eq!(res[252], FieldPrime::from(1));
assert_eq!(res[251], FieldPrime::from(0));