wip
This commit is contained in:
parent
3d05d7386b
commit
46cc73d735
13 changed files with 313 additions and 192 deletions
|
@ -4,7 +4,8 @@ use crate::flat_absy::{
|
|||
};
|
||||
use crate::solvers::Solver;
|
||||
use crate::typed_absy::types::{
|
||||
ConcreteGenericsAssignment, Constant, DeclarationSignature, DeclarationType, GenericIdentifier,
|
||||
ConcreteGenericsAssignment, DeclarationConstant, DeclarationSignature, DeclarationType,
|
||||
GenericIdentifier,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use zokrates_field::{Bn128Field, Field};
|
||||
|
@ -43,10 +44,12 @@ impl FlatEmbed {
|
|||
.inputs(vec![DeclarationType::uint(32)])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
FlatEmbed::Unpack => DeclarationSignature::new()
|
||||
.generics(vec![Some(Constant::Generic(GenericIdentifier {
|
||||
.generics(vec![Some(DeclarationConstant::Generic(
|
||||
GenericIdentifier {
|
||||
name: "N",
|
||||
index: 0,
|
||||
}))])
|
||||
},
|
||||
))])
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::array((
|
||||
DeclarationType::Boolean,
|
||||
|
@ -122,7 +125,7 @@ impl FlatEmbed {
|
|||
.generics
|
||||
.into_iter()
|
||||
.map(|c| match c.unwrap() {
|
||||
Constant::Generic(g) => g,
|
||||
DeclarationConstant::Generic(g) => g,
|
||||
_ => unreachable!(),
|
||||
});
|
||||
|
||||
|
|
|
@ -19,9 +19,9 @@ use crate::parser::Position;
|
|||
use crate::absy::types::{UnresolvedSignature, UnresolvedType, UserTypeId};
|
||||
|
||||
use crate::typed_absy::types::{
|
||||
ArrayType, Constant, DeclarationArrayType, DeclarationFunctionKey, DeclarationSignature,
|
||||
DeclarationStructMember, DeclarationStructType, DeclarationType, GenericIdentifier,
|
||||
StructLocation,
|
||||
ArrayType, DeclarationArrayType, DeclarationConstant, DeclarationFunctionKey,
|
||||
DeclarationSignature, DeclarationStructMember, DeclarationStructType, DeclarationType,
|
||||
GenericIdentifier, StructLocation,
|
||||
};
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
|
@ -55,8 +55,7 @@ impl ErrorInner {
|
|||
}
|
||||
|
||||
type TypeMap<'ast> = HashMap<OwnedModuleId, HashMap<UserTypeId, DeclarationType<'ast>>>;
|
||||
type ConstantMap<'ast, T> =
|
||||
HashMap<OwnedModuleId, HashMap<ConstantIdentifier<'ast>, Type<'ast, T>>>;
|
||||
type ConstantMap<'ast, T> = HashMap<OwnedModuleId, HashMap<&'ast str, Type<'ast, T>>>;
|
||||
|
||||
/// The global state of the program during semantic checks
|
||||
#[derive(Debug)]
|
||||
|
@ -506,8 +505,10 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
.in_file(module_id),
|
||||
),
|
||||
true => {
|
||||
constants
|
||||
.insert(declaration.id, TypedConstantSymbol::Here(c.clone()));
|
||||
constants.insert(
|
||||
ConstantIdentifier::new(declaration.id, module_id.into()),
|
||||
TypedConstantSymbol::Here(c.clone()),
|
||||
);
|
||||
self.insert_into_scope(Variable::with_id_and_type(
|
||||
declaration.id,
|
||||
c.get_type(),
|
||||
|
@ -600,7 +601,9 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
.constants
|
||||
.entry(import.module_id.to_path_buf())
|
||||
.or_default()
|
||||
.get(import.symbol_id)
|
||||
.iter()
|
||||
.find(|(i, _)| *i == &import.symbol_id)
|
||||
.map(|(_, c)| c)
|
||||
.cloned();
|
||||
|
||||
match (function_candidates.len(), type_candidate, const_candidate) {
|
||||
|
@ -653,8 +656,11 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
}});
|
||||
}
|
||||
true => {
|
||||
constants.insert(declaration.id, TypedConstantSymbol::There(import.module_id.to_path_buf(), import.symbol_id));
|
||||
self.insert_into_scope(Variable::with_id_and_type(declaration.id, ty.clone()));
|
||||
let imported_id = ConstantIdentifier::new(import.symbol_id.into(), import.module_id);
|
||||
let id = ConstantIdentifier::new(declaration.id.clone(), module_id.into());
|
||||
|
||||
constants.insert(id.clone(), TypedConstantSymbol::There(imported_id));
|
||||
self.insert_into_scope(Variable::with_id_and_type(declaration.id.clone(), ty.clone()));
|
||||
|
||||
state
|
||||
.constants
|
||||
|
@ -856,7 +862,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
|
||||
let v = Variable::with_id_and_type(
|
||||
match generic {
|
||||
Constant::Generic(g) => g.name,
|
||||
DeclarationConstant::Generic(g) => g.name,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
Type::Uint(UBitwidth::B32),
|
||||
|
@ -996,7 +1002,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
} else {
|
||||
match generics_map.insert(g.value, index).is_none() {
|
||||
true => {
|
||||
generics.push(Some(Constant::Generic(GenericIdentifier {
|
||||
generics.push(Some(DeclarationConstant::Generic(GenericIdentifier {
|
||||
name: g.value,
|
||||
index,
|
||||
})));
|
||||
|
@ -1112,16 +1118,16 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
fn check_generic_expression(
|
||||
&mut self,
|
||||
expr: ExpressionNode<'ast>,
|
||||
constants_map: &HashMap<ConstantIdentifier<'ast>, Type<'ast, T>>,
|
||||
constants_map: &HashMap<&'ast str, Type<'ast, T>>,
|
||||
generics_map: &HashMap<Identifier<'ast>, usize>,
|
||||
) -> Result<Constant<'ast>, ErrorInner> {
|
||||
) -> Result<DeclarationConstant<'ast>, ErrorInner> {
|
||||
let pos = expr.pos();
|
||||
|
||||
match expr.value {
|
||||
Expression::U32Constant(c) => Ok(Constant::Concrete(c)),
|
||||
Expression::U32Constant(c) => Ok(DeclarationConstant::Concrete(c)),
|
||||
Expression::IntConstant(c) => {
|
||||
if c <= BigUint::from(2u128.pow(32) - 1) {
|
||||
Ok(Constant::Concrete(
|
||||
Ok(DeclarationConstant::Concrete(
|
||||
u32::from_str_radix(&c.to_str_radix(16), 16).unwrap(),
|
||||
))
|
||||
} else {
|
||||
|
@ -1138,7 +1144,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
match (constants_map.get(name), generics_map.get(&name)) {
|
||||
(Some(ty), None) => {
|
||||
match ty {
|
||||
Type::Uint(UBitwidth::B32) => Ok(Constant::Identifier(name, 32usize)),
|
||||
Type::Uint(UBitwidth::B32) => Ok(DeclarationConstant::Identifier(name)),
|
||||
_ => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
|
@ -1148,7 +1154,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
})
|
||||
}
|
||||
}
|
||||
(None, Some(index)) => Ok(Constant::Generic(GenericIdentifier { name, index: *index })),
|
||||
(None, Some(index)) => Ok(DeclarationConstant::Generic(GenericIdentifier { name, index: *index })),
|
||||
_ => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
message: format!("Undeclared symbol `{}` in function definition", name)
|
||||
|
|
|
@ -1,43 +1,41 @@
|
|||
use crate::static_analysis::propagation::Propagator;
|
||||
use crate::typed_absy::folder::*;
|
||||
use crate::typed_absy::result_folder::ResultFolder;
|
||||
use crate::typed_absy::types::{Constant, DeclarationStructType, GStructMember};
|
||||
use crate::typed_absy::types::DeclarationConstant;
|
||||
use crate::typed_absy::*;
|
||||
use core::str;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryInto;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub struct ConstantInliner<'ast, 'a, T: Field> {
|
||||
type ModuleConstants<'ast, T> =
|
||||
HashMap<OwnedTypedModuleId, HashMap<&'ast str, TypedConstant<'ast, T>>>;
|
||||
|
||||
pub struct ConstantInliner<'ast, T> {
|
||||
modules: TypedModules<'ast, T>,
|
||||
location: OwnedTypedModuleId,
|
||||
propagator: Propagator<'ast, 'a, T>,
|
||||
constants: ModuleConstants<'ast, T>,
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> ConstantInliner<'ast, 'a, T> {
|
||||
impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
|
||||
pub fn new(
|
||||
modules: TypedModules<'ast, T>,
|
||||
location: OwnedTypedModuleId,
|
||||
propagator: Propagator<'ast, 'a, T>,
|
||||
constants: ModuleConstants<'ast, T>,
|
||||
) -> Self {
|
||||
ConstantInliner {
|
||||
modules,
|
||||
location,
|
||||
propagator,
|
||||
constants,
|
||||
}
|
||||
}
|
||||
pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
let mut constants = HashMap::new();
|
||||
let mut inliner = ConstantInliner::new(
|
||||
p.modules.clone(),
|
||||
p.main.clone(),
|
||||
Propagator::with_constants(&mut constants),
|
||||
);
|
||||
let constants = HashMap::new();
|
||||
let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone(), constants);
|
||||
inliner.fold_program(p)
|
||||
}
|
||||
|
||||
fn module(&self) -> &TypedModule<'ast, T> {
|
||||
self.modules.get(&self.location).unwrap()
|
||||
}
|
||||
// fn module(&self) -> &TypedModule<'ast, T> {
|
||||
// self.modules.get(&self.location).unwrap()
|
||||
// }
|
||||
|
||||
fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId {
|
||||
let prev = self.location.clone();
|
||||
|
@ -46,116 +44,119 @@ impl<'ast, 'a, T: Field> ConstantInliner<'ast, 'a, T> {
|
|||
}
|
||||
|
||||
fn get_constant(&mut self, id: &Identifier) -> Option<TypedConstant<'ast, T>> {
|
||||
self.modules
|
||||
.get(&self.location)
|
||||
.unwrap()
|
||||
assert_eq!(id.version, 0);
|
||||
match id.id {
|
||||
CoreIdentifier::Call(..) => {
|
||||
unreachable!("calls indentifiers are only available after call inlining")
|
||||
}
|
||||
CoreIdentifier::Source(id) => self
|
||||
.constants
|
||||
.get(id.clone().try_into().unwrap())
|
||||
.cloned()
|
||||
.map(|symbol| self.get_canonical_constant(symbol))
|
||||
}
|
||||
|
||||
fn get_canonical_constant(
|
||||
&mut self,
|
||||
symbol: TypedConstantSymbol<'ast, T>,
|
||||
) -> TypedConstant<'ast, T> {
|
||||
match symbol {
|
||||
TypedConstantSymbol::There(module_id, id) => {
|
||||
let location = self.change_location(module_id);
|
||||
let symbol = self.module().constants.get(id).cloned().unwrap();
|
||||
|
||||
let symbol = self.get_canonical_constant(symbol);
|
||||
let _ = self.change_location(location);
|
||||
symbol
|
||||
}
|
||||
TypedConstantSymbol::Here(tc) => {
|
||||
let tc: TypedConstant<T> = self.fold_constant(tc);
|
||||
TypedConstant {
|
||||
expression: self.propagator.fold_expression(tc.expression).unwrap(),
|
||||
..tc
|
||||
}
|
||||
}
|
||||
.get(&self.location)
|
||||
.and_then(|constants| constants.get(id))
|
||||
.cloned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> Folder<'ast, T> for ConstantInliner<'ast, 'a, T> {
|
||||
impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
|
||||
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
TypedProgram {
|
||||
modules: p
|
||||
.modules
|
||||
.into_iter()
|
||||
.map(|(module_id, module)| {
|
||||
self.change_location(module_id.clone());
|
||||
(module_id, self.fold_module(module))
|
||||
.map(|(m_id, m)| {
|
||||
self.change_location(m_id.clone());
|
||||
(m_id, self.fold_module(m))
|
||||
})
|
||||
.collect(),
|
||||
main: p.main,
|
||||
..p
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_declaration_type(&mut self, t: DeclarationType<'ast>) -> DeclarationType<'ast> {
|
||||
match t {
|
||||
DeclarationType::Array(ref array_ty) => match array_ty.size {
|
||||
Constant::Identifier(name, _) => {
|
||||
let tc = self.get_constant(&name.into()).unwrap();
|
||||
let expression: UExpression<'ast, T> = tc.expression.try_into().unwrap();
|
||||
match expression.inner {
|
||||
UExpressionInner::Value(v) => DeclarationType::array((
|
||||
self.fold_declaration_type(*array_ty.ty.clone()),
|
||||
Constant::Concrete(v as u32),
|
||||
)),
|
||||
_ => unreachable!("expected u32 value"),
|
||||
}
|
||||
}
|
||||
_ => t,
|
||||
},
|
||||
DeclarationType::Struct(struct_ty) => DeclarationType::struc(DeclarationStructType {
|
||||
members: struct_ty
|
||||
.members
|
||||
fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> {
|
||||
// only treat this module if its constants are not in the map yet
|
||||
if !self.constants.contains_key(&self.location) {
|
||||
self.constants.entry(self.location.clone()).or_default();
|
||||
TypedModule {
|
||||
constants: m
|
||||
.constants
|
||||
.into_iter()
|
||||
.map(|m| GStructMember::new(m.id, self.fold_declaration_type(*m.ty)))
|
||||
.collect(),
|
||||
..struct_ty
|
||||
}),
|
||||
_ => t,
|
||||
.map(|(id, tc)| {
|
||||
let constant = match tc {
|
||||
TypedConstantSymbol::There(imported_id) => {
|
||||
if !self.constants.contains_key(&imported_id.module) {
|
||||
let current_m_id = self.change_location(id.module.clone());
|
||||
let _ = self
|
||||
.fold_module(self.modules.get(&id.module).unwrap().clone());
|
||||
self.change_location(current_m_id);
|
||||
}
|
||||
self.constants
|
||||
.get(&imported_id.module)
|
||||
.unwrap()
|
||||
.get(&imported_id.id)
|
||||
.cloned()
|
||||
.unwrap()
|
||||
}
|
||||
TypedConstantSymbol::Here(c) => fold_constant(self, c),
|
||||
};
|
||||
|
||||
fn fold_type(&mut self, t: Type<'ast, T>) -> Type<'ast, T> {
|
||||
use self::GType::*;
|
||||
match t {
|
||||
Array(ref array_type) => match &array_type.size.inner {
|
||||
UExpressionInner::Identifier(v) => match self.get_constant(v) {
|
||||
Some(tc) => {
|
||||
let expression: UExpression<'ast, T> = tc.expression.try_into().unwrap();
|
||||
Type::array(GArrayType::new(
|
||||
self.fold_type(*array_type.ty.clone()),
|
||||
expression,
|
||||
))
|
||||
}
|
||||
None => t,
|
||||
},
|
||||
_ => t,
|
||||
},
|
||||
Struct(struct_type) => Type::struc(GStructType {
|
||||
members: struct_type
|
||||
.members
|
||||
assert!(self
|
||||
.constants
|
||||
.entry(self.location.clone())
|
||||
.or_default()
|
||||
.insert(id.id, constant.clone())
|
||||
.is_none());
|
||||
|
||||
(id, TypedConstantSymbol::Here(constant))
|
||||
})
|
||||
.collect(),
|
||||
functions: m
|
||||
.functions
|
||||
.into_iter()
|
||||
.map(|m| GStructMember::new(m.id, self.fold_type(*m.ty)))
|
||||
.map(|(key, fun)| {
|
||||
(
|
||||
self.fold_declaration_function_key(key),
|
||||
self.fold_function_symbol(fun),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
..struct_type
|
||||
}),
|
||||
_ => t,
|
||||
}
|
||||
} else {
|
||||
m
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_constant_symbol(
|
||||
fn fold_declaration_constant(
|
||||
&mut self,
|
||||
s: TypedConstantSymbol<'ast, T>,
|
||||
) -> TypedConstantSymbol<'ast, T> {
|
||||
let tc = self.get_canonical_constant(s);
|
||||
TypedConstantSymbol::Here(tc)
|
||||
c: DeclarationConstant<'ast>,
|
||||
) -> DeclarationConstant<'ast> {
|
||||
println!("id {}", c);
|
||||
println!("constants {:#?}", self.constants);
|
||||
println!("location {}", self.location.display());
|
||||
|
||||
match c {
|
||||
DeclarationConstant::Identifier(id) => DeclarationConstant::Concrete(
|
||||
match self
|
||||
.constants
|
||||
.get(&self.location)
|
||||
.unwrap()
|
||||
.get(&id)
|
||||
.cloned()
|
||||
.unwrap()
|
||||
{
|
||||
TypedConstant {
|
||||
ty: Type::Uint(UBitwidth::B32),
|
||||
expression:
|
||||
TypedExpression::Uint(UExpression {
|
||||
inner: UExpressionInner::Value(v),
|
||||
..
|
||||
}),
|
||||
} => v as u32,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
),
|
||||
c => c,
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_field_expression(
|
||||
|
|
|
@ -78,6 +78,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
|
|||
pub fn analyse(self, config: &CompileConfig) -> Result<(ZirProgram<'ast, T>, Abi), Error> {
|
||||
// inline user-defined constants
|
||||
let r = ConstantInliner::inline(self);
|
||||
println!("{}", r);
|
||||
// isolate branches
|
||||
let r = if config.isolate_branches {
|
||||
Isolator::isolate(r)
|
||||
|
|
|
@ -614,7 +614,7 @@ fn compute_hash<T: Field>(f: &TypedFunction<T>) -> u64 {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::typed_absy::types::Constant;
|
||||
use crate::typed_absy::types::DeclarationConstant;
|
||||
use crate::typed_absy::types::DeclarationSignature;
|
||||
use crate::typed_absy::{
|
||||
ArrayExpression, ArrayExpressionInner, DeclarationFunctionKey, DeclarationType,
|
||||
|
@ -834,11 +834,11 @@ mod tests {
|
|||
)])
|
||||
.inputs(vec![DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
))])
|
||||
.outputs(vec![DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
))]);
|
||||
|
||||
let foo: TypedFunction<Bn128Field> = TypedFunction {
|
||||
|
@ -1053,11 +1053,11 @@ mod tests {
|
|||
)])
|
||||
.inputs(vec![DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
))])
|
||||
.outputs(vec![DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
))]);
|
||||
|
||||
let foo: TypedFunction<Bn128Field> = TypedFunction {
|
||||
|
@ -1285,11 +1285,11 @@ mod tests {
|
|||
let foo_signature = DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
))])
|
||||
.outputs(vec![DeclarationType::array((
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
))])
|
||||
.generics(vec![Some(
|
||||
GenericIdentifier::with_name("K").index(0).into(),
|
||||
|
@ -1299,7 +1299,7 @@ mod tests {
|
|||
arguments: vec![DeclarationVariable::array(
|
||||
"a",
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
)
|
||||
.into()],
|
||||
statements: vec![
|
||||
|
@ -1363,7 +1363,7 @@ mod tests {
|
|||
arguments: vec![DeclarationVariable::array(
|
||||
"a",
|
||||
DeclarationType::FieldElement,
|
||||
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
|
||||
)
|
||||
.into()],
|
||||
statements: vec![TypedStatement::Return(vec![
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Generic walk through a typed AST. Not mutating in place
|
||||
|
||||
use crate::typed_absy::types::{ArrayType, StructMember, StructType};
|
||||
use crate::typed_absy::types::*;
|
||||
use crate::typed_absy::*;
|
||||
use zokrates_field::Field;
|
||||
|
||||
|
@ -80,6 +80,13 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
fold_signature(self, s)
|
||||
}
|
||||
|
||||
fn fold_declaration_constant(
|
||||
&mut self,
|
||||
c: DeclarationConstant<'ast>,
|
||||
) -> DeclarationConstant<'ast> {
|
||||
fold_declaration_constant(self, c)
|
||||
}
|
||||
|
||||
fn fold_parameter(&mut self, p: DeclarationParameter<'ast>) -> DeclarationParameter<'ast> {
|
||||
DeclarationParameter {
|
||||
id: self.fold_declaration_variable(p.id),
|
||||
|
@ -144,7 +151,40 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
}
|
||||
|
||||
fn fold_declaration_type(&mut self, t: DeclarationType<'ast>) -> DeclarationType<'ast> {
|
||||
t
|
||||
use self::GType::*;
|
||||
|
||||
match t {
|
||||
Array(array_type) => Array(self.fold_declaration_array_type(array_type)),
|
||||
Struct(struct_type) => Struct(self.fold_declaration_struct_type(struct_type)),
|
||||
t => t,
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_declaration_array_type(
|
||||
&mut self,
|
||||
t: DeclarationArrayType<'ast>,
|
||||
) -> DeclarationArrayType<'ast> {
|
||||
DeclarationArrayType {
|
||||
ty: box self.fold_declaration_type(*t.ty),
|
||||
size: self.fold_declaration_constant(t.size),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_declaration_struct_type(
|
||||
&mut self,
|
||||
t: DeclarationStructType<'ast>,
|
||||
) -> DeclarationStructType<'ast> {
|
||||
DeclarationStructType {
|
||||
members: t
|
||||
.members
|
||||
.into_iter()
|
||||
.map(|m| DeclarationStructMember {
|
||||
ty: box self.fold_declaration_type(*m.ty),
|
||||
..m
|
||||
})
|
||||
.collect(),
|
||||
..t
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> {
|
||||
|
@ -880,6 +920,13 @@ fn fold_signature<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
}
|
||||
}
|
||||
|
||||
fn fold_declaration_constant<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
_: &mut F,
|
||||
c: DeclarationConstant<'ast>,
|
||||
) -> DeclarationConstant<'ast> {
|
||||
c
|
||||
}
|
||||
|
||||
pub fn fold_array_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
e: ArrayExpression<'ast, T>,
|
||||
|
|
|
@ -19,9 +19,9 @@ mod variable;
|
|||
pub use self::identifier::CoreIdentifier;
|
||||
pub use self::parameter::{DeclarationParameter, GParameter};
|
||||
pub use self::types::{
|
||||
ConcreteFunctionKey, ConcreteSignature, ConcreteType, DeclarationFunctionKey,
|
||||
DeclarationSignature, DeclarationType, GArrayType, GStructType, GType, GenericIdentifier,
|
||||
IntoTypes, Signature, StructType, Type, Types, UBitwidth,
|
||||
ConcreteFunctionKey, ConcreteSignature, ConcreteType, ConstantIdentifier,
|
||||
DeclarationFunctionKey, DeclarationSignature, DeclarationType, GArrayType, GStructType, GType,
|
||||
GenericIdentifier, IntoTypes, Signature, StructType, Type, Types, UBitwidth,
|
||||
};
|
||||
use crate::typed_absy::types::ConcreteGenericsAssignment;
|
||||
|
||||
|
@ -62,12 +62,10 @@ pub type TypedModules<'ast, T> = HashMap<OwnedTypedModuleId, TypedModule<'ast, T
|
|||
pub type TypedFunctionSymbols<'ast, T> =
|
||||
HashMap<DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>>;
|
||||
|
||||
pub type ConstantIdentifier<'ast> = &'ast str;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum TypedConstantSymbol<'ast, T> {
|
||||
Here(TypedConstant<'ast, T>),
|
||||
There(OwnedTypedModuleId, ConstantIdentifier<'ast>),
|
||||
There(ConstantIdentifier<'ast>),
|
||||
}
|
||||
|
||||
/// A collection of `TypedConstantSymbol`s
|
||||
|
@ -188,12 +186,17 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> {
|
|||
let res = self
|
||||
.constants
|
||||
.iter()
|
||||
.map(|(key, symbol)| match symbol {
|
||||
.map(|(id, symbol)| match symbol {
|
||||
TypedConstantSymbol::Here(ref tc) => {
|
||||
format!("const {} {} = {}", tc.ty, key, tc.expression)
|
||||
format!("const {} {} = {}", tc.ty, id.id, tc.expression)
|
||||
}
|
||||
TypedConstantSymbol::There(ref module_id, ref id) => {
|
||||
format!("from \"{}\" import {} as {}", module_id.display(), id, key)
|
||||
TypedConstantSymbol::There(ref imported_id) => {
|
||||
format!(
|
||||
"from \"{}\" import {} as {}",
|
||||
imported_id.module.display(),
|
||||
imported_id.id,
|
||||
id.id
|
||||
)
|
||||
}
|
||||
})
|
||||
.chain(self.functions.iter().map(|(key, symbol)| match symbol {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::typed_absy::types::Constant;
|
||||
use crate::typed_absy::types::DeclarationConstant;
|
||||
use crate::typed_absy::GVariable;
|
||||
use std::fmt;
|
||||
|
||||
|
@ -18,7 +18,7 @@ impl<'ast, S> From<GVariable<'ast, S>> for GParameter<'ast, S> {
|
|||
}
|
||||
}
|
||||
|
||||
pub type DeclarationParameter<'ast> = GParameter<'ast, Constant<'ast>>;
|
||||
pub type DeclarationParameter<'ast> = GParameter<'ast, DeclarationConstant<'ast>>;
|
||||
|
||||
impl<'ast, S: fmt::Display + Clone> fmt::Display for GParameter<'ast, S> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Generic walk through a typed AST. Not mutating in place
|
||||
|
||||
use crate::typed_absy::types::{ArrayType, StructMember, StructType};
|
||||
use crate::typed_absy::types::*;
|
||||
use crate::typed_absy::*;
|
||||
use zokrates_field::Field;
|
||||
|
||||
|
@ -97,6 +97,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
fold_signature(self, s)
|
||||
}
|
||||
|
||||
fn fold_declaration_constant(
|
||||
&mut self,
|
||||
c: DeclarationConstant<'ast>,
|
||||
) -> Result<DeclarationConstant<'ast>, Self::Error> {
|
||||
fold_declaration_constant(self, c)
|
||||
}
|
||||
|
||||
fn fold_parameter(
|
||||
&mut self,
|
||||
p: DeclarationParameter<'ast>,
|
||||
|
@ -214,6 +221,34 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
Ok(t)
|
||||
}
|
||||
|
||||
fn fold_declaration_array_type(
|
||||
&mut self,
|
||||
t: DeclarationArrayType<'ast>,
|
||||
) -> Result<DeclarationArrayType<'ast>, Self::Error> {
|
||||
Ok(DeclarationArrayType {
|
||||
ty: box self.fold_declaration_type(*t.ty)?,
|
||||
size: self.fold_declaration_constant(t.size)?,
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_declaration_struct_type(
|
||||
&mut self,
|
||||
t: DeclarationStructType<'ast>,
|
||||
) -> Result<DeclarationStructType<'ast>, Self::Error> {
|
||||
Ok(DeclarationStructType {
|
||||
members: t
|
||||
.members
|
||||
.into_iter()
|
||||
.map(|m| {
|
||||
let id = m.id;
|
||||
self.fold_declaration_type(*m.ty)
|
||||
.map(|ty| DeclarationStructMember { ty: box ty, id })
|
||||
})
|
||||
.collect::<Result<_, _>>()?,
|
||||
..t
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_assignee(
|
||||
&mut self,
|
||||
a: TypedAssignee<'ast, T>,
|
||||
|
@ -934,6 +969,13 @@ fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
})
|
||||
}
|
||||
|
||||
fn fold_declaration_constant<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
_: &mut F,
|
||||
c: DeclarationConstant<'ast>,
|
||||
) -> Result<DeclarationConstant<'ast>, F::Error> {
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
pub fn fold_array_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
e: ArrayExpression<'ast, T>,
|
||||
|
|
|
@ -101,37 +101,49 @@ impl<'ast> fmt::Display for GenericIdentifier<'ast> {
|
|||
#[derive(Debug)]
|
||||
pub struct SpecializationError;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord)]
|
||||
pub struct ConstantIdentifier<'ast> {
|
||||
pub module: OwnedTypedModuleId,
|
||||
pub id: &'ast str,
|
||||
}
|
||||
|
||||
impl<'ast> ConstantIdentifier<'ast> {
|
||||
pub fn new(id: &'ast str, module: OwnedTypedModuleId) -> Self {
|
||||
ConstantIdentifier { id, module }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub enum Constant<'ast> {
|
||||
pub enum DeclarationConstant<'ast> {
|
||||
Generic(GenericIdentifier<'ast>),
|
||||
Concrete(u32),
|
||||
Identifier(&'ast str, usize),
|
||||
Identifier(&'ast str),
|
||||
}
|
||||
|
||||
impl<'ast> From<u32> for Constant<'ast> {
|
||||
impl<'ast> From<u32> for DeclarationConstant<'ast> {
|
||||
fn from(e: u32) -> Self {
|
||||
Constant::Concrete(e)
|
||||
DeclarationConstant::Concrete(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<usize> for Constant<'ast> {
|
||||
impl<'ast> From<usize> for DeclarationConstant<'ast> {
|
||||
fn from(e: usize) -> Self {
|
||||
Constant::Concrete(e as u32)
|
||||
DeclarationConstant::Concrete(e as u32)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<GenericIdentifier<'ast>> for Constant<'ast> {
|
||||
impl<'ast> From<GenericIdentifier<'ast>> for DeclarationConstant<'ast> {
|
||||
fn from(e: GenericIdentifier<'ast>) -> Self {
|
||||
Constant::Generic(e)
|
||||
DeclarationConstant::Generic(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> fmt::Display for Constant<'ast> {
|
||||
impl<'ast> fmt::Display for DeclarationConstant<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Constant::Generic(i) => write!(f, "{}", i),
|
||||
Constant::Concrete(v) => write!(f, "{}", v),
|
||||
Constant::Identifier(v, _) => write!(f, "{}", v),
|
||||
DeclarationConstant::Generic(i) => write!(f, "{}", i),
|
||||
DeclarationConstant::Concrete(v) => write!(f, "{}", v),
|
||||
DeclarationConstant::Identifier(v) => write!(f, "{}", v),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -142,16 +154,17 @@ impl<'ast, T> From<usize> for UExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T> From<Constant<'ast>> for UExpression<'ast, T> {
|
||||
fn from(c: Constant<'ast>) -> Self {
|
||||
impl<'ast, T> From<DeclarationConstant<'ast>> for UExpression<'ast, T> {
|
||||
fn from(c: DeclarationConstant<'ast>) -> Self {
|
||||
match c {
|
||||
Constant::Generic(i) => {
|
||||
DeclarationConstant::Generic(i) => {
|
||||
UExpressionInner::Identifier(i.name.into()).annotate(UBitwidth::B32)
|
||||
}
|
||||
Constant::Concrete(v) => UExpressionInner::Value(v as u128).annotate(UBitwidth::B32),
|
||||
Constant::Identifier(v, size) => {
|
||||
UExpressionInner::Identifier(Identifier::from(v)).annotate(UBitwidth::from(size))
|
||||
DeclarationConstant::Concrete(v) => {
|
||||
UExpressionInner::Value(v as u128).annotate(UBitwidth::B32)
|
||||
}
|
||||
DeclarationConstant::Identifier(v) => UExpressionInner::Identifier(Identifier::from(v))
|
||||
.annotate(UBitwidth::from(UBitwidth::B32)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -169,12 +182,12 @@ impl<'ast, T> TryInto<usize> for UExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> TryInto<usize> for Constant<'ast> {
|
||||
impl<'ast> TryInto<usize> for DeclarationConstant<'ast> {
|
||||
type Error = SpecializationError;
|
||||
|
||||
fn try_into(self) -> Result<usize, Self::Error> {
|
||||
match self {
|
||||
Constant::Concrete(v) => Ok(v as usize),
|
||||
DeclarationConstant::Concrete(v) => Ok(v as usize),
|
||||
_ => Err(SpecializationError),
|
||||
}
|
||||
}
|
||||
|
@ -190,7 +203,7 @@ pub struct GStructMember<S> {
|
|||
pub ty: Box<GType<S>>,
|
||||
}
|
||||
|
||||
pub type DeclarationStructMember<'ast> = GStructMember<Constant<'ast>>;
|
||||
pub type DeclarationStructMember<'ast> = GStructMember<DeclarationConstant<'ast>>;
|
||||
pub type ConcreteStructMember = GStructMember<usize>;
|
||||
pub type StructMember<'ast, T> = GStructMember<UExpression<'ast, T>>;
|
||||
|
||||
|
@ -242,7 +255,7 @@ pub struct GArrayType<S> {
|
|||
pub ty: Box<GType<S>>,
|
||||
}
|
||||
|
||||
pub type DeclarationArrayType<'ast> = GArrayType<Constant<'ast>>;
|
||||
pub type DeclarationArrayType<'ast> = GArrayType<DeclarationConstant<'ast>>;
|
||||
pub type ConcreteArrayType = GArrayType<usize>;
|
||||
pub type ArrayType<'ast, T> = GArrayType<UExpression<'ast, T>>;
|
||||
|
||||
|
@ -250,7 +263,7 @@ impl<'ast, T: PartialEq> PartialEq<DeclarationArrayType<'ast>> for ArrayType<'as
|
|||
fn eq(&self, other: &DeclarationArrayType<'ast>) -> bool {
|
||||
*self.ty == *other.ty
|
||||
&& match (self.size.as_inner(), &other.size) {
|
||||
(UExpressionInner::Value(l), Constant::Concrete(r)) => *l as u32 == *r,
|
||||
(UExpressionInner::Value(l), DeclarationConstant::Concrete(r)) => *l as u32 == *r,
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
@ -349,7 +362,7 @@ pub struct GStructType<S> {
|
|||
pub members: Vec<GStructMember<S>>,
|
||||
}
|
||||
|
||||
pub type DeclarationStructType<'ast> = GStructType<Constant<'ast>>;
|
||||
pub type DeclarationStructType<'ast> = GStructType<DeclarationConstant<'ast>>;
|
||||
pub type ConcreteStructType = GStructType<usize>;
|
||||
pub type StructType<'ast, T> = GStructType<UExpression<'ast, T>>;
|
||||
|
||||
|
@ -588,7 +601,7 @@ impl<'de, S: Deserialize<'de>> Deserialize<'de> for GType<S> {
|
|||
}
|
||||
}
|
||||
|
||||
pub type DeclarationType<'ast> = GType<Constant<'ast>>;
|
||||
pub type DeclarationType<'ast> = GType<DeclarationConstant<'ast>>;
|
||||
pub type ConcreteType = GType<usize>;
|
||||
pub type Type<'ast, T> = GType<UExpression<'ast, T>>;
|
||||
|
||||
|
@ -711,7 +724,7 @@ impl<'ast, T: fmt::Display + PartialEq + fmt::Debug> Type<'ast, T> {
|
|||
// check the size if types match
|
||||
match (&l.size.as_inner(), &r.size) {
|
||||
// compare the sizes for concrete ones
|
||||
(UExpressionInner::Value(v), Constant::Concrete(c)) => {
|
||||
(UExpressionInner::Value(v), DeclarationConstant::Concrete(c)) => {
|
||||
(*v as u32) == *c
|
||||
}
|
||||
_ => true,
|
||||
|
@ -772,7 +785,7 @@ pub struct GFunctionKey<'ast, S> {
|
|||
pub signature: GSignature<S>,
|
||||
}
|
||||
|
||||
pub type DeclarationFunctionKey<'ast> = GFunctionKey<'ast, Constant<'ast>>;
|
||||
pub type DeclarationFunctionKey<'ast> = GFunctionKey<'ast, DeclarationConstant<'ast>>;
|
||||
pub type ConcreteFunctionKey<'ast> = GFunctionKey<'ast, usize>;
|
||||
pub type FunctionKey<'ast, T> = GFunctionKey<'ast, UExpression<'ast, T>>;
|
||||
|
||||
|
@ -948,7 +961,7 @@ pub mod signature {
|
|||
}
|
||||
}
|
||||
|
||||
pub type DeclarationSignature<'ast> = GSignature<Constant<'ast>>;
|
||||
pub type DeclarationSignature<'ast> = GSignature<DeclarationConstant<'ast>>;
|
||||
pub type ConcreteSignature = GSignature<usize>;
|
||||
pub type Signature<'ast, T> = GSignature<UExpression<'ast, T>>;
|
||||
|
||||
|
@ -968,15 +981,17 @@ pub mod signature {
|
|||
&& match &t0.size {
|
||||
// if the declared size is an identifier, we insert into the map, or check if the concrete size
|
||||
// matches if this identifier is already in the map
|
||||
Constant::Generic(id) => match constants.0.entry(id.clone()) {
|
||||
DeclarationConstant::Generic(id) => match constants.0.entry(id.clone()) {
|
||||
Entry::Occupied(e) => *e.get() == s1,
|
||||
Entry::Vacant(e) => {
|
||||
e.insert(s1);
|
||||
true
|
||||
}
|
||||
},
|
||||
Constant::Concrete(s0) => s1 == *s0 as usize,
|
||||
Constant::Identifier(_, s0) => s1 == *s0,
|
||||
DeclarationConstant::Concrete(s0) => s1 == *s0 as usize,
|
||||
// in the case of a constant, we do not know the value yet, so we optimistically assume it's correct
|
||||
// if it does not match, it will be caught during inlining
|
||||
DeclarationConstant::Identifier(..) => true,
|
||||
}
|
||||
}
|
||||
(DeclarationType::FieldElement, GType::FieldElement)
|
||||
|
@ -1000,9 +1015,11 @@ pub mod signature {
|
|||
|
||||
let ty = box specialize_type(*t0.ty, &constants)?;
|
||||
let size = match t0.size {
|
||||
Constant::Generic(s) => constants.0.get(&s).cloned().ok_or(s),
|
||||
Constant::Concrete(s) => Ok(s.into()),
|
||||
Constant::Identifier(_, s) => Ok((s as u32).into()),
|
||||
DeclarationConstant::Generic(s) => constants.0.get(&s).cloned().ok_or(s),
|
||||
DeclarationConstant::Concrete(s) => Ok(s.into()),
|
||||
DeclarationConstant::Identifier(..) => {
|
||||
unreachable!("identifiers should have been removed in constant inlining")
|
||||
}
|
||||
}?;
|
||||
|
||||
GType::Array(GArrayType { size, ty })
|
||||
|
@ -1053,7 +1070,7 @@ pub mod signature {
|
|||
assert_eq!(self.generics.len(), values.len());
|
||||
|
||||
let decl_generics = self.generics.iter().map(|g| match g.clone().unwrap() {
|
||||
Constant::Generic(g) => g,
|
||||
DeclarationConstant::Generic(g) => g,
|
||||
_ => unreachable!(),
|
||||
});
|
||||
|
||||
|
@ -1096,7 +1113,7 @@ pub mod signature {
|
|||
v.map(|v| {
|
||||
(
|
||||
match g.clone().unwrap() {
|
||||
Constant::Generic(g) => g,
|
||||
DeclarationConstant::Generic(g) => g,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
v,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::typed_absy::types::{Constant, GStructType, UBitwidth};
|
||||
use crate::typed_absy::types::{DeclarationConstant, GStructType, UBitwidth};
|
||||
use crate::typed_absy::types::{GType, SpecializationError};
|
||||
use crate::typed_absy::Identifier;
|
||||
use crate::typed_absy::UExpression;
|
||||
|
@ -11,7 +11,7 @@ pub struct GVariable<'ast, S> {
|
|||
pub _type: GType<S>,
|
||||
}
|
||||
|
||||
pub type DeclarationVariable<'ast> = GVariable<'ast, Constant<'ast>>;
|
||||
pub type DeclarationVariable<'ast> = GVariable<'ast, DeclarationConstant<'ast>>;
|
||||
pub type ConcreteVariable<'ast> = GVariable<'ast, usize>;
|
||||
pub type Variable<'ast, T> = GVariable<'ast, UExpression<'ast, T>>;
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from "./origin.zok" import foo
|
||||
def main():
|
||||
assert(foo([1, 1]))
|
||||
return
|
|
@ -1,3 +1,3 @@
|
|||
const u32 N = 42
|
||||
const u32 N = 1
|
||||
def foo(field[N] a) -> bool:
|
||||
return true
|
Loading…
Reference in a new issue