Implement Add/Sub and CanonicalLinComb
This commit is contained in:
parent
dacf63c1e9
commit
e3f5cf047d
1 changed files with 53 additions and 33 deletions
|
@ -3,7 +3,7 @@ use num::Zero;
|
|||
use std::fmt;
|
||||
use std::ops::{Add, Div, Mul, Sub};
|
||||
use zokrates_field::field::Field;
|
||||
use std::collections::btree_set::BTreeSet;
|
||||
use std::collections::btree_map::{BTreeMap, Entry};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct QuadComb<T: Field> {
|
||||
|
@ -56,6 +56,9 @@ impl<T: Field> fmt::Display for QuadComb<T> {
|
|||
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)]
|
||||
pub struct LinComb<T: Field>(pub Vec<(FlatVariable, T)>);
|
||||
|
||||
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)]
|
||||
pub struct CanonicalLinComb<T: Field>(BTreeMap<FlatVariable, T>);
|
||||
|
||||
impl<T: Field> LinComb<T> {
|
||||
pub fn summand<U: Into<T>>(mult: U, var: FlatVariable) -> LinComb<T> {
|
||||
let res = vec![(var, mult.into())];
|
||||
|
@ -74,6 +77,34 @@ impl<T: Field> LinComb<T> {
|
|||
|
||||
None
|
||||
}
|
||||
|
||||
fn as_canonical(&self) -> CanonicalLinComb<T> {
|
||||
CanonicalLinComb(self.0.clone().into_iter().fold(
|
||||
BTreeMap::new(),
|
||||
|mut acc, (val, coeff)| {
|
||||
// if we're adding 0 times some variable, we can ignore this term
|
||||
if coeff != T::zero() {
|
||||
match acc.entry(val) {
|
||||
Entry::Occupied(o) => {
|
||||
// if the new value is non zero, update, else remove the term entirely
|
||||
if o.get().clone() + coeff.clone() != T::zero() {
|
||||
*o.into_mut() = o.get().clone() + coeff;
|
||||
} else {
|
||||
o.remove();
|
||||
}
|
||||
}
|
||||
Entry::Vacant(v) => {
|
||||
// We checked earlier but let's make sure we're not creating zero-coeff terms
|
||||
assert!(coeff != T::zero());
|
||||
v.insert(coeff);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
acc
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> fmt::Display for LinComb<T> {
|
||||
|
@ -83,7 +114,9 @@ impl<T: Field> fmt::Display for LinComb<T> {
|
|||
false => write!(
|
||||
f,
|
||||
"{}",
|
||||
self.0
|
||||
self
|
||||
.as_canonical()
|
||||
.0
|
||||
.iter()
|
||||
.map(|(k, v)| format!("{} * {}", v.to_compact_dec_string(), k))
|
||||
.collect::<Vec<_>>()
|
||||
|
@ -104,10 +137,7 @@ impl<T: Field> Add<LinComb<T>> for LinComb<T> {
|
|||
type Output = LinComb<T>;
|
||||
|
||||
fn add(self, other: LinComb<T>) -> LinComb<T> {
|
||||
let mut res = self.0;
|
||||
res.extend(other.0);
|
||||
|
||||
LinComb(res)
|
||||
LinComb(self.0.into_iter().chain(other.0.into_iter()).collect())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -115,29 +145,13 @@ impl<T: Field> Sub<LinComb<T>> for LinComb<T> {
|
|||
type Output = LinComb<T>;
|
||||
|
||||
fn sub(self, other: LinComb<T>) -> LinComb<T> {
|
||||
let mut res = self.0;
|
||||
let mut to_remove = BTreeSet::new();
|
||||
|
||||
for (k, v) in other.0 {
|
||||
let var;
|
||||
{
|
||||
var = res
|
||||
.iter()
|
||||
.find(|(id, _)| *id == k)
|
||||
.map(|(_, v)| v)
|
||||
.cloned();
|
||||
}
|
||||
|
||||
let new_val = T::zero() - v + var.unwrap_or(T::zero());
|
||||
if new_val != T::zero() {
|
||||
res.push((k, new_val));
|
||||
} else {
|
||||
to_remove.insert(k);
|
||||
}
|
||||
}
|
||||
|
||||
res.retain(|(id, _)| !to_remove.contains(id));
|
||||
LinComb(res)
|
||||
// Concatenate with second vector that have negative coeffs
|
||||
LinComb(
|
||||
self.0
|
||||
.into_iter()
|
||||
.chain(other.0.into_iter().map(|(var, coeff)| (var, T::zero() - coeff)))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -195,8 +209,8 @@ mod tests {
|
|||
let expected_vec = vec![
|
||||
(FlatVariable::new(42), FieldPrime::from(1)),
|
||||
(FlatVariable::new(42), FieldPrime::from(1)),
|
||||
|
||||
];
|
||||
|
||||
assert_eq!(c, LinComb(expected_vec));
|
||||
}
|
||||
#[test]
|
||||
|
@ -204,14 +218,20 @@ mod tests {
|
|||
let a: LinComb<FieldPrime> = FlatVariable::new(42).into();
|
||||
let b: LinComb<FieldPrime> = FlatVariable::new(42).into();
|
||||
let c = a - b.clone();
|
||||
assert_eq!(LinComb::zero(), c);
|
||||
|
||||
let expected_vec = vec![
|
||||
(FlatVariable::new(42), FieldPrime::from(1)),
|
||||
(FlatVariable::new(42), FieldPrime::from(-1)),
|
||||
];
|
||||
|
||||
assert_eq!(c, LinComb(expected_vec));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let a: LinComb<FieldPrime> =
|
||||
LinComb::from(FlatVariable::new(42)) + LinComb::summand(3, FlatVariable::new(21));
|
||||
assert_eq!(&a.to_string(), "1 * _42 + 3 * _21");
|
||||
assert_eq!(&a.to_string(), "3 * _21 + 1 * _42");
|
||||
let zero: LinComb<FieldPrime> = LinComb::zero();
|
||||
assert_eq!(&zero.to_string(), "0");
|
||||
}
|
||||
|
@ -247,7 +267,7 @@ mod tests {
|
|||
+ LinComb::summand(4, FlatVariable::new(33)),
|
||||
right: LinComb::summand(1, FlatVariable::new(21)),
|
||||
};
|
||||
assert_eq!(&a.to_string(), "(3 * _42 + 4 * _33) * (1 * _21)");
|
||||
assert_eq!(&a.to_string(), "(4 * _33 + 3 * _42) * (1 * _21)");
|
||||
let a: QuadComb<FieldPrime> = QuadComb {
|
||||
left: LinComb::zero(),
|
||||
right: LinComb::summand(1, FlatVariable::new(21)),
|
||||
|
|
Loading…
Reference in a new issue