refactor flat expression utils to reduce expression tree depth
This commit is contained in:
parent
748b341c43
commit
11c8a5d272
8 changed files with 90 additions and 99 deletions
|
@ -3,8 +3,8 @@ use crate::absy::{
|
|||
ConstantGenericNode, Expression,
|
||||
};
|
||||
use crate::flat_absy::{
|
||||
FlatDirective, FlatExpression, FlatFunctionIterator, FlatParameter, FlatStatement,
|
||||
FlatVariable, RuntimeError,
|
||||
flat_expression_from_bits, flat_expression_from_variable_summands, FlatDirective,
|
||||
FlatExpression, FlatFunctionIterator, FlatParameter, FlatStatement, FlatVariable, RuntimeError,
|
||||
};
|
||||
use crate::solvers::Solver;
|
||||
use crate::typed_absy::types::{
|
||||
|
@ -314,29 +314,6 @@ impl FlatEmbed {
|
|||
}
|
||||
}
|
||||
|
||||
// util to convert a vector of `(variable_id, coefficient)` to a flat_expression
|
||||
// we build a binary tree of additions by splitting the vector recursively
|
||||
#[cfg(any(feature = "ark", feature = "bellman"))]
|
||||
fn flat_expression_from_vec<T: Field>(v: &[(usize, T)]) -> FlatExpression<T> {
|
||||
match v.len() {
|
||||
0 => FlatExpression::Number(T::zero()),
|
||||
1 => {
|
||||
let (key, val) = v[0].clone();
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Number(val),
|
||||
box FlatExpression::Identifier(FlatVariable::new(key)),
|
||||
)
|
||||
}
|
||||
n => {
|
||||
let (u, v) = v.split_at(n / 2);
|
||||
FlatExpression::Add(
|
||||
box flat_expression_from_vec::<T>(u),
|
||||
box flat_expression_from_vec::<T>(v),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a flat function which computes a sha256 round
|
||||
///
|
||||
/// # Remarks
|
||||
|
@ -406,9 +383,9 @@ pub fn sha256_round<T: Field>(
|
|||
// insert flattened statements to represent constraints
|
||||
let constraint_statements = r1cs.constraints.into_iter().map(|c| {
|
||||
let c = from_bellman::<T, Bn256>(c);
|
||||
let rhs_a = flat_expression_from_vec::<T>(c.a.as_slice());
|
||||
let rhs_b = flat_expression_from_vec::<T>(c.b.as_slice());
|
||||
let lhs = flat_expression_from_vec::<T>(c.c.as_slice());
|
||||
let rhs_a = flat_expression_from_variable_summands::<T>(c.a.as_slice());
|
||||
let rhs_b = flat_expression_from_variable_summands::<T>(c.b.as_slice());
|
||||
let lhs = flat_expression_from_variable_summands::<T>(c.c.as_slice());
|
||||
|
||||
FlatStatement::Condition(
|
||||
lhs,
|
||||
|
@ -514,9 +491,9 @@ pub fn snark_verify_bls12_377<T: Field>(
|
|||
.into_iter()
|
||||
.map(|c| {
|
||||
let c = from_ark::<T, Bls12_377>(c);
|
||||
let rhs_a = flat_expression_from_vec::<T>(c.a.as_slice());
|
||||
let rhs_b = flat_expression_from_vec::<T>(c.b.as_slice());
|
||||
let lhs = flat_expression_from_vec::<T>(c.c.as_slice());
|
||||
let rhs_a = flat_expression_from_variable_summands::<T>(c.a.as_slice());
|
||||
let rhs_b = flat_expression_from_variable_summands::<T>(c.b.as_slice());
|
||||
let lhs = flat_expression_from_variable_summands::<T>(c.c.as_slice());
|
||||
|
||||
FlatStatement::Condition(
|
||||
lhs,
|
||||
|
@ -620,17 +597,11 @@ pub fn unpack_to_bitwidth<T: Field>(
|
|||
.collect();
|
||||
|
||||
// sum check: o253 + o252 * 2 + ... + o{253 - (bit_width - 1)} * 2**(bit_width - 1)
|
||||
let mut lhs_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..bit_width {
|
||||
lhs_sum = FlatExpression::Add(
|
||||
box lhs_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(FlatVariable::new(bit_width - i)),
|
||||
box FlatExpression::Number(T::from(2).pow(i)),
|
||||
),
|
||||
);
|
||||
}
|
||||
let lhs_sum = flat_expression_from_bits(
|
||||
(0..bit_width)
|
||||
.map(|i| FlatExpression::Identifier(FlatVariable::new(bit_width - i)))
|
||||
.collect(),
|
||||
);
|
||||
|
||||
statements.push(FlatStatement::Condition(
|
||||
lhs_sum,
|
||||
|
|
|
@ -7,9 +7,14 @@
|
|||
|
||||
pub mod flat_parameter;
|
||||
pub mod flat_variable;
|
||||
pub mod utils;
|
||||
|
||||
pub use self::flat_parameter::FlatParameter;
|
||||
pub use self::flat_variable::FlatVariable;
|
||||
pub use utils::{
|
||||
flat_expression_from_bits, flat_expression_from_expression_summands,
|
||||
flat_expression_from_variable_summands,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
|
53
zokrates_core/src/flat_absy/utils.rs
Normal file
53
zokrates_core/src/flat_absy/utils.rs
Normal file
|
@ -0,0 +1,53 @@
|
|||
use crate::flat_absy::*;
|
||||
use zokrates_field::Field;
|
||||
|
||||
// util to convert a vector of `(coefficient, expression)` to a flat_expression
|
||||
// we build a binary tree of additions by splitting the vector recursively
|
||||
pub fn flat_expression_from_expression_summands<T: Field, U: Clone + Into<FlatExpression<T>>>(
|
||||
v: &[(T, U)],
|
||||
) -> FlatExpression<T> {
|
||||
match v.len() {
|
||||
0 => FlatExpression::Number(T::zero()),
|
||||
1 => {
|
||||
let (val, var) = v[0].clone();
|
||||
FlatExpression::Mult(box FlatExpression::Number(val), box var.into())
|
||||
}
|
||||
n => {
|
||||
let (u, v) = v.split_at(n / 2);
|
||||
FlatExpression::Add(
|
||||
box flat_expression_from_expression_summands(u),
|
||||
box flat_expression_from_expression_summands(v),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flat_expression_from_bits<T: Field>(v: Vec<FlatExpression<T>>) -> FlatExpression<T> {
|
||||
flat_expression_from_expression_summands(
|
||||
&v.into_iter()
|
||||
.rev()
|
||||
.enumerate()
|
||||
.map(|(index, var)| (T::from(2).pow(index), var))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn flat_expression_from_variable_summands<T: Field>(v: &[(T, usize)]) -> FlatExpression<T> {
|
||||
match v.len() {
|
||||
0 => FlatExpression::Number(T::zero()),
|
||||
1 => {
|
||||
let (val, var) = v[0].clone();
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Number(val),
|
||||
box FlatExpression::Identifier(FlatVariable::new(var)),
|
||||
)
|
||||
}
|
||||
n => {
|
||||
let (u, v) = v.split_at(n / 2);
|
||||
FlatExpression::Add(
|
||||
box flat_expression_from_variable_summands(u),
|
||||
box flat_expression_from_variable_summands(v),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -5,9 +5,6 @@
|
|||
//! @author Jacob Eberhardt <jacob.eberhardt@tu-berlin.de>
|
||||
//! @date 2017
|
||||
|
||||
mod utils;
|
||||
|
||||
use self::utils::flat_expression_from_bits;
|
||||
use crate::ir::Interpreter;
|
||||
|
||||
use crate::compile::CompileConfig;
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
use crate::flat_absy::*;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub fn flat_expression_from_bits<T: Field>(v: Vec<FlatExpression<T>>) -> FlatExpression<T> {
|
||||
fn flat_expression_from_bits_aux<T: Field>(
|
||||
v: Vec<(T, FlatExpression<T>)>,
|
||||
) -> FlatExpression<T> {
|
||||
match v.len() {
|
||||
0 => FlatExpression::Number(T::zero()),
|
||||
1 => {
|
||||
let (coeff, var) = v[0].clone();
|
||||
FlatExpression::Mult(box FlatExpression::Number(coeff), box var)
|
||||
}
|
||||
n => {
|
||||
let (u, v) = v.split_at(n / 2);
|
||||
FlatExpression::Add(
|
||||
box flat_expression_from_bits_aux(u.to_vec()),
|
||||
box flat_expression_from_bits_aux(v.to_vec()),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
flat_expression_from_bits_aux(
|
||||
v.into_iter()
|
||||
.rev()
|
||||
.enumerate()
|
||||
.map(|(index, var)| (T::from(2).pow(index), var))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
|
@ -165,11 +165,7 @@ pub fn generate_verify_constraints(
|
|||
.into_iter()
|
||||
.zip(matrices.b.into_iter())
|
||||
.zip(matrices.c.into_iter())
|
||||
.map(|((a, b), c)| Constraint {
|
||||
a: a.into_iter().map(|(f, index)| (index, f)).collect(),
|
||||
b: b.into_iter().map(|(f, index)| (index, f)).collect(),
|
||||
c: c.into_iter().map(|(f, index)| (index, f)).collect(),
|
||||
})
|
||||
.map(|((a, b), c)| Constraint { a, b, c })
|
||||
.collect();
|
||||
|
||||
(
|
||||
|
@ -309,26 +305,26 @@ pub fn from_ark<T: zokrates_field::Field, E: PairingEngine>(c: Constraint<E::Fq>
|
|||
Constraint {
|
||||
a: c.a
|
||||
.into_iter()
|
||||
.map(|(index, fq)| {
|
||||
.map(|(fq, index)| {
|
||||
let mut res: Vec<u8> = vec![];
|
||||
fq.into_repr().write_le(&mut res).unwrap();
|
||||
(index, T::from_byte_vector(res))
|
||||
(T::from_byte_vector(res), index)
|
||||
})
|
||||
.collect(),
|
||||
b: c.b
|
||||
.into_iter()
|
||||
.map(|(index, fq)| {
|
||||
.map(|(fq, index)| {
|
||||
let mut res: Vec<u8> = vec![];
|
||||
fq.into_repr().write_le(&mut res).unwrap();
|
||||
(index, T::from_byte_vector(res))
|
||||
(T::from_byte_vector(res), index)
|
||||
})
|
||||
.collect(),
|
||||
c: c.c
|
||||
.into_iter()
|
||||
.map(|(index, fq)| {
|
||||
.map(|(fq, index)| {
|
||||
let mut res: Vec<u8> = vec![];
|
||||
fq.into_repr().write_le(&mut res).unwrap();
|
||||
(index, T::from_byte_vector(res))
|
||||
(T::from_byte_vector(res), index)
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
|
|
|
@ -180,17 +180,17 @@ impl<E: Engine> ConstraintSystem<E> for R1CS<E::Fr> {
|
|||
let a = a
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|(variable, coefficient)| (var_to_index(*variable), *coefficient))
|
||||
.map(|(variable, coefficient)| (*coefficient, var_to_index(*variable)))
|
||||
.collect();
|
||||
let b = b
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|(variable, coefficient)| (var_to_index(*variable), *coefficient))
|
||||
.map(|(variable, coefficient)| (*coefficient, var_to_index(*variable)))
|
||||
.collect();
|
||||
let c = c
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|(variable, coefficient)| (var_to_index(*variable), *coefficient))
|
||||
.map(|(variable, coefficient)| (*coefficient, var_to_index(*variable)))
|
||||
.collect();
|
||||
|
||||
self.constraints.push(Constraint { a, b, c });
|
||||
|
@ -257,26 +257,26 @@ pub fn from_bellman<T: zokrates_field::Field, E: Engine>(c: Constraint<E::Fr>) -
|
|||
Constraint {
|
||||
a: c.a
|
||||
.into_iter()
|
||||
.map(|(index, fq)| {
|
||||
.map(|(fq, index)| {
|
||||
let mut res: Vec<u8> = vec![];
|
||||
fq.into_repr().write_le(&mut res).unwrap();
|
||||
(index, T::from_byte_vector(res))
|
||||
(T::from_byte_vector(res), index)
|
||||
})
|
||||
.collect(),
|
||||
b: c.b
|
||||
.into_iter()
|
||||
.map(|(index, fq)| {
|
||||
.map(|(fq, index)| {
|
||||
let mut res: Vec<u8> = vec![];
|
||||
fq.into_repr().write_le(&mut res).unwrap();
|
||||
(index, T::from_byte_vector(res))
|
||||
(T::from_byte_vector(res), index)
|
||||
})
|
||||
.collect(),
|
||||
c: c.c
|
||||
.into_iter()
|
||||
.map(|(index, fq)| {
|
||||
.map(|(fq, index)| {
|
||||
let mut res: Vec<u8> = vec![];
|
||||
fq.into_repr().write_le(&mut res).unwrap();
|
||||
(index, T::from_byte_vector(res))
|
||||
(T::from_byte_vector(res), index)
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ pub struct Witness<T> {
|
|||
|
||||
#[derive(Default, Debug, PartialEq, Clone)]
|
||||
pub struct Constraint<T> {
|
||||
pub a: Vec<(usize, T)>,
|
||||
pub b: Vec<(usize, T)>,
|
||||
pub c: Vec<(usize, T)>,
|
||||
pub a: Vec<(T, usize)>,
|
||||
pub b: Vec<(T, usize)>,
|
||||
pub c: Vec<(T, usize)>,
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue