1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

Implement division for field; Add cases to r1cs; Bugfix in flattening for Mult; Bugfix for absy is_flattened

This commit is contained in:
Dennis Kuhnert 2017-02-20 16:35:50 +01:00
parent 4cb6f351e1
commit b01020f9cb
5 changed files with 142 additions and 85 deletions

View file

@ -7,5 +7,6 @@ def qeval(x):
y = x**3
b = x**5
c = x / 2
return y + x + y + c
d = (2 * x + 3 * b) * (x - b)
return y + x + y + c + d
// comment

View file

@ -180,7 +180,11 @@ impl<T: Field> Expression<T> {
Expression::Add(ref x, ref y) |
Expression::Sub(ref x, ref y) => x.is_linear() && y.is_linear(),
Expression::Mult(ref x, ref y) |
Expression::Div(ref x, ref y) => x.is_linear() && y.is_linear(),
Expression::Div(ref x, ref y) => match (x.clone(), y.clone()) {
(box Expression::Sub(..), _) |
(_, box Expression::Sub(..)) => false,
(box x, box y) => x.is_linear() && y.is_linear()
},
_ => false,
}
}

View file

@ -29,6 +29,8 @@ pub trait Field : From<i32> + From<u32> + From<usize> + for<'a> From<&'a str>
+ Div<Self, Output=Self> + for<'a> Div<&'a Self, Output=Self>
+ Pow<usize, Output=Self> + Pow<Self, Output=Self> + for<'a> Pow<&'a Self, Output=Self>
{
/// Returns a byte slice of this `Field`'s contents in decimal `String` representation.
fn into_dec_bytes(&self) -> Vec<u8>;
/// Returns the smallest value that can be represented by this field type.
fn min_value() -> Self;
/// Returns the largest value that can be represented by this field type.
@ -47,6 +49,9 @@ impl Field for FieldPrime {
fn max_value() -> FieldPrime {
FieldPrime{ value: &*P - ToBigInt::to_bigint(&1).unwrap() }
}
fn into_dec_bytes(&self) -> Vec<u8> {
self.value.to_str_radix(10).to_string().into_bytes()
}
}
impl Display for FieldPrime {
@ -154,36 +159,7 @@ impl<'a> Mul<&'a FieldPrime> for FieldPrime {
}
}
// extended_euclid(a,b)
// 1 wenn b = 0
// 2 dann return (a,1,0)
// 3 (d',s',t') = extended_euclid(b, a mod b)
// 4 (d,s,t) = (d',t',s' - (a div b)t')
// 5 return (d,s,t)
/// Calculates the gcd using a recursive implementation of the extended euclidian algorithm.
/// Returns (gcd(a,b), s, t) where gcd(a,b) = s * a + t * b
// fn extended_euclid(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
// if b.is_zero() {
// return (a.clone(), BigInt::one(), BigInt::zero());
// }
// let (d2, s2, t2) = extended_euclid(b, &(a % b));
// (d2, t2.clone(), s2 - (a / b) * t2)
// }
// function extended_gcd(a, b)
// s := 0; old_s := 1
// t := 1; old_t := 0
// r := b; old_r := a
// while r ≠ 0
// quotient := old_r div r
// (old_r, r) := (r, old_r - quotient * r)
// (old_s, s) := (s, old_s - quotient * s)
// (old_t, t) := (t, old_t - quotient * t)
// output "Bézout coefficients:", (old_s, old_t)
// output "greatest common divisor:", old_r
// output "quotients by the gcd:", (t, s)
/// Calculates the gcd using a iterative implementation of the extended euclidian algorithm.
fn extended_euclid(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
let (mut s, mut old_s) = (BigInt::zero(), BigInt::one());
let (mut t, mut old_t) = (BigInt::one(), BigInt::zero());
@ -203,36 +179,13 @@ fn extended_euclid(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
return (old_r, old_s, old_t)
}
// function inverse(a, n)
// t := 0; newt := 1;
// r := n; newr := a;
// while newr ≠ 0
// quotient := r div newr
// (t, newt) := (newt, t - quotient * newt)
// (r, newr) := (newr, r - quotient * newr)
// if r > 1 then return "a is not invertible"
// if t < 0 then t := t + n
// return t
impl Div<FieldPrime> for FieldPrime {
type Output = FieldPrime;
fn div(self, other: FieldPrime) -> FieldPrime {
// let (mut t, mut newt, mut r, mut newr) = (FieldPrime::zero(), FieldPrime::one(), other, self);
// while !&newr.is_zero() {
// let quotient = r.value.clone() / newr.value.clone();
// t = newt.clone();
// newt = t.clone() - FieldPrime{ value: quotient.clone() } * &newt;
// r = newr.clone();
// newr = r.clone() - FieldPrime{ value: quotient } * &newr;
// }
// if r > FieldPrime::one() {
// panic!("a is not invertible");
// }
// t
// (a * b^(p-2)) % p
// self * other.pow(FieldPrime::max_value() - FieldPrime::one())
FieldPrime{ value: self.value / other.value }
let (b, s, _) = extended_euclid(&other.value, &*P);
assert_eq!(b, BigInt::one());
FieldPrime{ value: &s - s.div_floor(&*P) * &*P } * self
}
}
@ -435,25 +388,25 @@ mod tests {
#[test]
fn division() {
assert_eq!(
"4".parse::<BigInt>().unwrap(),
(FieldPrime::from("54") / FieldPrime::from("12")).value
FieldPrime::from(4),
FieldPrime::from(48) / FieldPrime::from(12)
);
assert_eq!(
"4".parse::<BigInt>().unwrap(),
(FieldPrime::from("54") / &FieldPrime::from("12")).value
FieldPrime::from(4),
FieldPrime::from(48) / &FieldPrime::from(12)
);
}
#[test]
fn division_negative() {
let res = FieldPrime::from("-54") / FieldPrime::from("12");
assert_eq!(FieldPrime::from("-54"), FieldPrime::from("12") * res);
let res = FieldPrime::from(-54) / FieldPrime::from(12);
assert_eq!(FieldPrime::from(-54), FieldPrime::from(12) * res);
}
#[test]
fn division_two_negative() {
let res = FieldPrime::from("-12") / FieldPrime::from("-85");
assert_eq!(FieldPrime::from("-12"), FieldPrime::from("-85") * res);
let res = FieldPrime::from(-12) / FieldPrime::from(-85);
assert_eq!(FieldPrime::from(-12), FieldPrime::from(-85) * res);
}
#[test]
@ -523,5 +476,12 @@ mod tests {
(ToBigInt::to_bigint(&2).unwrap(), ToBigInt::to_bigint(&-9).unwrap(), ToBigInt::to_bigint(&47).unwrap()),
extended_euclid(&ToBigInt::to_bigint(&240).unwrap(), &ToBigInt::to_bigint(&46).unwrap())
);
let (b, s, _) = extended_euclid(&ToBigInt::to_bigint(&253).unwrap(), &*P);
assert_eq!(b, BigInt::one());
let s_field = FieldPrime{ value: &s - s.div_floor(&*P) * &*P };
assert_eq!(
FieldPrime::from("20071432198682655539767455070749754231531795211435158457485599966631195574669"),
s_field
);
}
}

View file

@ -135,7 +135,14 @@ impl Flattener {
let left_flattened = self.flatten_expression(statements_flattened, left);
let right_flattened = self.flatten_expression(statements_flattened, right);
let new_left = if left_flattened.is_linear() {
left_flattened
if let Sub(..) = left_flattened {
let new_name = format!("sym_{}", self.next_var_idx);
self.next_var_idx += 1;
statements_flattened.push(Statement::Definition(new_name.to_string(), left_flattened));
VariableReference(new_name)
} else {
left_flattened
}
} else {
let new_name = format!("sym_{}", self.next_var_idx);
self.next_var_idx += 1;
@ -143,7 +150,14 @@ impl Flattener {
VariableReference(new_name)
};
let new_right = if right_flattened.is_linear() {
right_flattened
if let Sub(..) = right_flattened {
let new_name = format!("sym_{}", self.next_var_idx);
self.next_var_idx += 1;
statements_flattened.push(Statement::Definition(new_name.to_string(), right_flattened));
VariableReference(new_name)
} else {
right_flattened
}
} else {
let new_name = format!("sym_{}", self.next_var_idx);
self.next_var_idx += 1;

View file

@ -8,9 +8,9 @@ use absy::Expression::*;
use std::collections::HashMap;
use field::Field;
fn count_variables_add<T: Field>(expr: Expression<T>) -> HashMap<String, T> {
fn count_variables_add<T: Field>(expr: &Expression<T>) -> HashMap<String, T> {
let mut count = HashMap::new();
match expr {
match expr.clone() {
NumberLiteral(x) => { count.insert("~one".to_string(), x); },
VariableReference(var) => { count.insert(var, T::one()); },
Add(box lhs, box rhs) => {
@ -42,7 +42,7 @@ fn count_variables_add<T: Field>(expr: Expression<T>) -> HashMap<String, T> {
let num = count.entry("~one".to_string()).or_insert(T::zero());
*num = num.clone() + x;
}
let vars = count_variables_add(e);
let vars = count_variables_add(&e);
for (key, value) in &vars {
let val = count.entry(key.to_string()).or_insert(T::zero());
*val = val.clone() + value;
@ -54,7 +54,7 @@ fn count_variables_add<T: Field>(expr: Expression<T>) -> HashMap<String, T> {
let var = count.entry(v).or_insert(T::zero());
*var = var.clone() + T::one();
}
let vars = count_variables_add(e);
let vars = count_variables_add(&e);
for (key, value) in &vars {
let val = count.entry(key.to_string()).or_insert(T::zero());
*val = val.clone() + value;
@ -90,12 +90,15 @@ fn count_variables_add<T: Field>(expr: Expression<T>) -> HashMap<String, T> {
let var = count.entry(v).or_insert(T::zero());
*var = var.clone() + n;
}
let vars = count_variables_add(e);
let vars = count_variables_add(&e);
for (key, value) in &vars {
let val = count.entry(key.to_string()).or_insert(T::zero());
*val = val.clone() + value;
}
},
(Mult(box NumberLiteral(n1), box VariableReference(v1)), Mult(box NumberLiteral(n2), box VariableReference(v2))) |
(Mult(box VariableReference(v1), box NumberLiteral(n1)), Mult(box NumberLiteral(n2), box VariableReference(v2))) |
(Mult(box NumberLiteral(n1), box VariableReference(v1)), Mult(box VariableReference(v2), box NumberLiteral(n2))) |
(Mult(box VariableReference(v1), box NumberLiteral(n1)), Mult(box VariableReference(v2), box NumberLiteral(n2))) => {
{
let var = count.entry(v1).or_insert(T::zero());
@ -135,7 +138,10 @@ fn swap_sub<T: Field>(lhs: &Expression<T>, rhs: &Expression<T>) -> (Expression<T
(Sub(box left, box right), v @ VariableReference(_)) |
(v @ VariableReference(_), Sub(box left, box right)) |
(Sub(box left, box right), v @ NumberLiteral(_)) |
(v @ NumberLiteral(_), Sub(box left, box right)) => {
(v @ NumberLiteral(_), Sub(box left, box right)) |
(Sub(box left, box right), v @ Mult(..)) |
(v @ Mult(..), Sub(box left, box right)) => {
assert!(v.is_linear());
let (l, r) = swap_sub(&left, &right);
(Add(box v, box r), l)
},
@ -149,11 +155,11 @@ pub fn r1cs_expression<T: Field>(linear_expr: Expression<T>, expr: Expression<T>
e @ Add(..) |
e @ Sub(..) => { // a - b = c --> b + c = a
let (lhs, rhs) = swap_sub(&linear_expr, &e);
for (key, value) in count_variables_add(rhs) {
for (key, value) in count_variables_add(&rhs) {
a_row.push((variables.iter().position(|r| r == &key).unwrap(), value));
}
b_row.push((0, T::one()));
for (key, value) in count_variables_add(lhs) {
for (key, value) in count_variables_add(&lhs) {
c_row.push((variables.iter().position(|r| r == &key).unwrap(), value));
}
},
@ -161,7 +167,7 @@ pub fn r1cs_expression<T: Field>(linear_expr: Expression<T>, expr: Expression<T>
match lhs {
box NumberLiteral(x) => a_row.push((0, x)),
box VariableReference(x) => a_row.push((variables.iter().position(|r| r == &x).unwrap(), T::one())),
box e @ Add(..) => for (key, value) in count_variables_add(e) {
box e @ Add(..) => for (key, value) in count_variables_add(&e) {
a_row.push((variables.iter().position(|r| r == &key).unwrap(), value));
},
e @ _ => panic!("Not flattened: {}", e),
@ -169,43 +175,47 @@ pub fn r1cs_expression<T: Field>(linear_expr: Expression<T>, expr: Expression<T>
match rhs {
box NumberLiteral(x) => b_row.push((0, x)),
box VariableReference(x) => b_row.push((variables.iter().position(|r| r == &x).unwrap(), T::one())),
box e @ Add(..) => for (key, value) in count_variables_add(e) {
box e @ Add(..) => for (key, value) in count_variables_add(&e) {
b_row.push((variables.iter().position(|r| r == &key).unwrap(), value));
},
e @ _ => panic!("Not flattened: {}", e),
};
for (key, value) in count_variables_add(linear_expr) {
for (key, value) in count_variables_add(&linear_expr) {
c_row.push((variables.iter().position(|r| r == &key).unwrap(), value));
}
},
Div(lhs, rhs) => { // a / b = c --> c * b = a
for (key, value) in count_variables_add(linear_expr) {
a_row.push((variables.iter().position(|r| r == &key).unwrap(), value));
}
match lhs {
box NumberLiteral(x) => c_row.push((0, x)),
box VariableReference(x) => c_row.push((variables.iter().position(|r| r == &x).unwrap(), T::one())),
_ => unimplemented!(),
box e @ Add(..) => for (key, value) in count_variables_add(&e) {
c_row.push((variables.iter().position(|r| r == &key).unwrap(), value));
},
box e @ Sub(..) => return r1cs_expression(Mult(box linear_expr, rhs), e, variables, a_row, b_row, c_row),
e @ _ => panic!("not implemented yet: {:?}", e),
};
match rhs {
box NumberLiteral(x) => b_row.push((0, x)),
box VariableReference(x) => b_row.push((variables.iter().position(|r| r == &x).unwrap(), T::one())),
_ => unimplemented!(),
};
for (key, value) in count_variables_add(&linear_expr) {
a_row.push((variables.iter().position(|r| r == &key).unwrap(), value));
}
},
Pow(_, _) => panic!("Pow not flattened"),
IfElse(_, _, _) => panic!("IfElse not flattened"),
VariableReference(var) => {
a_row.push((variables.iter().position(|r| r == &var).unwrap(), T::one()));
b_row.push((0, T::one()));
for (key, value) in count_variables_add(linear_expr) {
for (key, value) in count_variables_add(&linear_expr) {
c_row.push((variables.iter().position(|r| r == &key).unwrap(), value));
}
},
NumberLiteral(x) => {
a_row.push((0, x));
b_row.push((0, T::one()));
for (key, value) in count_variables_add(linear_expr) {
for (key, value) in count_variables_add(&linear_expr) {
c_row.push((variables.iter().position(|r| r == &key).unwrap(), value));
}
},
@ -272,5 +282,73 @@ mod tests {
row_eq(vec![(0, FieldPrime::from(1))], b_row);
row_eq(vec![(1, FieldPrime::from(1))], c_row);
}
#[test]
fn add_mult() {
// 4 * b + 3 * a + 3 * c == (3 * a + 6 * b + 4 * c) * (31 * a + 4 * c)
let lhs = Add(
box Add(
box Mult(box NumberLiteral(FieldPrime::from(4)), box VariableReference(String::from("b"))),
box Mult(box NumberLiteral(FieldPrime::from(3)), box VariableReference(String::from("a")))
),
box Mult(box NumberLiteral(FieldPrime::from(3)), box VariableReference(String::from("c")))
);
let rhs = Mult(
box Add(
box Add(
box Mult(box NumberLiteral(FieldPrime::from(3)), box VariableReference(String::from("a"))),
box Mult(box NumberLiteral(FieldPrime::from(6)), box VariableReference(String::from("b")))
),
box Mult(box NumberLiteral(FieldPrime::from(4)), box VariableReference(String::from("c")))
),
box Add(
box Mult(box NumberLiteral(FieldPrime::from(31)), box VariableReference(String::from("a"))),
box Mult(box NumberLiteral(FieldPrime::from(4)), box VariableReference(String::from("c")))
)
);
let mut variables: Vec<String> = vec!["~one", "a", "b", "c"].iter().map(|&x| String::from(x)).collect();
let mut a_row: Vec<(usize, FieldPrime)> = Vec::new();
let mut b_row: Vec<(usize, FieldPrime)> = Vec::new();
let mut c_row: Vec<(usize, FieldPrime)> = Vec::new();
r1cs_expression(lhs, rhs, &mut variables, &mut a_row, &mut b_row, &mut c_row);
row_eq(vec![(1, FieldPrime::from(3)), (2, FieldPrime::from(6)), (3, FieldPrime::from(4))], a_row);
row_eq(vec![(1, FieldPrime::from(31)), (3, FieldPrime::from(4))], b_row);
row_eq(vec![(1, FieldPrime::from(3)), (2, FieldPrime::from(4)), (3, FieldPrime::from(3))], c_row);
}
#[test]
fn add_mult() {
// 4 * b + 3 * a - 3 * c == (3 * a + 6 * b + 4 * c) * (31 * a + 4 * c)
let lhs = Add(
box Add(
box Mult(box NumberLiteral(FieldPrime::from(4)), box VariableReference(String::from("b"))),
box Mult(box NumberLiteral(FieldPrime::from(3)), box VariableReference(String::from("a")))
),
box Mult(box NumberLiteral(FieldPrime::from(3)), box VariableReference(String::from("c")))
);
let rhs = Mult(
box Add(
box Add(
box Mult(box NumberLiteral(FieldPrime::from(3)), box VariableReference(String::from("a"))),
box Mult(box NumberLiteral(FieldPrime::from(6)), box VariableReference(String::from("b")))
),
box Mult(box NumberLiteral(FieldPrime::from(4)), box VariableReference(String::from("c")))
),
box Add(
box Mult(box NumberLiteral(FieldPrime::from(31)), box VariableReference(String::from("a"))),
box Mult(box NumberLiteral(FieldPrime::from(4)), box VariableReference(String::from("c")))
)
);
let mut variables: Vec<String> = vec!["~one", "a", "b", "c"].iter().map(|&x| String::from(x)).collect();
let mut a_row: Vec<(usize, FieldPrime)> = Vec::new();
let mut b_row: Vec<(usize, FieldPrime)> = Vec::new();
let mut c_row: Vec<(usize, FieldPrime)> = Vec::new();
r1cs_expression(lhs, rhs, &mut variables, &mut a_row, &mut b_row, &mut c_row);
row_eq(vec![(1, FieldPrime::from(3)), (2, FieldPrime::from(6)), (3, FieldPrime::from(4))], a_row);
row_eq(vec![(1, FieldPrime::from(31)), (3, FieldPrime::from(4))], b_row);
row_eq(vec![(1, FieldPrime::from(3)), (2, FieldPrime::from(4)), (3, FieldPrime::from(3))], c_row);
}
}
}