1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

pass prog as reference, fix tests

This commit is contained in:
dark64 2021-09-06 12:53:08 +02:00
parent 34b631b644
commit 7c248fd77d
2 changed files with 18 additions and 14 deletions

View file

@ -150,6 +150,7 @@ impl<T: Field> Analyse for Prog<T> {
fn analyse(self) -> Result<Self, Self::Error> { fn analyse(self) -> Result<Self, Self::Error> {
log::debug!("Static analyser: Detect unconstrained zir"); log::debug!("Static analyser: Detect unconstrained zir");
UnconstrainedVariableDetector::detect(self).map_err(Error::from) UnconstrainedVariableDetector::detect(&self).map_err(Error::from)?;
Ok(self)
} }
} }

View file

@ -11,7 +11,7 @@ pub struct UnconstrainedVariableDetector {
pub(self) variables: HashSet<FlatVariable>, pub(self) variables: HashSet<FlatVariable>,
} }
#[derive(Debug)] #[derive(Debug, PartialEq)]
pub struct Error(usize); pub struct Error(usize);
impl fmt::Display for Error { impl fmt::Display for Error {
@ -26,12 +26,12 @@ impl fmt::Display for Error {
} }
impl UnconstrainedVariableDetector { impl UnconstrainedVariableDetector {
pub fn detect<T: Field>(p: Prog<T>) -> Result<Prog<T>, Error> { pub fn detect<T: Field>(p: &Prog<T>) -> Result<(), Error> {
let mut instance = Self::default(); let mut instance = Self::default();
instance.visit_module(&p); instance.visit_module(&p);
if instance.variables.is_empty() { if instance.variables.is_empty() {
Ok(p) Ok(())
} else { } else {
Err(Error(instance.variables.len())) Err(Error(instance.variables.len()))
} }
@ -61,12 +61,12 @@ mod tests {
use zokrates_field::Bn128Field; use zokrates_field::Bn128Field;
#[test] #[test]
fn should_detect_unconstrained_private_input() { fn unconstrained_private_input() {
// def main(_0) -> (1): // def main(_0) -> (1):
// (1 * ~one) * (42 * ~one) == 1 * ~out_0 // (1 * ~one) * (42 * ~one) == 1 * ~out_0
// return ~out_0 // return ~out_0
let _0 = FlatParameter::private(FlatVariable::new(0)); // unused var let _0 = FlatParameter::private(FlatVariable::new(0)); // unused private parameter
let one = FlatVariable::one(); let one = FlatVariable::one();
let out_0 = FlatVariable::public(0); let out_0 = FlatVariable::public(0);
@ -83,12 +83,15 @@ mod tests {
returns: vec![out_0], returns: vec![out_0],
}; };
let p = UnconstrainedVariableDetector::detect(p); let result = UnconstrainedVariableDetector::detect(&p);
assert!(p.is_err()); assert_eq!(
result.expect_err("expected an error").to_string(),
"Found unconstrained variables during IR analysis (found 1 occurrence)"
);
} }
#[test] #[test]
fn should_pass_with_constrained_private_input() { fn constrained_private_input() {
// def main(_0) -> (1): // def main(_0) -> (1):
// (1 * ~one) * (1 * _0) == 1 * ~out_0 // (1 * ~one) * (1 * _0) == 1 * ~out_0
// return ~out_0 // return ~out_0
@ -102,12 +105,12 @@ mod tests {
returns: vec![out_0], returns: vec![out_0],
}; };
let p = UnconstrainedVariableDetector::detect(p); let result = UnconstrainedVariableDetector::detect(&p);
assert!(p.is_ok()); assert_eq!(result, Ok(()));
} }
#[test] #[test]
fn should_pass_with_directive() { fn constrained_directive() {
// def main(_0) -> (1): // def main(_0) -> (1):
// # _1, _2 = ConditionEq((-42) * ~one + 1 * _0) // # _1, _2 = ConditionEq((-42) * ~one + 1 * _0)
// ((-42) * ~one + 1 * _0) * (1 * _2) == 1 * _1 // ((-42) * ~one + 1 * _0) * (1 * _2) == 1 * _1
@ -155,7 +158,7 @@ mod tests {
returns: vec![out_0], returns: vec![out_0],
}; };
let p = UnconstrainedVariableDetector::detect(p); let result = UnconstrainedVariableDetector::detect(&p);
assert!(p.is_ok()); assert_eq!(result, Ok(()));
} }
} }