use field::Field; use flat_absy::FlatVariable; use num::Zero; use std::collections::BTreeMap; use std::fmt; use std::ops::{Add, Sub}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct QuadComb { pub left: LinComb, pub right: LinComb, } impl QuadComb { pub fn from_linear_combinations(left: LinComb, right: LinComb) -> Self { QuadComb { left, right } } } impl From for QuadComb { fn from(v: FlatVariable) -> QuadComb { LinComb::from(v).into() } } impl fmt::Display for QuadComb { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "({}) * ({})", self.left, self.right,) } } impl From> for QuadComb { fn from(lc: LinComb) -> QuadComb { QuadComb::from_linear_combinations(LinComb::one(), lc) } } #[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)] pub struct LinComb(pub BTreeMap); impl LinComb { pub fn summand>(mult: U, var: FlatVariable) -> LinComb { let mut res = BTreeMap::new(); res.insert(var, mult.into()); LinComb(res) } pub fn one() -> LinComb { Self::summand(1, FlatVariable::one()) } } impl fmt::Display for LinComb { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, "{}", self.0 .iter() .map(|(k, v)| format!("{} * {}", v, k)) .collect::>() .join(" + ") ) } } impl From for LinComb { fn from(v: FlatVariable) -> LinComb { let mut r = BTreeMap::new(); r.insert(v, T::one()); LinComb(r) } } impl Add> for LinComb { type Output = LinComb; fn add(self, other: LinComb) -> LinComb { let mut res = self.0.clone(); for (k, v) in other.0 { let new_val = v + res.get(&k).unwrap_or(&T::zero()); if new_val == T::zero() { res.remove(&k) } else { res.insert(k, new_val) }; } LinComb(res) } } impl Sub> for LinComb { type Output = LinComb; fn sub(self, other: LinComb) -> LinComb { let mut res = self.0.clone(); for (k, v) in other.0 { let new_val = T::zero() - v + res.get(&k).unwrap_or(&T::zero()); if new_val == T::zero() { res.remove(&k) } else { res.insert(k, new_val) }; } LinComb(res) } } impl Zero for LinComb { fn zero() -> LinComb { LinComb(BTreeMap::new()) } fn is_zero(&self) -> bool { self.0.len() == 0 } } #[cfg(test)] mod tests { use super::*; use field::FieldPrime; mod linear { use super::*; #[test] fn add_zero() { let a: LinComb = LinComb::zero(); let b: LinComb = FlatVariable::new(42).into(); let c = a + b.clone(); assert_eq!(c, b); } #[test] fn add() { let a: LinComb = FlatVariable::new(42).into(); let b: LinComb = FlatVariable::new(42).into(); let c = a + b.clone(); let mut expected_map = BTreeMap::new(); expected_map.insert(FlatVariable::new(42), FieldPrime::from(2)); assert_eq!(c, LinComb(expected_map)); } #[test] fn sub() { let a: LinComb = FlatVariable::new(42).into(); let b: LinComb = FlatVariable::new(42).into(); let c = a - b.clone(); assert_eq!(c, LinComb::zero()); } } mod quadratic { use super::*; #[test] fn from_linear() { let a: LinComb = LinComb::summand(3, FlatVariable::new(42)) + LinComb::summand(4, FlatVariable::new(33)); let expected = QuadComb { left: LinComb::one(), right: a.clone(), }; assert_eq!(QuadComb::from(a), expected); } #[test] fn zero() { let a: LinComb = LinComb::zero(); let expected: QuadComb = QuadComb { left: LinComb::one(), right: LinComb::zero(), }; assert_eq!(QuadComb::from(a), expected); } } }