make tests pass
This commit is contained in:
parent
212d06ec76
commit
5d6f29cb4e
6 changed files with 65 additions and 50 deletions
|
@ -1122,7 +1122,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
fn check_generic_expression(
|
||||
&mut self,
|
||||
expr: ExpressionNode<'ast>,
|
||||
constants_map: &HashMap<&'ast str, Type<'ast, T>>,
|
||||
module_id: &ModuleId,
|
||||
constants_map: &HashMap<ConstantIdentifier<'ast>, Type<'ast, T>>,
|
||||
generics_map: &HashMap<Identifier<'ast>, usize>,
|
||||
) -> Result<DeclarationConstant<'ast>, ErrorInner> {
|
||||
let pos = expr.pos();
|
||||
|
@ -1148,7 +1149,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
match (constants_map.get(name), generics_map.get(&name)) {
|
||||
(Some(ty), None) => {
|
||||
match ty {
|
||||
Type::Uint(UBitwidth::B32) => Ok(DeclarationConstant::Constant(name)),
|
||||
Type::Uint(UBitwidth::B32) => Ok(DeclarationConstant::Constant(CanonicalConstantIdentifier::new(name, module_id.into()))),
|
||||
_ => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
|
@ -1192,6 +1193,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
UnresolvedType::Array(t, size) => {
|
||||
let checked_size = self.check_generic_expression(
|
||||
size.clone(),
|
||||
module_id,
|
||||
state.constants.get(module_id).unwrap_or(&HashMap::new()),
|
||||
generics_map,
|
||||
)?;
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
use crate::static_analysis::Propagator;
|
||||
use crate::typed_absy::folder::*;
|
||||
use crate::typed_absy::result_folder::ResultFolder;
|
||||
use crate::typed_absy::types::DeclarationConstant;
|
||||
use crate::typed_absy::*;
|
||||
use core::str;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryInto;
|
||||
use zokrates_field::Field;
|
||||
|
||||
type ModuleConstants<'ast, T> =
|
||||
HashMap<OwnedTypedModuleId, HashMap<&'ast str, TypedConstant<'ast, T>>>;
|
||||
HashMap<OwnedTypedModuleId, HashMap<Identifier<'ast>, TypedExpression<'ast, T>>>;
|
||||
|
||||
pub struct ConstantInliner<'ast, T> {
|
||||
modules: TypedModules<'ast, T>,
|
||||
|
@ -43,7 +44,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
|
|||
prev
|
||||
}
|
||||
|
||||
fn get_constant(&mut self, id: &Identifier) -> Option<TypedConstant<'ast, T>> {
|
||||
fn get_constant(&mut self, id: &Identifier) -> Option<TypedExpression<'ast, T>> {
|
||||
assert_eq!(id.version, 0);
|
||||
match id.id {
|
||||
CoreIdentifier::Call(..) => {
|
||||
|
@ -52,7 +53,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
|
|||
CoreIdentifier::Source(id) => self
|
||||
.constants
|
||||
.get(&self.location)
|
||||
.and_then(|constants| constants.get(id))
|
||||
.and_then(|constants| constants.get(&id.into()))
|
||||
.cloned(),
|
||||
}
|
||||
}
|
||||
|
@ -85,29 +86,43 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
let constant = match tc {
|
||||
TypedConstantSymbol::There(imported_id) => {
|
||||
if !self.constants.contains_key(&imported_id.module) {
|
||||
let current_m_id = self.change_location(id.module.clone());
|
||||
let _ = self
|
||||
.fold_module(self.modules.get(&id.module).unwrap().clone());
|
||||
let current_m_id =
|
||||
self.change_location(imported_id.module.clone());
|
||||
let _ = self.fold_module(
|
||||
self.modules.get(&imported_id.module).unwrap().clone(),
|
||||
);
|
||||
self.change_location(current_m_id);
|
||||
}
|
||||
self.constants
|
||||
.get(&imported_id.module)
|
||||
.unwrap()
|
||||
.get(&imported_id.id)
|
||||
.get(&imported_id.id.into())
|
||||
.cloned()
|
||||
.unwrap()
|
||||
}
|
||||
TypedConstantSymbol::Here(c) => fold_constant(self, c),
|
||||
TypedConstantSymbol::Here(c) => fold_constant(self, c).expression,
|
||||
};
|
||||
|
||||
let constant = Propagator::with_constants(
|
||||
self.constants.get_mut(&self.location).unwrap(),
|
||||
)
|
||||
.fold_expression(constant)
|
||||
.unwrap();
|
||||
|
||||
assert!(self
|
||||
.constants
|
||||
.entry(self.location.clone())
|
||||
.or_default()
|
||||
.insert(id.id, constant.clone())
|
||||
.insert(id.id.into(), constant.clone())
|
||||
.is_none());
|
||||
|
||||
(id, TypedConstantSymbol::Here(constant))
|
||||
(
|
||||
id,
|
||||
TypedConstantSymbol::Here(TypedConstant {
|
||||
ty: constant.get_type().clone(),
|
||||
expression: constant,
|
||||
}),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
functions: m
|
||||
|
@ -130,28 +145,20 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
&mut self,
|
||||
c: DeclarationConstant<'ast>,
|
||||
) -> DeclarationConstant<'ast> {
|
||||
println!("id {}", c);
|
||||
println!("constants {:#?}", self.constants);
|
||||
println!("location {}", self.location.display());
|
||||
|
||||
match c {
|
||||
DeclarationConstant::Constant(id) => DeclarationConstant::Concrete(
|
||||
match self
|
||||
.constants
|
||||
.get(&self.location)
|
||||
.get(&id.module)
|
||||
.unwrap()
|
||||
.get(&id)
|
||||
.get(&id.id.into())
|
||||
.cloned()
|
||||
.unwrap()
|
||||
{
|
||||
TypedConstant {
|
||||
ty: Type::Uint(UBitwidth::B32),
|
||||
expression:
|
||||
TypedExpression::Uint(UExpression {
|
||||
inner: UExpressionInner::Value(v),
|
||||
..
|
||||
}),
|
||||
} => v as u32,
|
||||
TypedExpression::Uint(UExpression {
|
||||
inner: UExpressionInner::Value(v),
|
||||
..
|
||||
}) => v as u32,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
),
|
||||
|
@ -165,7 +172,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
) -> FieldElementExpression<'ast, T> {
|
||||
match e {
|
||||
FieldElementExpression::Identifier(ref id) => match self.get_constant(id) {
|
||||
Some(c) => self.fold_constant(c).try_into().unwrap(),
|
||||
Some(c) => self.fold_expression(c).try_into().unwrap(),
|
||||
None => fold_field_expression(self, e),
|
||||
},
|
||||
e => fold_field_expression(self, e),
|
||||
|
@ -178,7 +185,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
) -> BooleanExpression<'ast, T> {
|
||||
match e {
|
||||
BooleanExpression::Identifier(ref id) => match self.get_constant(id) {
|
||||
Some(c) => self.fold_constant(c).try_into().unwrap(),
|
||||
Some(c) => self.fold_expression(c).try_into().unwrap(),
|
||||
None => fold_boolean_expression(self, e),
|
||||
},
|
||||
e => fold_boolean_expression(self, e),
|
||||
|
@ -193,7 +200,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
match e {
|
||||
UExpressionInner::Identifier(ref id) => match self.get_constant(id) {
|
||||
Some(c) => {
|
||||
let e: UExpression<'ast, T> = self.fold_constant(c).try_into().unwrap();
|
||||
let e: UExpression<'ast, T> = self.fold_expression(c).try_into().unwrap();
|
||||
e.into_inner()
|
||||
}
|
||||
None => fold_uint_expression_inner(self, size, e),
|
||||
|
@ -210,7 +217,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
match e {
|
||||
ArrayExpressionInner::Identifier(ref id) => match self.get_constant(id) {
|
||||
Some(c) => {
|
||||
let e: ArrayExpression<'ast, T> = self.fold_constant(c).try_into().unwrap();
|
||||
let e: ArrayExpression<'ast, T> = self.fold_expression(c).try_into().unwrap();
|
||||
e.into_inner()
|
||||
}
|
||||
None => fold_array_expression_inner(self, ty, e),
|
||||
|
@ -227,7 +234,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
|||
match e {
|
||||
StructExpressionInner::Identifier(ref id) => match self.get_constant(id) {
|
||||
Some(c) => {
|
||||
let e: StructExpression<'ast, T> = self.fold_constant(c).try_into().unwrap();
|
||||
let e: StructExpression<'ast, T> = self.fold_expression(c).try_into().unwrap();
|
||||
e.into_inner()
|
||||
}
|
||||
None => fold_struct_expression_inner(self, ty, e),
|
||||
|
@ -266,7 +273,7 @@ mod tests {
|
|||
};
|
||||
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
const_id,
|
||||
CanonicalConstantIdentifier::new(const_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(Bn128Field::from(1))),
|
||||
|
@ -354,7 +361,7 @@ mod tests {
|
|||
};
|
||||
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
const_id,
|
||||
CanonicalConstantIdentifier::new(const_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::Boolean,
|
||||
TypedExpression::Boolean(BooleanExpression::Value(true)),
|
||||
|
@ -443,7 +450,7 @@ mod tests {
|
|||
};
|
||||
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
const_id,
|
||||
CanonicalConstantIdentifier::new(const_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::Uint(UBitwidth::B32),
|
||||
UExpressionInner::Value(1u128)
|
||||
|
@ -544,7 +551,7 @@ mod tests {
|
|||
};
|
||||
|
||||
let constants: TypedConstantSymbols<_> = vec![(
|
||||
const_id,
|
||||
CanonicalConstantIdentifier::new(const_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::Array(
|
||||
|
@ -683,7 +690,7 @@ mod tests {
|
|||
.collect(),
|
||||
constants: vec![
|
||||
(
|
||||
const_a_id,
|
||||
CanonicalConstantIdentifier::new(const_a_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
|
@ -692,7 +699,7 @@ mod tests {
|
|||
)),
|
||||
),
|
||||
(
|
||||
const_b_id,
|
||||
CanonicalConstantIdentifier::new(const_b_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Add(
|
||||
|
@ -741,7 +748,7 @@ mod tests {
|
|||
.collect(),
|
||||
constants: vec![
|
||||
(
|
||||
const_a_id,
|
||||
CanonicalConstantIdentifier::new(const_a_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
|
@ -750,7 +757,7 @@ mod tests {
|
|||
)),
|
||||
),
|
||||
(
|
||||
const_b_id,
|
||||
CanonicalConstantIdentifier::new(const_b_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
|
@ -802,7 +809,7 @@ mod tests {
|
|||
.into_iter()
|
||||
.collect(),
|
||||
constants: vec![(
|
||||
foo_const_id,
|
||||
CanonicalConstantIdentifier::new(foo_const_id, "foo".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
|
@ -834,8 +841,11 @@ mod tests {
|
|||
.into_iter()
|
||||
.collect(),
|
||||
constants: vec![(
|
||||
foo_const_id,
|
||||
TypedConstantSymbol::There(OwnedTypedModuleId::from("foo"), foo_const_id),
|
||||
CanonicalConstantIdentifier::new(foo_const_id, "main".into()),
|
||||
TypedConstantSymbol::There(CanonicalConstantIdentifier::new(
|
||||
foo_const_id,
|
||||
"foo".into(),
|
||||
)),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
|
@ -872,7 +882,7 @@ mod tests {
|
|||
.into_iter()
|
||||
.collect(),
|
||||
constants: vec![(
|
||||
foo_const_id,
|
||||
CanonicalConstantIdentifier::new(foo_const_id, "main".into()),
|
||||
TypedConstantSymbol::Here(TypedConstant::new(
|
||||
GType::FieldElement,
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(
|
||||
|
|
|
@ -78,13 +78,13 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
|
|||
pub fn analyse(self, config: &CompileConfig) -> Result<(ZirProgram<'ast, T>, Abi), Error> {
|
||||
// inline user-defined constants
|
||||
let r = ConstantInliner::inline(self);
|
||||
println!("{}", r);
|
||||
// isolate branches
|
||||
let r = if config.isolate_branches {
|
||||
Isolator::isolate(r)
|
||||
} else {
|
||||
r
|
||||
};
|
||||
|
||||
// reduce the program to a single function
|
||||
let r = reduce_program(r).map_err(Error::from)?;
|
||||
// generate abi
|
||||
|
|
|
@ -1024,7 +1024,10 @@ pub fn fold_function_symbol<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
) -> TypedFunctionSymbol<'ast, T> {
|
||||
match s {
|
||||
TypedFunctionSymbol::Here(fun) => TypedFunctionSymbol::Here(f.fold_function(fun)),
|
||||
there => there, // by default, do not fold modules recursively
|
||||
TypedFunctionSymbol::There(key) => {
|
||||
TypedFunctionSymbol::There(f.fold_declaration_function_key(key))
|
||||
}
|
||||
s => s,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -119,7 +119,7 @@ impl<'ast> CanonicalConstantIdentifier<'ast> {
|
|||
pub enum DeclarationConstant<'ast> {
|
||||
Generic(GenericIdentifier<'ast>),
|
||||
Concrete(u32),
|
||||
Constant(ConstantIdentifier<'ast>),
|
||||
Constant(CanonicalConstantIdentifier<'ast>),
|
||||
}
|
||||
|
||||
impl<'ast> From<u32> for DeclarationConstant<'ast> {
|
||||
|
@ -145,7 +145,7 @@ impl<'ast> fmt::Display for DeclarationConstant<'ast> {
|
|||
match self {
|
||||
DeclarationConstant::Generic(i) => write!(f, "{}", i),
|
||||
DeclarationConstant::Concrete(v) => write!(f, "{}", v),
|
||||
DeclarationConstant::Constant(v) => write!(f, "{}", v),
|
||||
DeclarationConstant::Constant(v) => write!(f, "{}/{}", v.module.display(), v.id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -166,7 +166,7 @@ impl<'ast, T> From<DeclarationConstant<'ast>> for UExpression<'ast, T> {
|
|||
UExpressionInner::Value(v as u128).annotate(UBitwidth::B32)
|
||||
}
|
||||
DeclarationConstant::Constant(v) => {
|
||||
UExpressionInner::Identifier(Identifier::from(v)).annotate(UBitwidth::B32)
|
||||
UExpressionInner::Identifier(Identifier::from(v.id)).annotate(UBitwidth::B32)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
const u32 N = 1
|
||||
const u32 N = 1 + 1
|
||||
def foo(field[N] a) -> bool:
|
||||
return true
|
Loading…
Reference in a new issue