1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

make tests pass

This commit is contained in:
schaeff 2021-06-08 16:56:29 +02:00
parent 212d06ec76
commit 5d6f29cb4e
6 changed files with 65 additions and 50 deletions

View file

@ -1122,7 +1122,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
fn check_generic_expression(
&mut self,
expr: ExpressionNode<'ast>,
constants_map: &HashMap<&'ast str, Type<'ast, T>>,
module_id: &ModuleId,
constants_map: &HashMap<ConstantIdentifier<'ast>, Type<'ast, T>>,
generics_map: &HashMap<Identifier<'ast>, usize>,
) -> Result<DeclarationConstant<'ast>, ErrorInner> {
let pos = expr.pos();
@ -1148,7 +1149,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
match (constants_map.get(name), generics_map.get(&name)) {
(Some(ty), None) => {
match ty {
Type::Uint(UBitwidth::B32) => Ok(DeclarationConstant::Constant(name)),
Type::Uint(UBitwidth::B32) => Ok(DeclarationConstant::Constant(CanonicalConstantIdentifier::new(name, module_id.into()))),
_ => Err(ErrorInner {
pos: Some(pos),
message: format!(
@ -1192,6 +1193,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
UnresolvedType::Array(t, size) => {
let checked_size = self.check_generic_expression(
size.clone(),
module_id,
state.constants.get(module_id).unwrap_or(&HashMap::new()),
generics_map,
)?;

View file

@ -1,13 +1,14 @@
use crate::static_analysis::Propagator;
use crate::typed_absy::folder::*;
use crate::typed_absy::result_folder::ResultFolder;
use crate::typed_absy::types::DeclarationConstant;
use crate::typed_absy::*;
use core::str;
use std::collections::HashMap;
use std::convert::TryInto;
use zokrates_field::Field;
type ModuleConstants<'ast, T> =
HashMap<OwnedTypedModuleId, HashMap<&'ast str, TypedConstant<'ast, T>>>;
HashMap<OwnedTypedModuleId, HashMap<Identifier<'ast>, TypedExpression<'ast, T>>>;
pub struct ConstantInliner<'ast, T> {
modules: TypedModules<'ast, T>,
@ -43,7 +44,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
prev
}
fn get_constant(&mut self, id: &Identifier) -> Option<TypedConstant<'ast, T>> {
fn get_constant(&mut self, id: &Identifier) -> Option<TypedExpression<'ast, T>> {
assert_eq!(id.version, 0);
match id.id {
CoreIdentifier::Call(..) => {
@ -52,7 +53,7 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
CoreIdentifier::Source(id) => self
.constants
.get(&self.location)
.and_then(|constants| constants.get(id))
.and_then(|constants| constants.get(&id.into()))
.cloned(),
}
}
@ -85,29 +86,43 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
let constant = match tc {
TypedConstantSymbol::There(imported_id) => {
if !self.constants.contains_key(&imported_id.module) {
let current_m_id = self.change_location(id.module.clone());
let _ = self
.fold_module(self.modules.get(&id.module).unwrap().clone());
let current_m_id =
self.change_location(imported_id.module.clone());
let _ = self.fold_module(
self.modules.get(&imported_id.module).unwrap().clone(),
);
self.change_location(current_m_id);
}
self.constants
.get(&imported_id.module)
.unwrap()
.get(&imported_id.id)
.get(&imported_id.id.into())
.cloned()
.unwrap()
}
TypedConstantSymbol::Here(c) => fold_constant(self, c),
TypedConstantSymbol::Here(c) => fold_constant(self, c).expression,
};
let constant = Propagator::with_constants(
self.constants.get_mut(&self.location).unwrap(),
)
.fold_expression(constant)
.unwrap();
assert!(self
.constants
.entry(self.location.clone())
.or_default()
.insert(id.id, constant.clone())
.insert(id.id.into(), constant.clone())
.is_none());
(id, TypedConstantSymbol::Here(constant))
(
id,
TypedConstantSymbol::Here(TypedConstant {
ty: constant.get_type().clone(),
expression: constant,
}),
)
})
.collect(),
functions: m
@ -130,28 +145,20 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
&mut self,
c: DeclarationConstant<'ast>,
) -> DeclarationConstant<'ast> {
println!("id {}", c);
println!("constants {:#?}", self.constants);
println!("location {}", self.location.display());
match c {
DeclarationConstant::Constant(id) => DeclarationConstant::Concrete(
match self
.constants
.get(&self.location)
.get(&id.module)
.unwrap()
.get(&id)
.get(&id.id.into())
.cloned()
.unwrap()
{
TypedConstant {
ty: Type::Uint(UBitwidth::B32),
expression:
TypedExpression::Uint(UExpression {
inner: UExpressionInner::Value(v),
..
}),
} => v as u32,
TypedExpression::Uint(UExpression {
inner: UExpressionInner::Value(v),
..
}) => v as u32,
_ => unreachable!(),
},
),
@ -165,7 +172,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
) -> FieldElementExpression<'ast, T> {
match e {
FieldElementExpression::Identifier(ref id) => match self.get_constant(id) {
Some(c) => self.fold_constant(c).try_into().unwrap(),
Some(c) => self.fold_expression(c).try_into().unwrap(),
None => fold_field_expression(self, e),
},
e => fold_field_expression(self, e),
@ -178,7 +185,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::Identifier(ref id) => match self.get_constant(id) {
Some(c) => self.fold_constant(c).try_into().unwrap(),
Some(c) => self.fold_expression(c).try_into().unwrap(),
None => fold_boolean_expression(self, e),
},
e => fold_boolean_expression(self, e),
@ -193,7 +200,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
match e {
UExpressionInner::Identifier(ref id) => match self.get_constant(id) {
Some(c) => {
let e: UExpression<'ast, T> = self.fold_constant(c).try_into().unwrap();
let e: UExpression<'ast, T> = self.fold_expression(c).try_into().unwrap();
e.into_inner()
}
None => fold_uint_expression_inner(self, size, e),
@ -210,7 +217,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
match e {
ArrayExpressionInner::Identifier(ref id) => match self.get_constant(id) {
Some(c) => {
let e: ArrayExpression<'ast, T> = self.fold_constant(c).try_into().unwrap();
let e: ArrayExpression<'ast, T> = self.fold_expression(c).try_into().unwrap();
e.into_inner()
}
None => fold_array_expression_inner(self, ty, e),
@ -227,7 +234,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
match e {
StructExpressionInner::Identifier(ref id) => match self.get_constant(id) {
Some(c) => {
let e: StructExpression<'ast, T> = self.fold_constant(c).try_into().unwrap();
let e: StructExpression<'ast, T> = self.fold_expression(c).try_into().unwrap();
e.into_inner()
}
None => fold_struct_expression_inner(self, ty, e),
@ -266,7 +273,7 @@ mod tests {
};
let constants: TypedConstantSymbols<_> = vec![(
const_id,
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(Bn128Field::from(1))),
@ -354,7 +361,7 @@ mod tests {
};
let constants: TypedConstantSymbols<_> = vec![(
const_id,
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::Boolean,
TypedExpression::Boolean(BooleanExpression::Value(true)),
@ -443,7 +450,7 @@ mod tests {
};
let constants: TypedConstantSymbols<_> = vec![(
const_id,
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::Uint(UBitwidth::B32),
UExpressionInner::Value(1u128)
@ -544,7 +551,7 @@ mod tests {
};
let constants: TypedConstantSymbols<_> = vec![(
const_id,
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::Array(
@ -683,7 +690,7 @@ mod tests {
.collect(),
constants: vec![
(
const_a_id,
CanonicalConstantIdentifier::new(const_a_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
@ -692,7 +699,7 @@ mod tests {
)),
),
(
const_b_id,
CanonicalConstantIdentifier::new(const_b_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Add(
@ -741,7 +748,7 @@ mod tests {
.collect(),
constants: vec![
(
const_a_id,
CanonicalConstantIdentifier::new(const_a_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
@ -750,7 +757,7 @@ mod tests {
)),
),
(
const_b_id,
CanonicalConstantIdentifier::new(const_b_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
@ -802,7 +809,7 @@ mod tests {
.into_iter()
.collect(),
constants: vec![(
foo_const_id,
CanonicalConstantIdentifier::new(foo_const_id, "foo".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
@ -834,8 +841,11 @@ mod tests {
.into_iter()
.collect(),
constants: vec![(
foo_const_id,
TypedConstantSymbol::There(OwnedTypedModuleId::from("foo"), foo_const_id),
CanonicalConstantIdentifier::new(foo_const_id, "main".into()),
TypedConstantSymbol::There(CanonicalConstantIdentifier::new(
foo_const_id,
"foo".into(),
)),
)]
.into_iter()
.collect(),
@ -872,7 +882,7 @@ mod tests {
.into_iter()
.collect(),
constants: vec![(
foo_const_id,
CanonicalConstantIdentifier::new(foo_const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(

View file

@ -78,13 +78,13 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
pub fn analyse(self, config: &CompileConfig) -> Result<(ZirProgram<'ast, T>, Abi), Error> {
// inline user-defined constants
let r = ConstantInliner::inline(self);
println!("{}", r);
// isolate branches
let r = if config.isolate_branches {
Isolator::isolate(r)
} else {
r
};
// reduce the program to a single function
let r = reduce_program(r).map_err(Error::from)?;
// generate abi

View file

@ -1024,7 +1024,10 @@ pub fn fold_function_symbol<'ast, T: Field, F: Folder<'ast, T>>(
) -> TypedFunctionSymbol<'ast, T> {
match s {
TypedFunctionSymbol::Here(fun) => TypedFunctionSymbol::Here(f.fold_function(fun)),
there => there, // by default, do not fold modules recursively
TypedFunctionSymbol::There(key) => {
TypedFunctionSymbol::There(f.fold_declaration_function_key(key))
}
s => s,
}
}

View file

@ -119,7 +119,7 @@ impl<'ast> CanonicalConstantIdentifier<'ast> {
pub enum DeclarationConstant<'ast> {
Generic(GenericIdentifier<'ast>),
Concrete(u32),
Constant(ConstantIdentifier<'ast>),
Constant(CanonicalConstantIdentifier<'ast>),
}
impl<'ast> From<u32> for DeclarationConstant<'ast> {
@ -145,7 +145,7 @@ impl<'ast> fmt::Display for DeclarationConstant<'ast> {
match self {
DeclarationConstant::Generic(i) => write!(f, "{}", i),
DeclarationConstant::Concrete(v) => write!(f, "{}", v),
DeclarationConstant::Constant(v) => write!(f, "{}", v),
DeclarationConstant::Constant(v) => write!(f, "{}/{}", v.module.display(), v.id),
}
}
}
@ -166,7 +166,7 @@ impl<'ast, T> From<DeclarationConstant<'ast>> for UExpression<'ast, T> {
UExpressionInner::Value(v as u128).annotate(UBitwidth::B32)
}
DeclarationConstant::Constant(v) => {
UExpressionInner::Identifier(Identifier::from(v)).annotate(UBitwidth::B32)
UExpressionInner::Identifier(Identifier::from(v.id)).annotate(UBitwidth::B32)
}
}
}

View file

@ -1,3 +1,3 @@
const u32 N = 1
const u32 N = 1 + 1
def foo(field[N] a) -> bool:
return true