1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

handle main specific behaviour in prog, introduce quadcomb for return statements

This commit is contained in:
schaeff 2018-12-05 00:08:36 +01:00
parent 7a46f15e74
commit dfea1fe578
5 changed files with 267 additions and 169 deletions

View file

@ -0,0 +1,173 @@
use field::Field;
use flat_absy::FlatVariable;
use num::Zero;
use std::collections::BTreeMap;
use std::fmt;
use std::ops::{Add, Sub};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct QuadComb<T: Field> {
pub left: LinComb<T>,
pub right: LinComb<T>,
}
impl<T: Field> QuadComb<T> {
pub fn from_linear_combinations(left: LinComb<T>, right: LinComb<T>) -> Self {
QuadComb { left, right }
}
}
impl<T: Field> From<FlatVariable> for QuadComb<T> {
fn from(v: FlatVariable) -> QuadComb<T> {
LinComb::from(v).into()
}
}
impl<T: Field> fmt::Display for QuadComb<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "({}) * ({})", self.left, self.right,)
}
}
impl<T: Field> From<LinComb<T>> for QuadComb<T> {
fn from(lc: LinComb<T>) -> QuadComb<T> {
QuadComb::from_linear_combinations(LinComb::one(), lc)
}
}
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)]
pub struct LinComb<T: Field>(pub BTreeMap<FlatVariable, T>);
impl<T: Field> LinComb<T> {
pub fn summand<U: Into<T>>(mult: U, var: FlatVariable) -> LinComb<T> {
let mut res = BTreeMap::new();
res.insert(var, mult.into());
LinComb(res)
}
pub fn one() -> LinComb<T> {
Self::summand(1, FlatVariable::one())
}
}
impl<T: Field> fmt::Display for LinComb<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}",
self.0
.iter()
.map(|(k, v)| format!("{} * {}", v, k))
.collect::<Vec<_>>()
.join(" + ")
)
}
}
impl<T: Field> From<FlatVariable> for LinComb<T> {
fn from(v: FlatVariable) -> LinComb<T> {
let mut r = BTreeMap::new();
r.insert(v, T::one());
LinComb(r)
}
}
impl<T: Field> Add<LinComb<T>> for LinComb<T> {
type Output = LinComb<T>;
fn add(self, other: LinComb<T>) -> LinComb<T> {
let mut res = self.0.clone();
for (k, v) in other.0 {
let new_val = v + res.get(&k).unwrap_or(&T::zero());
if new_val == T::zero() {
res.remove(&k)
} else {
res.insert(k, new_val)
};
}
LinComb(res)
}
}
impl<T: Field> Sub<LinComb<T>> for LinComb<T> {
type Output = LinComb<T>;
fn sub(self, other: LinComb<T>) -> LinComb<T> {
let mut res = self.0.clone();
for (k, v) in other.0 {
let new_val = T::zero() - v + res.get(&k).unwrap_or(&T::zero());
if new_val == T::zero() {
res.remove(&k)
} else {
res.insert(k, new_val)
};
}
LinComb(res)
}
}
impl<T: Field> Zero for LinComb<T> {
fn zero() -> LinComb<T> {
LinComb(BTreeMap::new())
}
fn is_zero(&self) -> bool {
self.0.len() == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use field::FieldPrime;
mod linear {
use super::*;
#[test]
fn add_zero() {
let a: LinComb<FieldPrime> = LinComb::zero();
let b: LinComb<FieldPrime> = FlatVariable::new(42).into();
let c = a + b.clone();
assert_eq!(c, b);
}
#[test]
fn add() {
let a: LinComb<FieldPrime> = FlatVariable::new(42).into();
let b: LinComb<FieldPrime> = FlatVariable::new(42).into();
let c = a + b.clone();
let mut expected_map = BTreeMap::new();
expected_map.insert(FlatVariable::new(42), FieldPrime::from(2));
assert_eq!(c, LinComb(expected_map));
}
#[test]
fn sub() {
let a: LinComb<FieldPrime> = FlatVariable::new(42).into();
let b: LinComb<FieldPrime> = FlatVariable::new(42).into();
let c = a - b.clone();
assert_eq!(c, LinComb::zero());
}
}
mod quadratic {
use super::*;
#[test]
fn from_linear() {
let a: LinComb<FieldPrime> = LinComb::summand(3, FlatVariable::new(42)) + LinComb::summand(4, FlatVariable::new(33));
let expected = QuadComb {
left: LinComb::one(),
right: a.clone()
};
assert_eq!(QuadComb::from(a), expected);
}
#[test]
fn zero() {
let a: LinComb<FieldPrime> = LinComb::zero();
let expected: QuadComb<FieldPrime> = QuadComb {
left: LinComb::one(),
right: LinComb::zero(),
};
assert_eq!(QuadComb::from(a), expected);
}
}
}

View file

@ -1,12 +1,11 @@
use field::Field;
use flat_absy::{FlatExpression, FlatFunction, FlatProg, FlatStatement, FlatVariable};
use helpers;
use ir::{DirectiveStatement, Function, LinComb, Prog, Statement};
use ir::{DirectiveStatement, Function, LinComb, Prog, QuadComb, Statement};
use num::Zero;
impl<T: Field> From<FlatFunction<T>> for Function<T> {
fn from(flat_function: FlatFunction<T>) -> Function<T> {
// we need to make sure that return values are identifiers, so we define new (public) variables
let return_expressions: Vec<FlatExpression<T>> = flat_function
.statements
.iter()
@ -16,24 +15,13 @@ impl<T: Field> From<FlatFunction<T>> for Function<T> {
})
.next()
.unwrap();
let return_count = return_expressions.len();
let definitions = return_expressions
.into_iter()
.enumerate()
.map(|(index, e)| FlatStatement::Definition(FlatVariable::public(index), e));
Function {
id: flat_function.id,
arguments: flat_function.arguments.into_iter().map(|p| p.id).collect(),
// return the public variables we just defined
return_wires: (0..return_count)
.map(|i| FlatVariable::public(i).into())
.collect(),
// statements are the function statements, followed by definitions of outputs
returns: return_expressions.into_iter().map(|e| e.into()).collect(),
statements: flat_function
.statements
.into_iter()
.chain(definitions)
.filter_map(|s| match s {
FlatStatement::Return(..) => None,
s => Some(s.into()),
@ -45,19 +33,57 @@ impl<T: Field> From<FlatFunction<T>> for Function<T> {
impl<T: Field> From<FlatProg<T>> for Prog<T> {
fn from(flat_prog: FlatProg<T>) -> Prog<T> {
println!("{}", flat_prog);
// get the main function as all calls have been resolved
let main = flat_prog
.functions
.into_iter()
.find(|f| f.id == "main")
.unwrap();
// get the interface of the program, ie which inputs are private and public
let private = main.arguments.iter().map(|p| p.private).collect();
// convert the main function to this IR for functions
let main: Function<T> = main.into();
// contrary to other functions, we need to make sure that return values are identifiers, so we define new (public) variables
let definitions =
main.returns.iter().enumerate().map(|(index, e)| {
Statement::Constraint(e.clone(), FlatVariable::public(index).into())
});
// update the main function with the extra definition statements and replace the return values
let main = Function {
returns: (0..main.returns.len())
.map(|i| FlatVariable::public(i).into())
.collect(),
statements: main.statements.into_iter().chain(definitions).collect(),
..main
};
let main = Function::from(main);
Prog { private, main }
}
}
impl<T: Field> From<FlatExpression<T>> for QuadComb<T> {
fn from(flat_expression: FlatExpression<T>) -> QuadComb<T> {
match flat_expression.is_linear() {
true => LinComb::from(flat_expression).into(),
false => match flat_expression {
FlatExpression::Mult(box e1, box e2) => {
QuadComb::from_linear_combinations(e1.into(), e2.into())
}
e => unimplemented!("{}", e),
},
}
}
}
impl<T: Field> From<FlatExpression<T>> for LinComb<T> {
fn from(flat_expression: FlatExpression<T>) -> LinComb<T> {
assert!(flat_expression.is_linear());
match flat_expression {
FlatExpression::Number(ref n) if *n == T::from(0) => LinComb::zero(),
FlatExpression::Number(n) => LinComb::summand(n, FlatVariable::one()),
@ -81,16 +107,18 @@ impl<T: Field> From<FlatStatement<T>> for Statement<T> {
fn from(flat_statement: FlatStatement<T>) -> Statement<T> {
match flat_statement {
FlatStatement::Condition(linear, quadratic) => match quadratic {
FlatExpression::Mult(box lhs, box rhs) => {
Statement::Constraint(lhs.into(), rhs.into(), linear.into())
}
e => Statement::Constraint(LinComb::one(), e.into(), linear.into()),
FlatExpression::Mult(box lhs, box rhs) => Statement::Constraint(
QuadComb::from_linear_combinations(lhs.into(), rhs.into()),
linear.into(),
),
e => Statement::Constraint(LinComb::from(e).into(), linear.into()),
},
FlatStatement::Definition(var, quadratic) => match quadratic {
FlatExpression::Mult(box lhs, box rhs) => {
Statement::Constraint(lhs.into(), rhs.into(), var.into())
}
e => Statement::Constraint(LinComb::one(), e.into(), var.into()),
FlatExpression::Mult(box lhs, box rhs) => Statement::Constraint(
QuadComb::from_linear_combinations(lhs.into(), rhs.into()),
var.into(),
),
e => Statement::Constraint(LinComb::from(e).into(), var.into()),
},
FlatStatement::Directive(ds) => Statement::Directive(ds.into()),
_ => panic!("return should be handled at the function level"),

View file

@ -15,16 +15,16 @@ impl<T: Field> Prog<T> {
for statement in main.statements {
match statement {
Statement::Constraint(a, b, c) => match c.is_assignee(&witness) {
Statement::Constraint(quad, lin) => match lin.is_assignee(&witness) {
true => {
let val = a.evaluate(&witness) * b.evaluate(&witness);
witness.insert(c.0.iter().next().unwrap().0.clone(), val);
let val = quad.evaluate(&witness);
witness.insert(lin.0.iter().next().unwrap().0.clone(), val);
}
false => {
let lhs_value = a.evaluate(&witness) * b.evaluate(&witness);
let rhs_value = c.evaluate(&witness);
let lhs_value = quad.evaluate(&witness);
let rhs_value = lin.evaluate(&witness);
if lhs_value != rhs_value {
return Err(Error::Constraint(a, b, c, lhs_value, rhs_value));
return Err(Error::Constraint(quad, lin, lhs_value, rhs_value));
}
}
},
@ -63,19 +63,25 @@ impl<T: Field> LinComb<T> {
}
}
impl<T: Field> QuadComb<T> {
fn evaluate(&self, witness: &BTreeMap<FlatVariable, T>) -> T {
self.left.evaluate(&witness) * self.right.evaluate(&witness)
}
}
#[derive(PartialEq, Debug)]
pub enum Error<T: Field> {
Constraint(LinComb<T>, LinComb<T>, LinComb<T>, T, T),
Constraint(QuadComb<T>, LinComb<T>, T, T),
Solver,
}
impl<T: Field> fmt::Display for Error<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Error::Constraint(ref a, ref b, ref c, ref left_value, ref right_value) => write!(
Error::Constraint(ref quad, ref lin, ref left_value, ref right_value) => write!(
f,
"Expected ({}) * ({}) to equal ({}), but {} != {}",
a, b, c, left_value, right_value
"Expected {} to equal {}, but {} != {}",
quad, lin, left_value, right_value
),
Error::Solver => write!(f, ""),
}

View file

@ -1,116 +0,0 @@
use field::Field;
use flat_absy::FlatVariable;
use num::Zero;
use std::collections::BTreeMap;
use std::fmt;
use std::ops::{Add, Sub};
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)]
pub struct LinComb<T: Field>(pub BTreeMap<FlatVariable, T>);
impl<T: Field> LinComb<T> {
pub fn summand<U: Into<T>>(mult: U, var: FlatVariable) -> LinComb<T> {
let mut res = BTreeMap::new();
res.insert(var, mult.into());
LinComb(res)
}
pub fn one() -> LinComb<T> {
Self::summand(1, FlatVariable::one())
}
}
impl<T: Field> fmt::Display for LinComb<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}",
self.0
.iter()
.map(|(k, v)| format!("{} * {}", v, k))
.collect::<Vec<_>>()
.join(" + ")
)
}
}
impl<T: Field> From<FlatVariable> for LinComb<T> {
fn from(v: FlatVariable) -> LinComb<T> {
let mut r = BTreeMap::new();
r.insert(v, T::one());
LinComb(r)
}
}
impl<T: Field> Add<LinComb<T>> for LinComb<T> {
type Output = LinComb<T>;
fn add(self, other: LinComb<T>) -> LinComb<T> {
let mut res = self.0.clone();
for (k, v) in other.0 {
let new_val = v + res.get(&k).unwrap_or(&T::zero());
if new_val == T::zero() {
res.remove(&k)
} else {
res.insert(k, new_val)
};
}
LinComb(res)
}
}
impl<T: Field> Sub<LinComb<T>> for LinComb<T> {
type Output = LinComb<T>;
fn sub(self, other: LinComb<T>) -> LinComb<T> {
let mut res = self.0.clone();
for (k, v) in other.0 {
let new_val = T::zero() - v + res.get(&k).unwrap_or(&T::zero());
if new_val == T::zero() {
res.remove(&k)
} else {
res.insert(k, new_val)
};
}
LinComb(res)
}
}
impl<T: Field> Zero for LinComb<T> {
fn zero() -> LinComb<T> {
LinComb(BTreeMap::new())
}
fn is_zero(&self) -> bool {
self.0.len() == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use field::FieldPrime;
#[test]
fn add_zero() {
let a: LinComb<FieldPrime> = LinComb::zero();
let b: LinComb<FieldPrime> = FlatVariable::new(42).into();
let c = a + b.clone();
assert_eq!(c, b);
}
#[test]
fn add() {
let a: LinComb<FieldPrime> = FlatVariable::new(42).into();
let b: LinComb<FieldPrime> = FlatVariable::new(42).into();
let c = a + b.clone();
let mut expected_map = BTreeMap::new();
expected_map.insert(FlatVariable::new(42), FieldPrime::from(2));
assert_eq!(c, LinComb(expected_map));
}
#[test]
fn sub() {
let a: LinComb<FieldPrime> = FlatVariable::new(42).into();
let b: LinComb<FieldPrime> = FlatVariable::new(42).into();
let c = a - b.clone();
assert_eq!(c, LinComb::zero());
}
}

View file

@ -6,15 +6,16 @@ use std::collections::HashMap;
use std::fmt;
use std::mem;
mod expression;
mod from_flat;
mod interpreter;
mod linear_combination;
use self::linear_combination::LinComb;
use self::expression::LinComb;
use self::expression::QuadComb;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum Statement<T: Field> {
Constraint(LinComb<T>, LinComb<T>, LinComb<T>),
Constraint(QuadComb<T>, LinComb<T>),
Directive(DirectiveStatement<T>),
}
@ -48,7 +49,7 @@ impl<T: Field> fmt::Display for DirectiveStatement<T> {
impl<T: Field> fmt::Display for Statement<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Statement::Constraint(ref a, ref b, ref c) => write!(f, "({}) * ({}) == ({})", a, b, c),
Statement::Constraint(ref quad, ref lin) => write!(f, "{} == {}", quad, lin),
Statement::Directive(ref s) => write!(f, "{}", s),
}
}
@ -59,7 +60,7 @@ pub struct Function<T: Field> {
pub id: String,
pub statements: Vec<Statement<T>>,
pub arguments: Vec<FlatVariable>,
pub return_wires: Vec<LinComb<T>>,
pub returns: Vec<QuadComb<T>>,
}
impl<T: Field> fmt::Display for Function<T> {
@ -73,13 +74,13 @@ impl<T: Field> fmt::Display for Function<T> {
.map(|v| format!("{}", v))
.collect::<Vec<_>>()
.join(", "),
self.return_wires.len(),
self.returns.len(),
self.statements
.iter()
.map(|s| format!("\t{}", s))
.collect::<Vec<_>>()
.join("\n"),
self.return_wires
self.returns
.iter()
.map(|e| format!("{}", e))
.collect::<Vec<_>>()
@ -181,7 +182,7 @@ pub fn r1cs_program<T: Field>(
//~out are added after main's arguments as we want variables (columns)
//in the r1cs to be aligned like "public inputs | private inputs"
let main_return_count = main.return_wires.len();
let main_return_count = main.returns.len();
for i in 0..main_return_count {
provide_variable_idx(&mut variables, &FlatVariable::public(i));
@ -191,19 +192,19 @@ pub fn r1cs_program<T: Field>(
let private_inputs_offset = variables.len();
// first pass through statements to populate `variables`
for (aa, bb, cc) in main.statements.iter().filter_map(|s| match s {
Statement::Constraint(aa, bb, cc) => Some((aa, bb, cc)),
for (aa, bb) in main.statements.iter().filter_map(|s| match s {
Statement::Constraint(aa, bb) => Some((aa, bb)),
Statement::Directive(..) => None,
}) {
for (k, _) in &aa.0 {
for (k, _) in &aa.left.0 {
provide_variable_idx(&mut variables, &k);
}
for (k, _) in &aa.right.0 {
provide_variable_idx(&mut variables, &k);
}
for (k, _) in &bb.0 {
provide_variable_idx(&mut variables, &k);
}
for (k, _) in &cc.0 {
provide_variable_idx(&mut variables, &k);
}
}
let mut a = vec![];
@ -211,22 +212,26 @@ pub fn r1cs_program<T: Field>(
let mut c = vec![];
// second pass to convert program to raw sparse vectors
for (aa, bb, cc) in main.statements.into_iter().filter_map(|s| match s {
Statement::Constraint(aa, bb, cc) => Some((aa, bb, cc)),
for (aa, bb) in main.statements.into_iter().filter_map(|s| match s {
Statement::Constraint(aa, bb) => Some((aa, bb)),
Statement::Directive(..) => None,
}) {
a.push(
aa.0.into_iter()
aa.left
.0
.into_iter()
.map(|(k, v)| (variables.get(&k).unwrap().clone(), v))
.collect(),
);
b.push(
bb.0.into_iter()
aa.right
.0
.into_iter()
.map(|(k, v)| (variables.get(&k).unwrap().clone(), v))
.collect(),
);
c.push(
cc.0.into_iter()
bb.0.into_iter()
.map(|(k, v)| (variables.get(&k).unwrap().clone(), v))
.collect(),
);
@ -252,8 +257,10 @@ mod tests {
#[test]
fn print_constraint() {
let c: Statement<FieldPrime> = Statement::Constraint(
FlatVariable::new(42).into(),
FlatVariable::new(42).into(),
QuadComb::from_linear_combinations(
FlatVariable::new(42).into(),
FlatVariable::new(42).into(),
),
FlatVariable::new(42).into(),
);
assert_eq!(format!("{}", c), "(1 * _42) * (1 * _42) == (1 * _42)")