1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

Implement Add/Sub and CanonicalLinComb

This commit is contained in:
Eugene P 2019-04-16 12:12:51 +03:00
parent dacf63c1e9
commit e3f5cf047d

View file

@ -3,7 +3,7 @@ use num::Zero;
use std::fmt;
use std::ops::{Add, Div, Mul, Sub};
use zokrates_field::field::Field;
use std::collections::btree_set::BTreeSet;
use std::collections::btree_map::{BTreeMap, Entry};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct QuadComb<T: Field> {
@ -56,6 +56,9 @@ impl<T: Field> fmt::Display for QuadComb<T> {
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)]
pub struct LinComb<T: Field>(pub Vec<(FlatVariable, T)>);
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)]
pub struct CanonicalLinComb<T: Field>(BTreeMap<FlatVariable, T>);
impl<T: Field> LinComb<T> {
pub fn summand<U: Into<T>>(mult: U, var: FlatVariable) -> LinComb<T> {
let res = vec![(var, mult.into())];
@ -74,6 +77,34 @@ impl<T: Field> LinComb<T> {
None
}
fn as_canonical(&self) -> CanonicalLinComb<T> {
CanonicalLinComb(self.0.clone().into_iter().fold(
BTreeMap::new(),
|mut acc, (val, coeff)| {
// if we're adding 0 times some variable, we can ignore this term
if coeff != T::zero() {
match acc.entry(val) {
Entry::Occupied(o) => {
// if the new value is non zero, update, else remove the term entirely
if o.get().clone() + coeff.clone() != T::zero() {
*o.into_mut() = o.get().clone() + coeff;
} else {
o.remove();
}
}
Entry::Vacant(v) => {
// We checked earlier but let's make sure we're not creating zero-coeff terms
assert!(coeff != T::zero());
v.insert(coeff);
}
}
}
acc
},
))
}
}
impl<T: Field> fmt::Display for LinComb<T> {
@ -83,7 +114,9 @@ impl<T: Field> fmt::Display for LinComb<T> {
false => write!(
f,
"{}",
self.0
self
.as_canonical()
.0
.iter()
.map(|(k, v)| format!("{} * {}", v.to_compact_dec_string(), k))
.collect::<Vec<_>>()
@ -104,10 +137,7 @@ impl<T: Field> Add<LinComb<T>> for LinComb<T> {
type Output = LinComb<T>;
fn add(self, other: LinComb<T>) -> LinComb<T> {
let mut res = self.0;
res.extend(other.0);
LinComb(res)
LinComb(self.0.into_iter().chain(other.0.into_iter()).collect())
}
}
@ -115,29 +145,13 @@ impl<T: Field> Sub<LinComb<T>> for LinComb<T> {
type Output = LinComb<T>;
fn sub(self, other: LinComb<T>) -> LinComb<T> {
let mut res = self.0;
let mut to_remove = BTreeSet::new();
for (k, v) in other.0 {
let var;
{
var = res
.iter()
.find(|(id, _)| *id == k)
.map(|(_, v)| v)
.cloned();
}
let new_val = T::zero() - v + var.unwrap_or(T::zero());
if new_val != T::zero() {
res.push((k, new_val));
} else {
to_remove.insert(k);
}
}
res.retain(|(id, _)| !to_remove.contains(id));
LinComb(res)
// Concatenate with second vector that have negative coeffs
LinComb(
self.0
.into_iter()
.chain(other.0.into_iter().map(|(var, coeff)| (var, T::zero() - coeff)))
.collect(),
)
}
}
@ -195,8 +209,8 @@ mod tests {
let expected_vec = vec![
(FlatVariable::new(42), FieldPrime::from(1)),
(FlatVariable::new(42), FieldPrime::from(1)),
];
assert_eq!(c, LinComb(expected_vec));
}
#[test]
@ -204,14 +218,20 @@ mod tests {
let a: LinComb<FieldPrime> = FlatVariable::new(42).into();
let b: LinComb<FieldPrime> = FlatVariable::new(42).into();
let c = a - b.clone();
assert_eq!(LinComb::zero(), c);
let expected_vec = vec![
(FlatVariable::new(42), FieldPrime::from(1)),
(FlatVariable::new(42), FieldPrime::from(-1)),
];
assert_eq!(c, LinComb(expected_vec));
}
#[test]
fn display() {
let a: LinComb<FieldPrime> =
LinComb::from(FlatVariable::new(42)) + LinComb::summand(3, FlatVariable::new(21));
assert_eq!(&a.to_string(), "1 * _42 + 3 * _21");
assert_eq!(&a.to_string(), "3 * _21 + 1 * _42");
let zero: LinComb<FieldPrime> = LinComb::zero();
assert_eq!(&zero.to_string(), "0");
}
@ -247,7 +267,7 @@ mod tests {
+ LinComb::summand(4, FlatVariable::new(33)),
right: LinComb::summand(1, FlatVariable::new(21)),
};
assert_eq!(&a.to_string(), "(3 * _42 + 4 * _33) * (1 * _21)");
assert_eq!(&a.to_string(), "(4 * _33 + 3 * _42) * (1 * _21)");
let a: QuadComb<FieldPrime> = QuadComb {
left: LinComb::zero(),
right: LinComb::summand(1, FlatVariable::new(21)),