1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

refactor flat expression utils to reduce expression tree depth

This commit is contained in:
schaeff 2022-05-31 15:30:18 +02:00
parent 748b341c43
commit 11c8a5d272
8 changed files with 90 additions and 99 deletions

View file

@ -3,8 +3,8 @@ use crate::absy::{
ConstantGenericNode, Expression,
};
use crate::flat_absy::{
FlatDirective, FlatExpression, FlatFunctionIterator, FlatParameter, FlatStatement,
FlatVariable, RuntimeError,
flat_expression_from_bits, flat_expression_from_variable_summands, FlatDirective,
FlatExpression, FlatFunctionIterator, FlatParameter, FlatStatement, FlatVariable, RuntimeError,
};
use crate::solvers::Solver;
use crate::typed_absy::types::{
@ -314,29 +314,6 @@ impl FlatEmbed {
}
}
// util to convert a vector of `(variable_id, coefficient)` to a flat_expression
// we build a binary tree of additions by splitting the vector recursively
#[cfg(any(feature = "ark", feature = "bellman"))]
fn flat_expression_from_vec<T: Field>(v: &[(usize, T)]) -> FlatExpression<T> {
match v.len() {
0 => FlatExpression::Number(T::zero()),
1 => {
let (key, val) = v[0].clone();
FlatExpression::Mult(
box FlatExpression::Number(val),
box FlatExpression::Identifier(FlatVariable::new(key)),
)
}
n => {
let (u, v) = v.split_at(n / 2);
FlatExpression::Add(
box flat_expression_from_vec::<T>(u),
box flat_expression_from_vec::<T>(v),
)
}
}
}
/// Returns a flat function which computes a sha256 round
///
/// # Remarks
@ -406,9 +383,9 @@ pub fn sha256_round<T: Field>(
// insert flattened statements to represent constraints
let constraint_statements = r1cs.constraints.into_iter().map(|c| {
let c = from_bellman::<T, Bn256>(c);
let rhs_a = flat_expression_from_vec::<T>(c.a.as_slice());
let rhs_b = flat_expression_from_vec::<T>(c.b.as_slice());
let lhs = flat_expression_from_vec::<T>(c.c.as_slice());
let rhs_a = flat_expression_from_variable_summands::<T>(c.a.as_slice());
let rhs_b = flat_expression_from_variable_summands::<T>(c.b.as_slice());
let lhs = flat_expression_from_variable_summands::<T>(c.c.as_slice());
FlatStatement::Condition(
lhs,
@ -514,9 +491,9 @@ pub fn snark_verify_bls12_377<T: Field>(
.into_iter()
.map(|c| {
let c = from_ark::<T, Bls12_377>(c);
let rhs_a = flat_expression_from_vec::<T>(c.a.as_slice());
let rhs_b = flat_expression_from_vec::<T>(c.b.as_slice());
let lhs = flat_expression_from_vec::<T>(c.c.as_slice());
let rhs_a = flat_expression_from_variable_summands::<T>(c.a.as_slice());
let rhs_b = flat_expression_from_variable_summands::<T>(c.b.as_slice());
let lhs = flat_expression_from_variable_summands::<T>(c.c.as_slice());
FlatStatement::Condition(
lhs,
@ -620,17 +597,11 @@ pub fn unpack_to_bitwidth<T: Field>(
.collect();
// sum check: o253 + o252 * 2 + ... + o{253 - (bit_width - 1)} * 2**(bit_width - 1)
let mut lhs_sum = FlatExpression::Number(T::from(0));
for i in 0..bit_width {
lhs_sum = FlatExpression::Add(
box lhs_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(FlatVariable::new(bit_width - i)),
box FlatExpression::Number(T::from(2).pow(i)),
),
);
}
let lhs_sum = flat_expression_from_bits(
(0..bit_width)
.map(|i| FlatExpression::Identifier(FlatVariable::new(bit_width - i)))
.collect(),
);
statements.push(FlatStatement::Condition(
lhs_sum,

View file

@ -7,9 +7,14 @@
pub mod flat_parameter;
pub mod flat_variable;
pub mod utils;
pub use self::flat_parameter::FlatParameter;
pub use self::flat_variable::FlatVariable;
pub use utils::{
flat_expression_from_bits, flat_expression_from_expression_summands,
flat_expression_from_variable_summands,
};
use serde::{Deserialize, Serialize};

View file

@ -0,0 +1,53 @@
use crate::flat_absy::*;
use zokrates_field::Field;
// util to convert a vector of `(coefficient, expression)` to a flat_expression
// we build a binary tree of additions by splitting the vector recursively
pub fn flat_expression_from_expression_summands<T: Field, U: Clone + Into<FlatExpression<T>>>(
v: &[(T, U)],
) -> FlatExpression<T> {
match v.len() {
0 => FlatExpression::Number(T::zero()),
1 => {
let (val, var) = v[0].clone();
FlatExpression::Mult(box FlatExpression::Number(val), box var.into())
}
n => {
let (u, v) = v.split_at(n / 2);
FlatExpression::Add(
box flat_expression_from_expression_summands(u),
box flat_expression_from_expression_summands(v),
)
}
}
}
pub fn flat_expression_from_bits<T: Field>(v: Vec<FlatExpression<T>>) -> FlatExpression<T> {
flat_expression_from_expression_summands(
&v.into_iter()
.rev()
.enumerate()
.map(|(index, var)| (T::from(2).pow(index), var))
.collect::<Vec<_>>(),
)
}
pub fn flat_expression_from_variable_summands<T: Field>(v: &[(T, usize)]) -> FlatExpression<T> {
match v.len() {
0 => FlatExpression::Number(T::zero()),
1 => {
let (val, var) = v[0].clone();
FlatExpression::Mult(
box FlatExpression::Number(val),
box FlatExpression::Identifier(FlatVariable::new(var)),
)
}
n => {
let (u, v) = v.split_at(n / 2);
FlatExpression::Add(
box flat_expression_from_variable_summands(u),
box flat_expression_from_variable_summands(v),
)
}
}
}

View file

@ -5,9 +5,6 @@
//! @author Jacob Eberhardt <jacob.eberhardt@tu-berlin.de>
//! @date 2017
mod utils;
use self::utils::flat_expression_from_bits;
use crate::ir::Interpreter;
use crate::compile::CompileConfig;

View file

@ -1,31 +0,0 @@
use crate::flat_absy::*;
use zokrates_field::Field;
pub fn flat_expression_from_bits<T: Field>(v: Vec<FlatExpression<T>>) -> FlatExpression<T> {
fn flat_expression_from_bits_aux<T: Field>(
v: Vec<(T, FlatExpression<T>)>,
) -> FlatExpression<T> {
match v.len() {
0 => FlatExpression::Number(T::zero()),
1 => {
let (coeff, var) = v[0].clone();
FlatExpression::Mult(box FlatExpression::Number(coeff), box var)
}
n => {
let (u, v) = v.split_at(n / 2);
FlatExpression::Add(
box flat_expression_from_bits_aux(u.to_vec()),
box flat_expression_from_bits_aux(v.to_vec()),
)
}
}
}
flat_expression_from_bits_aux(
v.into_iter()
.rev()
.enumerate()
.map(|(index, var)| (T::from(2).pow(index), var))
.collect::<Vec<_>>(),
)
}

View file

@ -165,11 +165,7 @@ pub fn generate_verify_constraints(
.into_iter()
.zip(matrices.b.into_iter())
.zip(matrices.c.into_iter())
.map(|((a, b), c)| Constraint {
a: a.into_iter().map(|(f, index)| (index, f)).collect(),
b: b.into_iter().map(|(f, index)| (index, f)).collect(),
c: c.into_iter().map(|(f, index)| (index, f)).collect(),
})
.map(|((a, b), c)| Constraint { a, b, c })
.collect();
(
@ -309,26 +305,26 @@ pub fn from_ark<T: zokrates_field::Field, E: PairingEngine>(c: Constraint<E::Fq>
Constraint {
a: c.a
.into_iter()
.map(|(index, fq)| {
.map(|(fq, index)| {
let mut res: Vec<u8> = vec![];
fq.into_repr().write_le(&mut res).unwrap();
(index, T::from_byte_vector(res))
(T::from_byte_vector(res), index)
})
.collect(),
b: c.b
.into_iter()
.map(|(index, fq)| {
.map(|(fq, index)| {
let mut res: Vec<u8> = vec![];
fq.into_repr().write_le(&mut res).unwrap();
(index, T::from_byte_vector(res))
(T::from_byte_vector(res), index)
})
.collect(),
c: c.c
.into_iter()
.map(|(index, fq)| {
.map(|(fq, index)| {
let mut res: Vec<u8> = vec![];
fq.into_repr().write_le(&mut res).unwrap();
(index, T::from_byte_vector(res))
(T::from_byte_vector(res), index)
})
.collect(),
}

View file

@ -180,17 +180,17 @@ impl<E: Engine> ConstraintSystem<E> for R1CS<E::Fr> {
let a = a
.as_ref()
.iter()
.map(|(variable, coefficient)| (var_to_index(*variable), *coefficient))
.map(|(variable, coefficient)| (*coefficient, var_to_index(*variable)))
.collect();
let b = b
.as_ref()
.iter()
.map(|(variable, coefficient)| (var_to_index(*variable), *coefficient))
.map(|(variable, coefficient)| (*coefficient, var_to_index(*variable)))
.collect();
let c = c
.as_ref()
.iter()
.map(|(variable, coefficient)| (var_to_index(*variable), *coefficient))
.map(|(variable, coefficient)| (*coefficient, var_to_index(*variable)))
.collect();
self.constraints.push(Constraint { a, b, c });
@ -257,26 +257,26 @@ pub fn from_bellman<T: zokrates_field::Field, E: Engine>(c: Constraint<E::Fr>) -
Constraint {
a: c.a
.into_iter()
.map(|(index, fq)| {
.map(|(fq, index)| {
let mut res: Vec<u8> = vec![];
fq.into_repr().write_le(&mut res).unwrap();
(index, T::from_byte_vector(res))
(T::from_byte_vector(res), index)
})
.collect(),
b: c.b
.into_iter()
.map(|(index, fq)| {
.map(|(fq, index)| {
let mut res: Vec<u8> = vec![];
fq.into_repr().write_le(&mut res).unwrap();
(index, T::from_byte_vector(res))
(T::from_byte_vector(res), index)
})
.collect(),
c: c.c
.into_iter()
.map(|(index, fq)| {
.map(|(fq, index)| {
let mut res: Vec<u8> = vec![];
fq.into_repr().write_le(&mut res).unwrap();
(index, T::from_byte_vector(res))
(T::from_byte_vector(res), index)
})
.collect(),
}

View file

@ -24,7 +24,7 @@ pub struct Witness<T> {
#[derive(Default, Debug, PartialEq, Clone)]
pub struct Constraint<T> {
pub a: Vec<(usize, T)>,
pub b: Vec<(usize, T)>,
pub c: Vec<(usize, T)>,
pub a: Vec<(T, usize)>,
pub b: Vec<(T, usize)>,
pub c: Vec<(T, usize)>,
}