1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

pass prog as reference, fix tests

This commit is contained in:
dark64 2021-09-06 12:53:08 +02:00
parent 34b631b644
commit 7c248fd77d
2 changed files with 18 additions and 14 deletions

View file

@ -150,6 +150,7 @@ impl<T: Field> Analyse for Prog<T> {
fn analyse(self) -> Result<Self, Self::Error> {
log::debug!("Static analyser: Detect unconstrained zir");
UnconstrainedVariableDetector::detect(self).map_err(Error::from)
UnconstrainedVariableDetector::detect(&self).map_err(Error::from)?;
Ok(self)
}
}

View file

@ -11,7 +11,7 @@ pub struct UnconstrainedVariableDetector {
pub(self) variables: HashSet<FlatVariable>,
}
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub struct Error(usize);
impl fmt::Display for Error {
@ -26,12 +26,12 @@ impl fmt::Display for Error {
}
impl UnconstrainedVariableDetector {
pub fn detect<T: Field>(p: Prog<T>) -> Result<Prog<T>, Error> {
pub fn detect<T: Field>(p: &Prog<T>) -> Result<(), Error> {
let mut instance = Self::default();
instance.visit_module(&p);
if instance.variables.is_empty() {
Ok(p)
Ok(())
} else {
Err(Error(instance.variables.len()))
}
@ -61,12 +61,12 @@ mod tests {
use zokrates_field::Bn128Field;
#[test]
fn should_detect_unconstrained_private_input() {
fn unconstrained_private_input() {
// def main(_0) -> (1):
// (1 * ~one) * (42 * ~one) == 1 * ~out_0
// return ~out_0
let _0 = FlatParameter::private(FlatVariable::new(0)); // unused var
let _0 = FlatParameter::private(FlatVariable::new(0)); // unused private parameter
let one = FlatVariable::one();
let out_0 = FlatVariable::public(0);
@ -83,12 +83,15 @@ mod tests {
returns: vec![out_0],
};
let p = UnconstrainedVariableDetector::detect(p);
assert!(p.is_err());
let result = UnconstrainedVariableDetector::detect(&p);
assert_eq!(
result.expect_err("expected an error").to_string(),
"Found unconstrained variables during IR analysis (found 1 occurrence)"
);
}
#[test]
fn should_pass_with_constrained_private_input() {
fn constrained_private_input() {
// def main(_0) -> (1):
// (1 * ~one) * (1 * _0) == 1 * ~out_0
// return ~out_0
@ -102,12 +105,12 @@ mod tests {
returns: vec![out_0],
};
let p = UnconstrainedVariableDetector::detect(p);
assert!(p.is_ok());
let result = UnconstrainedVariableDetector::detect(&p);
assert_eq!(result, Ok(()));
}
#[test]
fn should_pass_with_directive() {
fn constrained_directive() {
// def main(_0) -> (1):
// # _1, _2 = ConditionEq((-42) * ~one + 1 * _0)
// ((-42) * ~one + 1 * _0) * (1 * _2) == 1 * _1
@ -155,7 +158,7 @@ mod tests {
returns: vec![out_0],
};
let p = UnconstrainedVariableDetector::detect(p);
assert!(p.is_ok());
let result = UnconstrainedVariableDetector::detect(&p);
assert_eq!(result, Ok(()));
}
}