diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index ab97df45..a56c06d3 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -213,10 +213,8 @@ fn check_with_arena<'ast, T: Field, E: Into>( CompileErrors(errors.into_iter().map(|e| CompileError::from(e)).collect()) })?; - let abi = typed_ast.abi(); - // analyse (unroll and constant propagation) - let typed_ast = typed_ast.analyse(); + let (typed_ast, abi) = typed_ast.analyse(); Ok((typed_ast, abi)) } diff --git a/zokrates_core/src/embed.rs b/zokrates_core/src/embed.rs index 48284e3d..91f68d33 100644 --- a/zokrates_core/src/embed.rs +++ b/zokrates_core/src/embed.rs @@ -4,7 +4,7 @@ use flat_absy::{ FlatVariable, }; use std::collections::HashMap; -use typed_absy::types::{FunctionKey, Signature, Type}; +use typed_absy::types::{ConcreteFunctionKey, ConcreteSignature, ConcreteType}; use zokrates_field::Field; /// A low level function that contains non-deterministic introduction of variables. It is carried out as is until @@ -21,34 +21,34 @@ pub enum FlatEmbed { } impl FlatEmbed { - pub fn signature(&self) -> Signature { + pub fn signature(&self) -> ConcreteSignature { match self { - FlatEmbed::Unpack(bitwidth) => Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::array(Type::Boolean, *bitwidth)]), - FlatEmbed::U8ToBits => Signature::new() - .inputs(vec![Type::uint(8)]) - .outputs(vec![Type::array(Type::Boolean, 8)]), - FlatEmbed::U16ToBits => Signature::new() - .inputs(vec![Type::uint(16)]) - .outputs(vec![Type::array(Type::Boolean, 16)]), - FlatEmbed::U32ToBits => Signature::new() - .inputs(vec![Type::uint(32)]) - .outputs(vec![Type::array(Type::Boolean, 32)]), - FlatEmbed::U8FromBits => Signature::new() - .outputs(vec![Type::uint(8)]) - .inputs(vec![Type::array(Type::Boolean, 8)]), - FlatEmbed::U16FromBits => Signature::new() - .outputs(vec![Type::uint(16)]) - .inputs(vec![Type::array(Type::Boolean, 16)]), - FlatEmbed::U32FromBits => Signature::new() - .outputs(vec![Type::uint(32)]) - .inputs(vec![Type::array(Type::Boolean, 32)]), + FlatEmbed::Unpack(bitwidth) => ConcreteSignature::new() + .inputs(vec![ConcreteType::FieldElement]) + .outputs(vec![ConcreteType::array(ConcreteType::Boolean, *bitwidth)]), + FlatEmbed::U8ToBits => ConcreteSignature::new() + .inputs(vec![ConcreteType::uint(8)]) + .outputs(vec![ConcreteType::array(ConcreteType::Boolean, 8usize)]), + FlatEmbed::U16ToBits => ConcreteSignature::new() + .inputs(vec![ConcreteType::uint(16)]) + .outputs(vec![ConcreteType::array(ConcreteType::Boolean, 16usize)]), + FlatEmbed::U32ToBits => ConcreteSignature::new() + .inputs(vec![ConcreteType::uint(32)]) + .outputs(vec![ConcreteType::array(ConcreteType::Boolean, 32usize)]), + FlatEmbed::U8FromBits => ConcreteSignature::new() + .outputs(vec![ConcreteType::uint(8)]) + .inputs(vec![ConcreteType::array(ConcreteType::Boolean, 8usize)]), + FlatEmbed::U16FromBits => ConcreteSignature::new() + .outputs(vec![ConcreteType::uint(16)]) + .inputs(vec![ConcreteType::array(ConcreteType::Boolean, 16usize)]), + FlatEmbed::U32FromBits => ConcreteSignature::new() + .outputs(vec![ConcreteType::uint(32)]) + .inputs(vec![ConcreteType::array(ConcreteType::Boolean, 32usize)]), } } - pub fn key(&self) -> FunctionKey<'static> { - FunctionKey::with_id(self.id()).signature(self.signature()) + pub fn key(&self) -> ConcreteFunctionKey<'static> { + ConcreteFunctionKey::with_id(self.id()).signature(self.signature()) } pub fn id(&self) -> &'static str { diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 3e090bb9..8a037420 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -7,7 +7,7 @@ use crate::absy::Identifier; use crate::absy::*; use crate::typed_absy::*; -use crate::typed_absy::{Parameter, Variable}; +use crate::typed_absy::{DeclarationParameter, DeclarationVariable, Parameter, Variable}; use std::collections::{hash_map::Entry, BTreeSet, HashMap, HashSet}; use std::fmt; use std::path::PathBuf; @@ -19,7 +19,12 @@ use crate::absy::types::{UnresolvedSignature, UnresolvedType, UserTypeId}; use crate::typed_absy::types::{FunctionKey, Signature, Type}; use std::hash::{Hash, Hasher}; -use typed_absy::types::{ArrayType, StructMember}; +use typed_absy::types::{ + ArrayType, Constant, DeclarationArrayType, DeclarationFunctionKey, DeclarationSignature, + DeclarationStructMember, DeclarationStructType, DeclarationType, StructMember, +}; + +use std::convert::{TryFrom, TryInto}; #[derive(PartialEq, Debug)] pub struct ErrorInner { @@ -42,33 +47,33 @@ impl ErrorInner { } } -type TypeMap = HashMap>; +type TypeMap<'ast> = HashMap>>; /// The global state of the program during semantic checks #[derive(Debug)] -struct State<'ast, T: Field> { +struct State<'ast, T> { /// The modules yet to be checked, which we consume as we explore the dependency tree modules: Modules<'ast, T>, /// The already checked modules, which we're returning at the end typed_modules: TypedModules<'ast, T>, /// The user-defined types, which we keep track at this phase only. In later phases, we rely only on basic types and combinations thereof - types: TypeMap, + types: TypeMap<'ast>, } /// A symbol for a given name: either a type or a group of functions. Not both! #[derive(PartialEq, Hash, Eq, Debug)] -enum SymbolType { +enum SymbolType<'ast> { Type, - Functions(BTreeSet), + Functions(BTreeSet>), } /// A data structure to keep track of all symbols in a module #[derive(Default)] -struct SymbolUnifier { - symbols: HashMap, +struct SymbolUnifier<'ast> { + symbols: HashMap>, } -impl SymbolUnifier { +impl<'ast> SymbolUnifier<'ast> { fn insert_type>(&mut self, id: S) -> bool { let s_type = self.symbols.entry(id.into()); match s_type { @@ -82,7 +87,11 @@ impl SymbolUnifier { } } - fn insert_function>(&mut self, id: S, signature: Signature) -> bool { + fn insert_function>( + &mut self, + id: S, + signature: DeclarationSignature<'ast>, + ) -> bool { let s_type = self.symbols.entry(id.into()); match s_type { // if anything is already called `id`, it depends what it is @@ -124,14 +133,14 @@ impl fmt::Display for ErrorInner { } /// A function query in the current module. -struct FunctionQuery<'ast> { +struct FunctionQuery<'ast, T> { id: Identifier<'ast>, - inputs: Vec, + inputs: Vec>, /// Output types are optional as we try to infer them - outputs: Vec>, + outputs: Vec>>, } -impl<'ast> fmt::Display for FunctionQuery<'ast> { +impl<'ast, T> fmt::Display for FunctionQuery<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "(")?; for (i, t) in self.inputs.iter().enumerate() { @@ -169,13 +178,13 @@ impl<'ast> fmt::Display for FunctionQuery<'ast> { } } -impl<'ast> FunctionQuery<'ast> { +impl<'ast, T: Field> FunctionQuery<'ast, T> { /// Create a new query. fn new( id: Identifier<'ast>, - inputs: &Vec, - outputs: &Vec>, - ) -> FunctionQuery<'ast> { + inputs: &Vec>, + outputs: &Vec>>, + ) -> Self { FunctionQuery { id, inputs: inputs.clone(), @@ -184,52 +193,57 @@ impl<'ast> FunctionQuery<'ast> { } /// match a `FunctionKey` against this `FunctionQuery` - fn match_func(&self, func: &FunctionKey) -> bool { - self.id == func.id - && self.inputs == func.signature.inputs - && self.outputs.len() == func.signature.outputs.len() - && self.outputs.iter().enumerate().all(|(index, t)| match t { - Some(ref t) => t == &func.signature.outputs[index], - _ => true, - }) + fn match_func(&self, func: &FunctionKey<'ast, T>) -> bool { + // self.id == func.id + // && self.inputs == func.signature.inputs + // && self.outputs.len() == func.signature.outputs.len() + // && self.outputs.iter().enumerate().all(|(index, t)| match t { + // Some(ref t) => t == &func.signature.outputs[index], + // _ => true, + // }) + unimplemented!() } - fn match_funcs(&self, funcs: &HashSet>) -> Option> { - funcs.iter().find(|func| self.match_func(func)).cloned() + fn match_funcs( + &self, + funcs: &HashSet>, + ) -> Option> { + // funcs.iter().find(|func| self.match_func(func)).cloned() + unimplemented!() } } /// A scoped variable, so that we can delete all variables of a given scope when exiting it #[derive(Clone, Debug)] -pub struct ScopedVariable<'ast> { - id: Variable<'ast>, +pub struct ScopedVariable<'ast, T> { + id: Variable<'ast, T>, level: usize, } /// Identifiers of different `ScopedVariable`s should not conflict, so we define them as equivalent -impl<'ast> PartialEq for ScopedVariable<'ast> { - fn eq(&self, other: &ScopedVariable) -> bool { +impl<'ast, T> PartialEq for ScopedVariable<'ast, T> { + fn eq(&self, other: &Self) -> bool { self.id.id == other.id.id } } -impl<'ast> Hash for ScopedVariable<'ast> { +impl<'ast, T> Hash for ScopedVariable<'ast, T> { fn hash(&self, state: &mut H) { self.id.id.hash(state); } } -impl<'ast> Eq for ScopedVariable<'ast> {} +impl<'ast, T> Eq for ScopedVariable<'ast, T> {} /// Checker checks the semantics of a program, keeping track of functions and variables in scope -pub struct Checker<'ast> { - scope: HashSet>, - functions: HashSet>, +pub struct Checker<'ast, T> { + scope: HashSet>, + functions: HashSet>, level: usize, } -impl<'ast> Checker<'ast> { - fn new() -> Checker<'ast> { +impl<'ast, T: Field> Checker<'ast, T> { + fn new() -> Self { Checker { scope: HashSet::new(), functions: HashSet::new(), @@ -242,11 +256,11 @@ impl<'ast> Checker<'ast> { /// # Arguments /// /// * `prog` - The `Program` to be checked - pub fn check(prog: Program<'ast, T>) -> Result, Vec> { + pub fn check(prog: Program<'ast, T>) -> Result, Vec> { Checker::new().check_program(prog) } - fn check_program( + fn check_program( &mut self, program: Program<'ast, T>, ) -> Result, Vec> { @@ -281,13 +295,13 @@ impl<'ast> Checker<'ast> { }) } - fn check_struct_type_declaration( + fn check_struct_type_declaration( &mut self, id: String, s: StructDefinitionNode<'ast, T>, module_id: &ModuleId, - types: &TypeMap, - ) -> Result> { + types: &TypeMap<'ast>, + ) -> Result, Vec> { let pos = s.pos(); let s = s.value; @@ -298,7 +312,7 @@ impl<'ast> Checker<'ast> { for field in s.fields { let member_id = field.value.id.to_string(); match self - .check_type(field.value.ty, module_id, &types) + .check_declaration_type(field.value.ty, module_id, &types) .map(|t| (member_id, t)) { Ok(f) => match fields_set.insert(f.0.clone()) { @@ -318,23 +332,23 @@ impl<'ast> Checker<'ast> { return Err(errors); } - Ok(Type::Struct(StructType::new( + Ok(DeclarationType::Struct(DeclarationStructType::new( module_id.into(), id, fields .iter() - .map(|f| StructMember::new(f.0.clone(), f.1.clone())) + .map(|f| DeclarationStructMember::new(f.0.clone(), f.1.clone())) .collect(), ))) } - fn check_symbol_declaration( + fn check_symbol_declaration( &mut self, declaration: SymbolDeclarationNode<'ast, T>, module_id: &ModuleId, state: &mut State<'ast, T>, - functions: &mut HashMap, TypedFunctionSymbol<'ast, T>>, - symbol_unifier: &mut SymbolUnifier, + functions: &mut HashMap, TypedFunctionSymbol<'ast, T>>, + symbol_unifier: &mut SymbolUnifier<'ast>, ) -> Result<(), Vec> { let mut errors: Vec = vec![]; @@ -392,11 +406,11 @@ impl<'ast> Checker<'ast> { }; self.functions.insert( - FunctionKey::with_id(declaration.id.clone()) + DeclarationFunctionKey::with_id(declaration.id.clone()) .signature(funct.signature.clone()), ); functions.insert( - FunctionKey::with_id(declaration.id.clone()) + DeclarationFunctionKey::with_id(declaration.id.clone()) .signature(funct.signature.clone()), TypedFunctionSymbol::Here(funct), ); @@ -419,7 +433,7 @@ impl<'ast> Checker<'ast> { .functions .iter() .filter(|(k, _)| k.id == import.symbol_id) - .map(|(_, v)| FunctionKey { + .map(|(_, v)| DeclarationFunctionKey { id: import.symbol_id.clone(), signature: v.signature(&state.typed_modules).clone(), }) @@ -438,7 +452,7 @@ impl<'ast> Checker<'ast> { // rename the type to the declared symbol let t = match t { - Type::Struct(t) => Type::Struct(StructType { + DeclarationType::Struct(t) => DeclarationType::Struct(DeclarationStructType { module: module_id.clone(), name: declaration.id.into(), ..t @@ -511,7 +525,9 @@ impl<'ast> Checker<'ast> { }; } Symbol::Flat(funct) => { - match symbol_unifier.insert_function(declaration.id, funct.signature()) { + match symbol_unifier + .insert_function(declaration.id, funct.signature().try_into().unwrap()) + { false => { errors.push( ErrorInner { @@ -528,12 +544,12 @@ impl<'ast> Checker<'ast> { }; self.functions.insert( - FunctionKey::with_id(declaration.id.clone()) - .signature(funct.signature().clone()), + DeclarationFunctionKey::with_id(declaration.id.clone()) + .signature(funct.signature().clone().try_into().unwrap()), ); functions.insert( - FunctionKey::with_id(declaration.id.clone()) - .signature(funct.signature().clone()), + DeclarationFunctionKey::with_id(declaration.id.clone()) + .signature(funct.signature().clone().try_into().unwrap()), TypedFunctionSymbol::Flat(funct), ); } @@ -547,7 +563,7 @@ impl<'ast> Checker<'ast> { Ok(()) } - fn check_module( + fn check_module( &mut self, module_id: &ModuleId, state: &mut State<'ast, T>, @@ -611,7 +627,7 @@ impl<'ast> Checker<'ast> { Ok(()) } - fn check_single_main(module: &TypedModule) -> Result<(), ErrorInner> { + fn check_single_main(module: &TypedModule) -> Result<(), ErrorInner> { match module .functions .iter() @@ -630,7 +646,7 @@ impl<'ast> Checker<'ast> { } } - fn check_for_var(&self, var: &VariableNode<'ast, T>) -> Result<(), ErrorInner> { + fn check_for_var(&self, var: &VariableNode<'ast, T>) -> Result<(), ErrorInner> { match var.value.get_type() { UnresolvedType::Uint(32) => Ok(()), t => Err(ErrorInner { @@ -640,11 +656,11 @@ impl<'ast> Checker<'ast> { } } - fn check_function( + fn check_function( &mut self, funct_node: FunctionNode<'ast, T>, module_id: &ModuleId, - types: &TypeMap, + types: &TypeMap<'ast>, ) -> Result, Vec> { self.enter_scope(); @@ -704,6 +720,10 @@ impl<'ast> Checker<'ast> { TypedStatement::Return(e) => { match e.iter().map(|e| e.get_type()).collect::>() == s.outputs + .clone() + .into_iter() + .map(|o| o.into()) + .collect::>() { true => {} false => errors.push(ErrorInner { @@ -753,32 +773,34 @@ impl<'ast> Checker<'ast> { }) } - fn check_parameter( + fn check_parameter( &mut self, p: ParameterNode<'ast, T>, module_id: &ModuleId, - types: &TypeMap, - ) -> Result, Vec> { - let var = self.check_variable(p.value.id, module_id, types)?; + types: &TypeMap<'ast>, + ) -> Result, Vec> { + let var = self.check_declaration_variable(p.value.id, module_id, types)?; - Ok(Parameter { + Ok(DeclarationParameter { id: var, private: p.value.private, }) } - fn check_signature( + fn check_signature( &mut self, signature: UnresolvedSignature<'ast, T>, module_id: &ModuleId, - types: &TypeMap, - ) -> Result> { + types: &TypeMap<'ast>, + ) -> Result, Vec> { let mut errors = vec![]; let mut inputs = vec![]; let mut outputs = vec![]; + // TODO: check that generics are declared + for t in signature.inputs { - match self.check_type(t, module_id, types) { + match self.check_declaration_type(t, module_id, types) { Ok(t) => { inputs.push(t); } @@ -789,7 +811,7 @@ impl<'ast> Checker<'ast> { } for t in signature.outputs { - match self.check_type(t, module_id, types) { + match self.check_declaration_type(t, module_id, types) { Ok(t) => { outputs.push(t); } @@ -803,15 +825,15 @@ impl<'ast> Checker<'ast> { return Err(errors); } - Ok(Signature { inputs, outputs }) + Ok(DeclarationSignature { inputs, outputs }) } - fn check_type( + fn check_type( &mut self, ty: UnresolvedTypeNode<'ast, T>, module_id: &ModuleId, - types: &TypeMap, - ) -> Result { + types: &TypeMap<'ast>, + ) -> Result, ErrorInner> { let pos = ty.pos(); let ty = ty.value; @@ -826,7 +848,7 @@ impl<'ast> Checker<'ast> { let size = match size { TypedExpression::Uint(e) => match e.bitwidth() { - UBitwidth::B32 => Ok(e.inner), + UBitwidth::B32 => Ok(e), bitwidth => Err(ErrorInner { pos: Some(pos), message: format!( @@ -844,16 +866,73 @@ impl<'ast> Checker<'ast> { }), }?; - let size = match size { - UExpressionInner::Value(v) => v, - _ => unimplemented!(), - } as usize; - Ok(Type::Array(ArrayType::new( self.check_type(*t, module_id, types)?, size, ))) } + UnresolvedType::User(id) => types + .get(module_id) + .unwrap() + .get(&id) + .cloned() + .ok_or_else(|| ErrorInner { + pos: Some(pos), + message: format!("Undefined type {}", id), + }) + .map(|t| t.into()), + } + } + + fn check_declaration_type( + &mut self, + ty: UnresolvedTypeNode<'ast, T>, + module_id: &ModuleId, + types: &TypeMap<'ast>, + ) -> Result, ErrorInner> { + let pos = ty.pos(); + let ty = ty.value; + + match ty { + UnresolvedType::FieldElement => Ok(DeclarationType::FieldElement), + UnresolvedType::Boolean => Ok(DeclarationType::Boolean), + UnresolvedType::Uint(bitwidth) => Ok(DeclarationType::uint(bitwidth)), + UnresolvedType::Array(t, size) => { + let size = self.check_expression(size, module_id, types)?; + + let ty = size.get_type(); + + unimplemented!("make parser stricter and only accept id and literal here?"); + + let size = match size { + TypedExpression::Uint(e) => match e.bitwidth() { + UBitwidth::B32 => match e.into_inner() { + UExpressionInner::Identifier(id) => unimplemented!(), + UExpressionInner::Value(v) => Ok(Constant::Concrete(v as u32)), + _ => unimplemented!(), + }, + bitwidth => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected array dimension to be a u32 constant, found {} of type {}", + e, ty + ), + }), + }, + _ => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected array dimension to be a u32 constant, found {} of type {}", + size, ty + ), + }), + }?; + + Ok(DeclarationType::Array(DeclarationArrayType::new( + self.check_declaration_type(*t, module_id, types)?, + size, + ))) + } UnresolvedType::User(id) => { types .get(module_id) @@ -868,12 +947,12 @@ impl<'ast> Checker<'ast> { } } - fn check_variable( + fn check_variable( &mut self, v: crate::absy::VariableNode<'ast, T>, module_id: &ModuleId, - types: &TypeMap, - ) -> Result, Vec> { + types: &TypeMap<'ast>, + ) -> Result, Vec> { Ok(Variable::with_id_and_type( v.value.id, self.check_type(v.value._type, module_id, types) @@ -881,11 +960,24 @@ impl<'ast> Checker<'ast> { )) } - fn check_statement( + fn check_declaration_variable( + &mut self, + v: crate::absy::VariableNode<'ast, T>, + module_id: &ModuleId, + types: &TypeMap<'ast>, + ) -> Result, Vec> { + Ok(DeclarationVariable::with_id_and_type( + v.value.id, + self.check_declaration_type(v.value._type, module_id, types) + .map_err(|e| vec![e])?, + )) + } + + fn check_statement( &mut self, stat: StatementNode<'ast, T>, module_id: &ModuleId, - types: &TypeMap, + types: &TypeMap<'ast>, ) -> Result, Vec> { let pos = stat.pos(); @@ -979,8 +1071,6 @@ impl<'ast> Checker<'ast> { .check_expression(to, module_id, &types) .map_err(|e| vec![e])?; - use std::convert::TryInto; - let from = match from.get_type() { Type::Uint(UBitwidth::B32) => Ok(from.try_into().unwrap()), t => Err(ErrorInner { @@ -1081,11 +1171,11 @@ impl<'ast> Checker<'ast> { } } - fn check_assignee( + fn check_assignee( &mut self, assignee: AssigneeNode<'ast, T>, module_id: &ModuleId, - types: &TypeMap, + types: &TypeMap<'ast>, ) -> Result, ErrorInner> { let pos = assignee.pos(); // check that the assignee is declared @@ -1117,7 +1207,7 @@ impl<'ast> Checker<'ast> { }; let checked_typed_index = match checked_index { - TypedExpression::FieldElement(e) => Ok(e), + TypedExpression::Uint(e) => Ok(e), e => Err(ErrorInner { pos: Some(pos), @@ -1169,11 +1259,11 @@ impl<'ast> Checker<'ast> { } } - fn check_spread_or_expression( + fn check_spread_or_expression( &mut self, spread_or_expression: SpreadOrExpression<'ast, T>, module_id: &ModuleId, - types: &TypeMap, + types: &TypeMap<'ast>, ) -> Result>, ErrorInner> { match spread_or_expression { SpreadOrExpression::Spread(s) => { @@ -1187,44 +1277,72 @@ impl<'ast> Checker<'ast> { let ty = e.inner_type().clone(); let size = e.size(); - match e.into_inner() { - // if we're doing a spread over an inline array, we return the inside of the array: ...[x, y, z] == x, y, z - // this is not strictly needed, but it makes spreads memory linear rather than quadratic - ArrayExpressionInner::Value(v) => Ok(v), - // otherwise we return a[0], ..., a[a.size() -1 ] - e => Ok((0..size) - .map(|i| match &ty { - Type::FieldElement => FieldElementExpression::select( - e.clone().annotate(Type::FieldElement, size), - FieldElementExpression::Number(T::from(i)), - ) - .into(), - Type::Uint(bitwidth) => UExpression::select( - e.clone().annotate(Type::Uint(*bitwidth), size), - FieldElementExpression::Number(T::from(i)), - ) - .into(), - Type::Boolean => BooleanExpression::select( - e.clone().annotate(Type::Boolean, size), - FieldElementExpression::Number(T::from(i)), - ) - .into(), - Type::Array(array_type) => ArrayExpressionInner::Select( - box e - .clone() - .annotate(Type::Array(array_type.clone()), size), - box FieldElementExpression::Number(T::from(i)), - ) - .annotate(*array_type.ty.clone(), array_type.size) - .into(), - Type::Struct(members) => StructExpressionInner::Select( - box e.clone().annotate(Type::Struct(members.clone()), size), - box FieldElementExpression::Number(T::from(i)), - ) - .annotate(members.clone()) - .into(), - }) - .collect()), + + match size.into_inner() { + UExpressionInner::Value(size) => match e.into_inner() { + // if we're doing a spread over an inline array, we return the inside of the array: ...[x, y, z] == x, y, z + // this is not strictly needed, but it makes spreads memory linear rather than quadratic + ArrayExpressionInner::Value(v) => Ok(v), + // otherwise we return a[0], ..., a[a.size() -1 ] + e => Ok((0..size) + .map(|i| match &ty { + Type::FieldElement => FieldElementExpression::select( + e.clone().annotate( + Type::FieldElement, + UExpressionInner::Value(size) + .annotate(UBitwidth::B32), + ), + UExpressionInner::Value(i).annotate(UBitwidth::B32), + ) + .into(), + Type::Uint(bitwidth) => UExpression::select( + e.clone().annotate( + Type::Uint(*bitwidth), + UExpressionInner::Value(size) + .annotate(UBitwidth::B32), + ), + UExpressionInner::Value(i).annotate(UBitwidth::B32), + ) + .into(), + Type::Boolean => BooleanExpression::select( + e.clone().annotate( + Type::Boolean, + UExpressionInner::Value(size) + .annotate(UBitwidth::B32), + ), + UExpressionInner::Value(i).annotate(UBitwidth::B32), + ) + .into(), + Type::Array(array_type) => ArrayExpressionInner::Select( + box e.clone().annotate( + Type::Array(array_type.clone()), + UExpressionInner::Value(size) + .annotate(UBitwidth::B32), + ), + box UExpressionInner::Value(i).annotate(UBitwidth::B32), + ) + .annotate(*array_type.ty.clone(), array_type.size) + .into(), + Type::Struct(members) => StructExpressionInner::Select( + box e.clone().annotate( + Type::Struct(members.clone()), + UExpressionInner::Value(size) + .annotate(UBitwidth::B32), + ), + box UExpressionInner::Value(i).annotate(UBitwidth::B32), + ) + .annotate(members.clone()) + .into(), + }) + .collect()), + }, + size => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Array in spread should have a constant value, found {}", + size.annotate(UBitwidth::B32) + ), + }), } } e => Err(ErrorInner { @@ -1247,11 +1365,11 @@ impl<'ast> Checker<'ast> { } } - fn check_expression( + fn check_expression( &mut self, expr: ExpressionNode<'ast, T>, module_id: &ModuleId, - types: &TypeMap, + types: &TypeMap<'ast>, ) -> Result, ErrorInner> { let pos = expr.pos(); @@ -1713,130 +1831,159 @@ impl<'ast> Checker<'ast> { let array = self.check_expression(array, module_id, &types)?; match index { - RangeOrExpression::Range(r) => match array { - TypedExpression::Array(array) => { - let array_size = array.size(); - let inner_type = array.inner_type().clone(); + RangeOrExpression::Range(r) => { + unimplemented!() + // match array { + // TypedExpression::Array(array) => { + // let array_size = match array.size().into_inner() { + // UExpressionInner::Value(array_size) => { - // check that the bounds are valid expressions - let from = r - .value - .from - .map(|e| self.check_expression(e, module_id, &types)) - .unwrap_or(Ok(FieldElementExpression::Number(T::from(0)).into()))?; + // let inner_type = array.inner_type().clone(); - let to = r - .value - .to - .map(|e| self.check_expression(e, module_id, &types)) - .unwrap_or(Ok(FieldElementExpression::Number(T::from( - array_size, - )) - .into()))?; + // // check that the bounds are valid expressions + // let from = r + // .value + // .from + // .map(|e| self.check_expression(e, module_id, &types)) + // .unwrap_or(Ok(UExpressionInner::Value(0).annotate(UBitwidth::B32).into()))?; - // check the bounds are field constants - // Note: it would be nice to allow any field expression, and check it's a constant after constant propagation, - // but it's tricky from a type perspective: the size of the slice changes the type of the resulting array, - // which doesn't work well with our static array approach. Enabling arrays to have unknown size introduces a lot - // of complexity in the compiler, as function selection in inlining requires knowledge of the array size, but - // determining array size potentially requires inlining and propagating. This suggests we would need semantic checking - // to happen iteratively with inlining and propagation, which we can't do now as we go from absy to typed_absy - let from = match from { - TypedExpression::FieldElement(FieldElementExpression::Number(n)) => Ok(n.to_dec_string().parse::().unwrap()), - e => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Expected the lower bound of the range to be a constant field, found {}", - e - ), - }) - }?; + // let to = r + // .value + // .to + // .map(|e| self.check_expression(e, module_id, &types)) + // .unwrap_or(Ok(UExpressionInner::Value( + // array_size, + // ).annotate(UBitwidth::B32) + // .into()))?; - let to = match to { - TypedExpression::FieldElement(FieldElementExpression::Number(n)) => Ok(n.to_dec_string().parse::().unwrap()), - e => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Expected the higher bound of the range to be a constant field, found {}", - e - ), - }) - }?; + // // check the bounds are field constants + // // Note: it would be nice to allow any field expression, and check it's a constant after constant propagation, + // // but it's tricky from a type perspective: the size of the slice changes the type of the resulting array, + // // which doesn't work well with our static array approach. Enabling arrays to have unknown size introduces a lot + // // of complexity in the compiler, as function selection in inlining requires knowledge of the array size, but + // // determining array size potentially requires inlining and propagating. This suggests we would need semantic checking + // // to happen iteratively with inlining and propagation, which we can't do now as we go from absy to typed_absy + // let from = match from.get_type() { + // Type::Uint(UBitwidth::B32) => {} + // TypedExpression::Uint(e) => match e { + // match e.bitwidth() { + // UBitwidth::B32 => match e.into_inner() { + // UExpressionInner::Value(v) => Ok(v), + // e => Err(ErrorInner { + // pos: Some(pos), + // message: format!( + // "Expected the lower bound of the range to be a constant u32, found {}", + // e + // ), + // }) + // }, - match (from, to, array_size) { - (f, _, s) if f > s => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Lower range bound {} is out of array bounds [0, {}]", - f, s, - ), - }), - (_, t, s) if t > s => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Higher range bound {} is out of array bounds [0, {}]", - t, s, - ), - }), - (f, t, _) if f > t => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Lower range bound {} is larger than higher range bound {}", - f, t, - ), - }), - (f, t, _) => Ok(ArrayExpressionInner::Value( - (f..t) - .map(|i| match inner_type.clone() { - Type::FieldElement => FieldElementExpression::Select( - box array.clone(), - box FieldElementExpression::Number(T::from(i)), - ) - .into(), - Type::Boolean => BooleanExpression::Select( - box array.clone(), - box FieldElementExpression::Number(T::from(i)), - ) - .into(), - Type::Uint(bitwidth) => UExpressionInner::Select( - box array.clone(), - box FieldElementExpression::Number(T::from(i)), - ) - .annotate(bitwidth) - .into(), - Type::Struct(struct_ty) => { - StructExpressionInner::Select( - box array.clone(), - box FieldElementExpression::Number(T::from(i)), - ) - .annotate(struct_ty) - .into() - } - Type::Array(array_ty) => ArrayExpressionInner::Select( - box array.clone(), - box FieldElementExpression::Number(T::from(i)), - ) - .annotate(*array_ty.ty, array_ty.size) - .into(), - }) - .collect(), - ) - .annotate(inner_type, t - f) - .into()), - } - } - e => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Cannot access slice of expression {} of type {}", - e, - e.get_type(), - ), - }), - }, + // } + // }Ok(n), + // e => Err(ErrorInner { + // pos: Some(pos), + // message: format!( + // "Expected the lower bound of the range to be a constant u32, found {}", + // e + // ), + // }) + // }? as u32; + + // let to = match to { + // TypedExpression::Uint(Expression::Number(n)) => Ok(n), + // e => Err(ErrorInner { + // pos: Some(pos), + // message: format!( + // "Expected the higher bound of the range to be a constant u32, found {}", + // e + // ), + // }) + // }? as u32; + + // match (from, to, array_size) { + // (f, _, s) if f > s => Err(ErrorInner { + // pos: Some(pos), + // message: format!( + // "Lower range bound {} is out of array bounds [0, {}]", + // f, s, + // ), + // }), + // (_, t, s) if t > s => Err(ErrorInner { + // pos: Some(pos), + // message: format!( + // "Higher range bound {} is out of array bounds [0, {}]", + // t, s, + // ), + // }), + // (f, t, _) if f > t => Err(ErrorInner { + // pos: Some(pos), + // message: format!( + // "Lower range bound {} is larger than higher range bound {}", + // f, t, + // ), + // }), + // (f, t, _) => Ok(ArrayExpressionInner::Value( + // (f..t) + // .map(|i| match inner_type.clone() { + // Type::FieldElement => FieldElementExpression::Select( + // box array.clone(), + // box FieldElementExpression::Number(T::from(i)), + // ) + // .into(), + // Type::Boolean => BooleanExpression::Select( + // box array.clone(), + // box FieldElementExpression::Number(T::from(i)), + // ) + // .into(), + // Type::Uint(bitwidth) => UExpressionInner::Select( + // box array.clone(), + // box FieldElementExpression::Number(T::from(i)), + // ) + // .annotate(bitwidth) + // .into(), + // Type::Struct(struct_ty) => { + // StructExpressionInner::Select( + // box array.clone(), + // box FieldElementExpression::Number(T::from(i)), + // ) + // .annotate(struct_ty) + // .into() + // } + // Type::Array(array_ty) => ArrayExpressionInner::Select( + // box array.clone(), + // box FieldElementExpression::Number(T::from(i)), + // ) + // .annotate(*array_ty.ty, array_ty.size) + // .into(), + // }) + // .collect(), + // ) + // .annotate(inner_type, t - f) + // .into()), + // } + // }, + // l => Err(ErrorInner { + // pos: Some(pos), + // message: format!( + // "Range are not available for arrays of non-constant length, found {}", + // l, + // ), + // }) + // }; + // } + // e => Err(ErrorInner { + // pos: Some(pos), + // message: format!( + // "Cannot access slice of expression {} of type {}", + // e, + // e.get_type(), + // ), + // }), + // } + } RangeOrExpression::Expression(e) => { match (array, self.check_expression(e, module_id, &types)?) { - (TypedExpression::Array(a), TypedExpression::FieldElement(i)) => { + (TypedExpression::Array(a), TypedExpression::Uint(i)) => { match a.inner_type().clone() { Type::FieldElement => { Ok(FieldElementExpression::select(a, i).into()) @@ -1934,10 +2081,10 @@ impl<'ast> Checker<'ast> { unwrapped_expressions.push(unwrapped_e.into()); } - let size = unwrapped_expressions.len(); + let size = unwrapped_expressions.len() as u32; Ok(ArrayExpressionInner::Value(unwrapped_expressions) - .annotate(Type::FieldElement, size) + .annotate(Type::FieldElement, size as u32) .into()) } Type::Boolean => { @@ -1961,7 +2108,7 @@ impl<'ast> Checker<'ast> { unwrapped_expressions.push(unwrapped_e.into()); } - let size = unwrapped_expressions.len(); + let size = unwrapped_expressions.len() as u32; Ok(ArrayExpressionInner::Value(unwrapped_expressions) .annotate(Type::Boolean, size) @@ -2003,7 +2150,7 @@ impl<'ast> Checker<'ast> { unwrapped_expressions.push(unwrapped_e.into()); } - let size = unwrapped_expressions.len(); + let size = unwrapped_expressions.len() as u32; Ok(ArrayExpressionInner::Value(unwrapped_expressions) .annotate(ty, size) @@ -2045,7 +2192,7 @@ impl<'ast> Checker<'ast> { unwrapped_expressions.push(unwrapped_e.into()); } - let size = unwrapped_expressions.len(); + let size = unwrapped_expressions.len() as u32; Ok(ArrayExpressionInner::Value(unwrapped_expressions) .annotate(ty, size) @@ -2087,7 +2234,7 @@ impl<'ast> Checker<'ast> { unwrapped_expressions.push(unwrapped_e.into()); } - let size = unwrapped_expressions.len(); + let size = unwrapped_expressions.len() as u32; Ok(ArrayExpressionInner::Value(unwrapped_expressions) .annotate(ty, size) @@ -2096,7 +2243,7 @@ impl<'ast> Checker<'ast> { } } Expression::InlineStruct(id, inline_members) => { - let ty = self.check_type::( + let ty = self.check_type( UnresolvedType::User(id.clone()).at(42, 42, 42), module_id, &types, @@ -2342,7 +2489,7 @@ impl<'ast> Checker<'ast> { } } - fn get_scope(&self, variable_name: &'ast str) -> Option<&'ast ScopedVariable> { + fn get_scope<'a>(&'a self, variable_name: &'ast str) -> Option<&'a ScopedVariable<'ast, T>> { self.scope.get(&ScopedVariable { id: Variable::with_id_and_type( crate::typed_absy::Identifier::from(variable_name), @@ -2352,15 +2499,16 @@ impl<'ast> Checker<'ast> { }) } - fn insert_into_scope(&mut self, v: Variable<'ast>) -> bool { + fn insert_into_scope>>(&mut self, v: U) -> bool { self.scope.insert(ScopedVariable { - id: v, + id: v.into(), level: self.level, }) } - fn find_function(&self, query: &FunctionQuery<'ast>) -> Option> { - query.match_funcs(&self.functions) + fn find_function(&self, query: &FunctionQuery<'ast, T>) -> Option> { + // query.match_funcs(&self.functions) + unimplemented!() } fn enter_scope(&mut self) { @@ -2850,11 +2998,11 @@ mod tests { } } - pub fn new_with_args<'ast>( - scope: HashSet>, + pub fn new_with_args<'ast, T: Field>( + scope: HashSet>, level: usize, - functions: HashSet>, - ) -> Checker<'ast> { + functions: HashSet>, + ) -> Checker<'ast, T> { Checker { scope: scope, functions: functions, diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index eb2c7b81..e93f9a9f 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -4,28 +4,30 @@ use typed_absy::types::{StructType, UBitwidth}; use zir; use zokrates_field::Field; +use std::convert::{TryFrom, TryInto}; + pub struct Flattener { phantom: PhantomData, } -fn flatten_identifier_rec<'a>( - id: zir::SourceIdentifier<'a>, - ty: typed_absy::Type, -) -> Vec { +fn flatten_identifier_rec<'ast>( + id: zir::SourceIdentifier<'ast>, + ty: typed_absy::types::ConcreteType, +) -> Vec> { match ty { - typed_absy::Type::FieldElement => vec![zir::Variable { + typed_absy::types::ConcreteType::FieldElement => vec![zir::Variable { id: zir::Identifier::Source(id), _type: zir::Type::FieldElement, }], - typed_absy::Type::Boolean => vec![zir::Variable { + typed_absy::types::ConcreteType::Boolean => vec![zir::Variable { id: zir::Identifier::Source(id), _type: zir::Type::Boolean, }], - typed_absy::Type::Uint(bitwidth) => vec![zir::Variable { + typed_absy::types::ConcreteType::Uint(bitwidth) => vec![zir::Variable { id: zir::Identifier::Source(id), _type: zir::Type::uint(bitwidth.to_usize()), }], - typed_absy::Type::Array(array_type) => (0..array_type.size) + typed_absy::types::ConcreteType::Array(array_type) => (0..array_type.size) .flat_map(|i| { flatten_identifier_rec( zir::SourceIdentifier::Select(box id.clone(), i), @@ -33,7 +35,7 @@ fn flatten_identifier_rec<'a>( ) }) .collect(), - typed_absy::Type::Struct(members) => members + typed_absy::types::ConcreteType::Struct(members) => members .into_iter() .flat_map(|struct_member| { flatten_identifier_rec( @@ -75,9 +77,12 @@ impl<'ast, T: Field> Flattener { fold_function(self, f) } - fn fold_parameter(&mut self, p: typed_absy::Parameter<'ast>) -> Vec> { + fn fold_declaration_parameter( + &mut self, + p: typed_absy::DeclarationParameter<'ast>, + ) -> Vec> { let private = p.private; - self.fold_variable(p.id) + self.fold_variable(p.id.try_into().unwrap()) .into_iter() .map(|v| zir::Parameter { id: v, private }) .collect() @@ -87,10 +92,12 @@ impl<'ast, T: Field> Flattener { zir::SourceIdentifier::Basic(n) } - fn fold_variable(&mut self, v: typed_absy::Variable<'ast>) -> Vec> { + fn fold_variable(&mut self, v: typed_absy::Variable<'ast, T>) -> Vec> { let id = self.fold_name(v.id.clone()); let ty = v.get_type(); + let ty = typed_absy::types::ConcreteType::try_from(ty).unwrap(); + flatten_identifier_rec(id, ty) } @@ -146,6 +153,8 @@ impl<'ast, T: Field> Flattener { ) -> zir::ZirExpressionList<'ast, T> { match es { typed_absy::TypedExpressionList::FunctionCall(id, arguments, _) => { + let id = typed_absy::types::ConcreteFunctionKey::try_from(id).unwrap(); + zir::ZirExpressionList::FunctionCall( self.fold_function_key(id), arguments @@ -160,7 +169,7 @@ impl<'ast, T: Field> Flattener { fn fold_function_key( &mut self, - k: typed_absy::types::FunctionKey<'ast>, + k: typed_absy::types::ConcreteFunctionKey<'ast>, ) -> zir::types::FunctionKey<'ast> { k.into() } @@ -194,7 +203,7 @@ impl<'ast, T: Field> Flattener { fn fold_array_expression_inner( &mut self, - ty: &typed_absy::Type, + ty: &typed_absy::types::ConcreteType, size: usize, e: typed_absy::ArrayExpressionInner<'ast, T>, ) -> Vec> { @@ -202,7 +211,7 @@ impl<'ast, T: Field> Flattener { } fn fold_struct_expression_inner( &mut self, - ty: &StructType, + ty: &typed_absy::types::ConcreteStructType, e: typed_absy::StructExpressionInner<'ast, T>, ) -> Vec> { fold_struct_expression_inner(self, ty, e) @@ -217,7 +226,12 @@ pub fn fold_module<'ast, T: Field>( functions: p .functions .into_iter() - .map(|(key, fun)| (f.fold_function_key(key), f.fold_function_symbol(fun))) + .map(|(key, fun)| { + ( + f.fold_function_key(key.try_into().unwrap()), + f.fold_function_symbol(fun), + ) + }) .collect(), } } @@ -267,14 +281,16 @@ pub fn fold_statement<'ast, T: Field>( pub fn fold_array_expression_inner<'ast, T: Field>( f: &mut Flattener, - t: &typed_absy::Type, + t: &typed_absy::types::ConcreteType, size: usize, e: typed_absy::ArrayExpressionInner<'ast, T>, ) -> Vec> { match e { typed_absy::ArrayExpressionInner::Identifier(id) => { - let variables = - flatten_identifier_rec(f.fold_name(id), typed_absy::Type::array(t.clone(), size)); + let variables = flatten_identifier_rec( + f.fold_name(id), + typed_absy::types::ConcreteType::array(t.clone(), size), + ); variables .into_iter() .map(|v| match v._type { @@ -329,7 +345,11 @@ pub fn fold_array_expression_inner<'ast, T: Field>( let offset: usize = members .iter() .take_while(|member| member.id != id) - .map(|member| member.ty.get_primitive_count()) + .map(|member| { + typed_absy::types::ConcreteType::try_from(*member.ty) + .unwrap() + .get_primitive_count() + }) .sum(); // we also need the size of this member @@ -339,12 +359,15 @@ pub fn fold_array_expression_inner<'ast, T: Field>( } typed_absy::ArrayExpressionInner::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); - match index { - zir::FieldElementExpression::Number(i) => { - let size = t.get_primitive_count() * size; - let start = i.to_dec_string().parse::().unwrap() * size; + match index.into_inner() { + zir::UExpressionInner::Value(i) => { + let size = typed_absy::types::ConcreteType::try_from(*t) + .unwrap() + .get_primitive_count() + * size; + let start = i as usize * size; let end = start + size; array[start..end].to_vec() } @@ -356,13 +379,15 @@ pub fn fold_array_expression_inner<'ast, T: Field>( pub fn fold_struct_expression_inner<'ast, T: Field>( f: &mut Flattener, - t: &StructType, + t: &typed_absy::types::ConcreteStructType, e: typed_absy::StructExpressionInner<'ast, T>, ) -> Vec> { match e { typed_absy::StructExpressionInner::Identifier(id) => { - let variables = - flatten_identifier_rec(f.fold_name(id), typed_absy::Type::struc(t.clone())); + let variables = flatten_identifier_rec( + f.fold_name(id), + typed_absy::types::ConcreteType::struc(t.clone()), + ); variables .into_iter() .map(|v| match v._type { @@ -417,30 +442,33 @@ pub fn fold_struct_expression_inner<'ast, T: Field>( let offset: usize = members .iter() .take_while(|member| member.id != id) - .map(|member| member.ty.get_primitive_count()) + .map(|member| { + typed_absy::types::ConcreteType::try_from(*member.ty) + .unwrap() + .get_primitive_count() + }) .sum(); // we also need the size of this member - let size = t - .iter() - .find(|member| member.id == id) - .unwrap() - .ty - .get_primitive_count(); + let size = typed_absy::types::ConcreteType::try_from( + *t.iter().find(|member| member.id == id).unwrap().ty, + ) + .unwrap() + .get_primitive_count(); s[offset..offset + size].to_vec() } typed_absy::StructExpressionInner::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); - match index { - zir::FieldElementExpression::Number(i) => { + match index.into_inner() { + zir::UExpressionInner::Value(i) => { let size = t .iter() .map(|m| m.ty.get_primitive_count()) .fold(0, |acc, current| acc + current); - let start = i.to_dec_string().parse::().unwrap() * size; + let start = i as usize * size; let end = start + size; array[start..end].to_vec() } @@ -458,9 +486,12 @@ pub fn fold_field_expression<'ast, T: Field>( typed_absy::FieldElementExpression::Number(n) => zir::FieldElementExpression::Number(n), typed_absy::FieldElementExpression::Identifier(id) => { zir::FieldElementExpression::Identifier( - flatten_identifier_rec(f.fold_name(id), typed_absy::Type::FieldElement)[0] - .id - .clone(), + flatten_identifier_rec( + f.fold_name(id), + typed_absy::types::ConcreteType::FieldElement, + )[0] + .id + .clone(), ) } typed_absy::FieldElementExpression::Add(box e1, box e2) => { @@ -503,26 +534,22 @@ pub fn fold_field_expression<'ast, T: Field>( let offset: usize = members .iter() .take_while(|member| member.id != id) - .map(|member| member.ty.get_primitive_count()) + .map(|member| { + typed_absy::types::ConcreteType::try_from(*member.ty) + .unwrap() + .get_primitive_count() + }) .sum(); - use std::convert::TryInto; - s[offset].clone().try_into().unwrap() } typed_absy::FieldElementExpression::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); - use std::convert::TryInto; - - match index { - zir::FieldElementExpression::Number(i) => array - [i.to_dec_string().parse::().unwrap()] - .clone() - .try_into() - .unwrap(), + match index.into_inner() { + zir::UExpressionInner::Value(i) => array[i as usize].clone().try_into().unwrap(), _ => unreachable!(""), } } @@ -536,7 +563,7 @@ pub fn fold_boolean_expression<'ast, T: Field>( match e { typed_absy::BooleanExpression::Value(v) => zir::BooleanExpression::Value(v), typed_absy::BooleanExpression::Identifier(id) => zir::BooleanExpression::Identifier( - flatten_identifier_rec(f.fold_name(id), typed_absy::Type::Boolean)[0] + flatten_identifier_rec(f.fold_name(id), typed_absy::types::ConcreteType::Boolean)[0] .id .clone(), ), @@ -661,25 +688,21 @@ pub fn fold_boolean_expression<'ast, T: Field>( let offset: usize = members .iter() .take_while(|member| member.id != id) - .map(|member| member.ty.get_primitive_count()) + .map(|member| { + typed_absy::types::ConcreteType::try_from(*member.ty) + .unwrap() + .get_primitive_count() + }) .sum(); - use std::convert::TryInto; - s[offset].clone().try_into().unwrap() } typed_absy::BooleanExpression::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); - use std::convert::TryInto; - - match index { - zir::FieldElementExpression::Number(i) => array - [i.to_dec_string().parse::().unwrap()] - .clone() - .try_into() - .unwrap(), + match index.into_inner() { + zir::UExpressionInner::Value(i) => array[i as usize].clone().try_into().unwrap(), _ => unreachable!(), } } @@ -702,9 +725,12 @@ pub fn fold_uint_expression_inner<'ast, T: Field>( match e { typed_absy::UExpressionInner::Value(v) => zir::UExpressionInner::Value(v), typed_absy::UExpressionInner::Identifier(id) => zir::UExpressionInner::Identifier( - flatten_identifier_rec(f.fold_name(id), typed_absy::Type::Uint(bitwidth))[0] - .id - .clone(), + flatten_identifier_rec( + f.fold_name(id), + typed_absy::types::ConcreteType::Uint(bitwidth), + )[0] + .id + .clone(), ), typed_absy::UExpressionInner::Add(box left, box right) => { let left = f.fold_uint_expression(left); @@ -764,16 +790,11 @@ pub fn fold_uint_expression_inner<'ast, T: Field>( } typed_absy::UExpressionInner::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); - use std::convert::TryInto; - - match index { - zir::FieldElementExpression::Number(i) => { - let e: zir::UExpression<_> = array[i.to_dec_string().parse::().unwrap()] - .clone() - .try_into() - .unwrap(); + match index.into_inner() { + zir::UExpressionInner::Value(i) => { + let e: zir::UExpression<_> = array[i as usize].clone().try_into().unwrap(); e.into_inner() } _ => unreachable!(), @@ -787,11 +808,13 @@ pub fn fold_uint_expression_inner<'ast, T: Field>( let offset: usize = members .iter() .take_while(|member| member.id != id) - .map(|member| member.ty.get_primitive_count()) + .map(|member| { + typed_absy::types::ConcreteType::try_from(*member.ty) + .unwrap() + .get_primitive_count() + }) .sum(); - use std::convert::TryInto; - let res: zir::UExpression<'ast, T> = s[offset].clone().try_into().unwrap(); res.into_inner() @@ -813,14 +836,16 @@ pub fn fold_function<'ast, T: Field>( arguments: fun .arguments .into_iter() - .flat_map(|a| f.fold_parameter(a)) + .flat_map(|a| f.fold_declaration_parameter(a)) .collect(), statements: fun .statements .into_iter() .flat_map(|s| f.fold_statement(s)) .collect(), - signature: fun.signature.into(), + signature: typed_absy::types::ConcreteSignature::try_from(fun.signature) + .unwrap() + .into(), } } @@ -828,14 +853,25 @@ pub fn fold_array_expression<'ast, T: Field>( f: &mut Flattener, e: typed_absy::ArrayExpression<'ast, T>, ) -> Vec> { - f.fold_array_expression_inner(&e.inner_type().clone(), e.size(), e.into_inner()) + let size = match e.size().into_inner() { + typed_absy::UExpressionInner::Value(v) => v, + _ => unreachable!(), + } as usize; + f.fold_array_expression_inner( + &typed_absy::types::ConcreteType::try_from(e.inner_type().clone()).unwrap(), + size, + e.into_inner(), + ) } pub fn fold_struct_expression<'ast, T: Field>( f: &mut Flattener, e: typed_absy::StructExpression<'ast, T>, ) -> Vec> { - f.fold_struct_expression_inner(&e.ty().clone(), e.into_inner()) + f.fold_struct_expression_inner( + &typed_absy::types::ConcreteStructType::try_from(e.ty().clone()).unwrap(), + e.into_inner(), + ) } pub fn fold_function_symbol<'ast, T: Field>( @@ -846,9 +882,10 @@ pub fn fold_function_symbol<'ast, T: Field>( typed_absy::TypedFunctionSymbol::Here(fun) => { zir::ZirFunctionSymbol::Here(f.fold_function(fun)) } - typed_absy::TypedFunctionSymbol::There(key, module) => { - zir::ZirFunctionSymbol::There(f.fold_function_key(key), module) - } // by default, do not fold modules recursively + typed_absy::TypedFunctionSymbol::There(key, module) => zir::ZirFunctionSymbol::There( + f.fold_function_key(typed_absy::types::ConcreteFunctionKey::try_from(key).unwrap()), + module, + ), // by default, do not fold modules recursively typed_absy::TypedFunctionSymbol::Flat(flat) => zir::ZirFunctionSymbol::Flat(flat), } } diff --git a/zokrates_core/src/static_analysis/inline.rs b/zokrates_core/src/static_analysis/inline.rs index 94a13fe9..ba5afba4 100644 --- a/zokrates_core/src/static_analysis/inline.rs +++ b/zokrates_core/src/static_analysis/inline.rs @@ -18,6 +18,8 @@ use static_analysis::propagate_unroll::{Blocked, Output}; use std::collections::HashMap; +use std::convert::{TryFrom, TryInto}; +use typed_absy::types::ConcreteFunctionKey; use typed_absy::types::{FunctionKey, Type, UBitwidth}; use typed_absy::{folder::*, *}; use zokrates_field::Field; @@ -25,7 +27,7 @@ use zokrates_field::Field; #[derive(Debug, PartialEq, Eq, Hash, Clone)] struct Location<'ast> { module: TypedModuleId, - key: FunctionKey<'ast>, + key: ConcreteFunctionKey<'ast>, } impl<'ast> Location<'ast> { @@ -37,11 +39,16 @@ impl<'ast> Location<'ast> { type CallCache<'ast, T> = HashMap< Location<'ast>, HashMap< - FunctionKey<'ast>, + ConcreteFunctionKey<'ast>, HashMap>, Vec>>, >, >; +enum InlineError<'ast, T> { + Flat(FunctionKey<'ast, T>, Vec>), + NonConstant(FunctionKey<'ast, T>, Vec>), +} + /// An inliner #[derive(Debug)] pub struct Inliner<'ast, T: Field> { @@ -52,9 +59,9 @@ pub struct Inliner<'ast, T: Field> { /// a buffer of statements to be added to the inlined statements statement_buffer: Vec>, /// the current call stack - stack: Vec<(TypedModuleId, FunctionKey<'ast>, usize)>, + stack: Vec<(TypedModuleId, ConcreteFunctionKey<'ast>, usize)>, /// the call count for each function - call_count: HashMap<(TypedModuleId, FunctionKey<'ast>), usize>, + call_count: HashMap<(TypedModuleId, ConcreteFunctionKey<'ast>), usize>, /// the cache for memoization: for each function body, tracks function calls call_cache: CallCache<'ast, T>, /// whether the inliner is blocked, and why @@ -65,7 +72,7 @@ impl<'ast, T: Field> Inliner<'ast, T> { fn with_modules_and_module_id_and_key>( modules: TypedModules<'ast, T>, module_id: S, - key: FunctionKey<'ast>, + key: ConcreteFunctionKey<'ast>, ) -> Self { Inliner { modules, @@ -95,7 +102,11 @@ impl<'ast, T: Field> Inliner<'ast, T> { .0; // initialize an inliner over all modules, starting from the main module - Inliner::with_modules_and_module_id_and_key(p.modules, main_module_id, main_key.clone()) + Inliner::with_modules_and_module_id_and_key( + p.modules, + main_module_id, + main_key.clone().try_into().unwrap(), + ) } pub fn inline(&mut self, p: TypedProgram<'ast, T>) -> Output<'ast, T> { @@ -162,22 +173,47 @@ impl<'ast, T: Field> Inliner<'ast, T> { } } + fn get_concrete_function( + &self, + key: ConcreteFunctionKey<'ast>, + ) -> TypedFunctionSymbol<'ast, T> { + unimplemented!() + } + + fn try_inline_call( + &mut self, + key: FunctionKey<'ast, T>, + expressions: Vec>, + ) -> Result>, InlineError<'ast, T>> { + match ConcreteFunctionKey::try_from(key) { + Ok(key) => self + .try_inline_concrete_call(key, expressions) + .map_err(|e| InlineError::Flat(e.0.into(), e.1)), + Err(()) => { + self.blocked = Some(Blocked::Inline); + Err(InlineError::NonConstant(key, expressions)) + } + } + } + /// try to inline a call to function with key `key` in the stack of `self` /// if inlining succeeds, return the expressions returned by the function call /// if inlining fails (as in the case of flat function symbols), return the arguments to the function call for further processing - fn try_inline_call( + fn try_inline_concrete_call( &mut self, - key: &FunctionKey<'ast>, + key: ConcreteFunctionKey<'ast>, expressions: Vec>, - ) -> Result>, (FunctionKey<'ast>, Vec>)> - { - match self.call_cache().get(key).map(|m| m.get(&expressions)) { + ) -> Result< + Vec>, + (ConcreteFunctionKey<'ast>, Vec>), + > { + match self.call_cache().get(&key).map(|m| m.get(&expressions)) { Some(Some(exprs)) => return Ok(exprs.clone()), _ => {} }; // here we clone a function symbol, which is cheap except when it contains the function body, in which case we'd clone anyways - let res = match self.module().functions.get(&key).unwrap().clone() { + let res = match self.get_concrete_function(key) { // if the function called is in the same module, we can go ahead and inline in this module TypedFunctionSymbol::Here(function) => { let (current_module, current_key) = @@ -200,7 +236,7 @@ impl<'ast, T: Field> Inliner<'ast, T> { .zip(expressions.clone()) .map(|(a, e)| { TypedStatement::Definition( - self.fold_assignee(TypedAssignee::Identifier(a.id.clone())), + self.fold_assignee(TypedAssignee::Identifier(a.id.clone().into())), e, ) }) @@ -233,24 +269,31 @@ impl<'ast, T: Field> Inliner<'ast, T> { } // if the function called is in some other module, we switch focus to that module and call the function locally there TypedFunctionSymbol::There(function_key, module_id) => { - // switch focus to `module_id` - let (current_module, current_key) = - self.change_context(module_id, function_key.clone()); - // inline the call there - let res = self.try_inline_call(&function_key, expressions.clone())?; - // switch back focus - self.change_context(current_module, current_key); - Ok(res) + unimplemented!() + + // let function_key = function_key.try_into().unwrap(); + + // // switch focus to `module_id` + // let (current_module, current_key) = + // self.change_context(module_id, function_key.clone()); + // // inline the call there + // let res = self.try_inline_call(&function_key, expressions.clone())?; + // // switch back focus + // self.change_context(current_module, current_key); + // Ok(res) } // if the function is a flat symbol, replace the call with a call to the local function we provide so it can be inlined in flattening TypedFunctionSymbol::Flat(embed) => { // increase the number of calls for this function by one let _ = self .call_count - .entry((self.module_id().clone(), embed.key::().clone())) + .entry(( + self.module_id().clone(), + embed.key::().clone().try_into().unwrap(), + )) .and_modify(|i| *i += 1) .or_insert(1); - Err((embed.key::(), expressions.clone())) + Err((embed.key::().try_into().unwrap(), expressions.clone())) } }; @@ -267,8 +310,8 @@ impl<'ast, T: Field> Inliner<'ast, T> { fn change_context( &mut self, module_id: TypedModuleId, - function_key: FunctionKey<'ast>, - ) -> (TypedModuleId, FunctionKey<'ast>) { + function_key: ConcreteFunctionKey<'ast>, + ) -> (TypedModuleId, ConcreteFunctionKey<'ast>) { let current_module = std::mem::replace(&mut self.location.module, module_id); let current_key = std::mem::replace(&mut self.location.key, function_key); (current_module, current_key) @@ -281,7 +324,7 @@ impl<'ast, T: Field> Inliner<'ast, T> { fn call_cache( &mut self, ) -> &HashMap< - FunctionKey<'ast>, + ConcreteFunctionKey<'ast>, HashMap>, Vec>>, > { self.call_cache @@ -292,7 +335,7 @@ impl<'ast, T: Field> Inliner<'ast, T> { fn call_cache_mut( &mut self, ) -> &mut HashMap< - FunctionKey<'ast>, + ConcreteFunctionKey<'ast>, HashMap>, Vec>>, > { self.call_cache.get_mut(&self.location).unwrap() @@ -305,7 +348,10 @@ impl<'ast, T: Field> Inliner<'ast, T> { impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { // add extra statements before the modified statement - fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { + fn fold_statement<'a>( + &'a mut self, + s: TypedStatement<'ast, T>, + ) -> Vec> { let folded = match s { TypedStatement::For(v, from, to, statements) => { self.blocked = Some(Blocked::Unroll); @@ -319,7 +365,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { .collect(); let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); - match self.try_inline_call(&key, exps) { + match self.try_inline_call(key, exps) { Ok(ret) => variables .into_iter() .zip(ret.into_iter()) @@ -327,16 +373,22 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { TypedStatement::Definition(TypedAssignee::Identifier(v), e) }) .collect(), - Err((key, expressions)) => vec![TypedStatement::MultipleDefinition( - variables, - TypedExpressionList::FunctionCall(key, expressions, types), - )], + Err(e) => match e { + InlineError::Flat(key, expressions) + | InlineError::NonConstant(key, expressions) => { + vec![TypedStatement::MultipleDefinition( + variables, + TypedExpressionList::FunctionCall(key, expressions, types), + )] + } + }, } } }, s => fold_statement(self, s), }; - self.statement_buffer.drain(..).chain(folded).collect() + unimplemented!() + //self.statement_buffer.drain(..).chain(folded).collect() } // prefix all names with the stack @@ -356,28 +408,50 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { FieldElementExpression::FunctionCall(key, exps) => { let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); - match self.try_inline_call(&key, exps) { + match self.try_inline_call(key, exps) { Ok(mut ret) => match ret.pop().unwrap() { TypedExpression::FieldElement(e) => e, _ => unreachable!(), }, - Err((key, expressions)) => { - let tys = key.signature.outputs.clone(); - let id = Identifier { - id: CoreIdentifier::Call(key.clone()), - version: *self - .call_count - .get(&(self.module_id().clone(), key.clone())) - .unwrap(), - stack: self.stack.clone(), - }; - self.statement_buffer - .push(TypedStatement::MultipleDefinition( - vec![Variable::with_id_and_type(id.clone(), tys[0].clone())], - TypedExpressionList::FunctionCall(key, expressions, tys), - )); - FieldElementExpression::Identifier(id) - } + Err(e) => match e { + InlineError::Flat(key, expressions) => { + let key = ConcreteFunctionKey::try_from(key).unwrap(); + let tys = key.signature.outputs.clone(); + let id = Identifier { + id: CoreIdentifier::Call(key.clone()), + version: *self + .call_count + .get(&(self.module_id().clone(), key.clone())) + .unwrap(), + stack: self.stack.clone(), + }; + self.statement_buffer + .push(TypedStatement::MultipleDefinition( + vec![Variable::with_id_and_type( + id.clone(), + tys[0].clone().into(), + )], + TypedExpressionList::FunctionCall( + key.into(), + expressions, + tys.into_iter().map(|t| t.into()).collect(), + ), + )); + + self.call_cache_mut() + .entry(key.clone()) + .or_insert_with(|| HashMap::new()) + .insert( + expressions, + vec![FieldElementExpression::Identifier(id.clone()).into()], + ); + + FieldElementExpression::Identifier(id) + } + InlineError::NonConstant(key, expressions) => { + FieldElementExpression::FunctionCall(key, expressions) + } + }, } } e => fold_field_expression(self, e), @@ -393,41 +467,51 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { BooleanExpression::FunctionCall(key, exps) => { let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); - match self.try_inline_call(&key, exps) { + match self.try_inline_call(key, exps) { Ok(mut ret) => match ret.pop().unwrap() { TypedExpression::Boolean(e) => e, _ => unreachable!(), }, - Err((key, expressions)) => { - let tys = key.signature.outputs.clone(); - let id = Identifier { - id: CoreIdentifier::Call(key.clone()), - version: *self - .call_count - .get(&(self.module_id().clone(), key.clone())) - .unwrap(), - stack: self.stack.clone(), - }; - self.statement_buffer - .push(TypedStatement::MultipleDefinition( - vec![Variable::with_id_and_type(id.clone(), tys[0].clone())], - TypedExpressionList::FunctionCall( - key.clone(), - expressions.clone(), - tys, - ), - )); + Err(e) => match e { + InlineError::Flat(key, expressions) => { + let key = ConcreteFunctionKey::try_from(key).unwrap(); - self.call_cache_mut() - .entry(key.clone()) - .or_insert_with(|| HashMap::new()) - .insert( - expressions, - vec![BooleanExpression::Identifier(id.clone()).into()], - ); + let tys = key.signature.outputs.clone(); + let id = Identifier { + id: CoreIdentifier::Call(key.clone()), + version: *self + .call_count + .get(&(self.module_id().clone(), key.clone())) + .unwrap(), + stack: self.stack.clone(), + }; + self.statement_buffer + .push(TypedStatement::MultipleDefinition( + vec![Variable::with_id_and_type( + id.clone(), + tys[0].clone().into(), + )], + TypedExpressionList::FunctionCall( + key.into(), + expressions, + tys.into_iter().map(|t| t.into()).collect(), + ), + )); - BooleanExpression::Identifier(id) - } + self.call_cache_mut() + .entry(key.clone()) + .or_insert_with(|| HashMap::new()) + .insert( + expressions, + vec![BooleanExpression::Identifier(id.clone()).into()], + ); + + BooleanExpression::Identifier(id) + } + InlineError::NonConstant(key, expressions) => { + BooleanExpression::FunctionCall(key, expressions) + } + }, } } e => fold_boolean_expression(self, e), @@ -437,51 +521,61 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { // inline calls which return an array fn fold_array_expression_inner( &mut self, - ty: &Type, - size: usize, + ty: &Type<'ast, T>, + size: UExpression<'ast, T>, e: ArrayExpressionInner<'ast, T>, ) -> ArrayExpressionInner<'ast, T> { match e { ArrayExpressionInner::FunctionCall(key, exps) => { let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); - match self.try_inline_call(&key, exps) { + match self.try_inline_call(key, exps) { Ok(mut ret) => match ret.pop().unwrap() { TypedExpression::Array(e) => e.into_inner(), _ => unreachable!(), }, - Err((embed_key, expressions)) => { - let tys = key.signature.outputs.clone(); - let id = Identifier { - id: CoreIdentifier::Call(key.clone()), - version: *self - .call_count - .get(&(self.module_id().clone(), embed_key.clone())) - .unwrap(), - stack: self.stack.clone(), - }; - self.statement_buffer - .push(TypedStatement::MultipleDefinition( - vec![Variable::with_id_and_type(id.clone(), tys[0].clone())], - TypedExpressionList::FunctionCall( - embed_key.clone(), - expressions.clone(), - tys, - ), - )); + Err(e) => match e { + InlineError::Flat(key, expressions) => { + let key = ConcreteFunctionKey::try_from(key).unwrap(); - let out = ArrayExpressionInner::Identifier(id); + let tys = key.signature.outputs.clone(); + let id = Identifier { + id: CoreIdentifier::Call(key.clone()), + version: *self + .call_count + .get(&(self.module_id().clone(), key.clone())) + .unwrap(), + stack: self.stack.clone(), + }; + self.statement_buffer + .push(TypedStatement::MultipleDefinition( + vec![Variable::with_id_and_type( + id.clone(), + tys[0].clone().into(), + )], + TypedExpressionList::FunctionCall( + key.into(), + expressions.clone(), + tys.into_iter().map(|t| t.into()).collect(), + ), + )); - self.call_cache_mut() - .entry(key.clone()) - .or_insert_with(|| HashMap::new()) - .insert( - expressions, - vec![out.clone().annotate(ty.clone(), size).into()], - ); + let out = ArrayExpressionInner::Identifier(id); - out - } + self.call_cache_mut() + .entry(key.clone()) + .or_insert_with(|| HashMap::new()) + .insert( + expressions, + vec![out.clone().annotate(ty.clone().into(), size).into()], + ); + + out + } + InlineError::NonConstant(key, expressions) => { + ArrayExpressionInner::FunctionCall(key, expressions) + } + }, } } // default @@ -491,35 +585,57 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { fn fold_struct_expression_inner( &mut self, - ty: &StructType, + ty: &StructType<'ast, T>, e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { match e { StructExpressionInner::FunctionCall(key, exps) => { let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); - match self.try_inline_call(&key, exps) { + match self.try_inline_call(key, exps) { Ok(mut ret) => match ret.pop().unwrap() { TypedExpression::Struct(e) => e.into_inner(), _ => unreachable!(), }, - Err((key, expressions)) => { - let tys = key.signature.outputs.clone(); - let id = Identifier { - id: CoreIdentifier::Call(key.clone()), - version: *self - .call_count - .get(&(self.module_id().clone(), key.clone())) - .unwrap(), - stack: self.stack.clone(), - }; - self.statement_buffer - .push(TypedStatement::MultipleDefinition( - vec![Variable::with_id_and_type(id.clone(), tys[0].clone())], - TypedExpressionList::FunctionCall(key, expressions, tys), - )); - StructExpressionInner::Identifier(id) - } + Err(e) => match e { + InlineError::Flat(key, expressions) => { + let key = ConcreteFunctionKey::try_from(key).unwrap(); + + let tys = key.signature.outputs.clone(); + let id = Identifier { + id: CoreIdentifier::Call(key.clone()), + version: *self + .call_count + .get(&(self.module_id().clone(), key.clone())) + .unwrap(), + stack: self.stack.clone(), + }; + self.statement_buffer + .push(TypedStatement::MultipleDefinition( + vec![Variable::with_id_and_type( + id.clone(), + tys[0].clone().into(), + )], + TypedExpressionList::FunctionCall( + key.into(), + expressions, + tys.into_iter().map(|t| t.into()).collect(), + ), + )); + + let out = StructExpressionInner::Identifier(id); + + self.call_cache_mut() + .entry(key.clone()) + .or_insert_with(|| HashMap::new()) + .insert(expressions, vec![out.clone().annotate(ty.clone()).into()]); + + out + } + InlineError::NonConstant(key, expressions) => { + StructExpressionInner::FunctionCall(key, expressions) + } + }, } } // default @@ -536,40 +652,50 @@ impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { UExpressionInner::FunctionCall(key, exps) => { let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); - match self.try_inline_call(&key, exps) { + match self.try_inline_call(key, exps) { Ok(mut ret) => match ret.pop().unwrap() { TypedExpression::Uint(e) => e.into_inner(), _ => unreachable!(), }, - Err((embed_key, expressions)) => { - let tys = key.signature.outputs.clone(); - let id = Identifier { - id: CoreIdentifier::Call(key.clone()), - version: *self - .call_count - .get(&(self.module_id().clone(), embed_key.clone())) - .unwrap(), - stack: self.stack.clone(), - }; - self.statement_buffer - .push(TypedStatement::MultipleDefinition( - vec![Variable::with_id_and_type(id.clone(), tys[0].clone())], - TypedExpressionList::FunctionCall( - embed_key.clone(), - expressions.clone(), - tys, - ), - )); + Err(e) => match e { + InlineError::Flat(embed_key, expressions) => { + let key = ConcreteFunctionKey::try_from(key).unwrap(); - let out = UExpressionInner::Identifier(id); + let tys = key.signature.outputs.clone(); + let id = Identifier { + id: CoreIdentifier::Call(key.clone()), + version: *self + .call_count + .get(&(self.module_id().clone(), key.clone().into())) + .unwrap(), + stack: self.stack.clone(), + }; + self.statement_buffer + .push(TypedStatement::MultipleDefinition( + vec![Variable::with_id_and_type( + id.clone(), + tys[0].clone().into(), + )], + TypedExpressionList::FunctionCall( + embed_key.clone().into(), + expressions.clone(), + tys.into_iter().map(|t| t.into()).collect(), + ), + )); - self.call_cache_mut() - .entry(key.clone()) - .or_insert_with(|| HashMap::new()) - .insert(expressions, vec![out.clone().annotate(size).into()]); + let out = UExpressionInner::Identifier(id); - out - } + self.call_cache_mut() + .entry(key.clone()) + .or_insert_with(|| HashMap::new()) + .insert(expressions, vec![out.clone().annotate(size).into()]); + + out + } + InlineError::NonConstant(key, expressions) => { + UExpressionInner::FunctionCall(key, expressions) + } + }, } } // default diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index ce769280..6925b334 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -17,9 +17,7 @@ mod unroll; mod variable_access_remover; use self::flatten_complex_types::Flattener; -use self::inline::Inliner; use self::propagate_unroll::PropagatedUnroller; -use self::propagation::Propagator; use self::redefinition::RedefinitionOptimizer; use self::return_binder::ReturnBinder; use self::uint_optimizer::UintOptimizer; @@ -27,7 +25,7 @@ use self::unconstrained_vars::UnconstrainedVariableDetector; use self::variable_access_remover::VariableAccessRemover; use crate::flat_absy::FlatProg; use crate::ir::Prog; -use crate::typed_absy::TypedProgram; +use crate::typed_absy::{abi::Abi, TypedProgram}; use zir::ZirProgram; use zokrates_field::Field; @@ -36,9 +34,12 @@ pub trait Analyse { } impl<'ast, T: Field> TypedProgram<'ast, T> { - pub fn analyse(self) -> ZirProgram<'ast, T> { + pub fn analyse(self) -> (ZirProgram<'ast, T>, Abi) { // propagated unrolling let r = PropagatedUnroller::unroll(self).unwrap_or_else(|e| panic!(e)); + + let abi = r.abi().unwrap(); + // return binding let r = ReturnBinder::bind(r); @@ -60,7 +61,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { // optimize uint expressions let zir = UintOptimizer::optimize(zir); - zir + (zir, abi) } } diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 6d6aa75b..2f18e8e1 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -139,7 +139,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { .collect(); fn process_u_from_bits<'ast, T: Field>( - variables: Vec>, + variables: Vec>, arguments: Vec>, bitwidth: UBitwidth, ) -> TypedExpression<'ast, T> { @@ -183,7 +183,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } fn process_u_to_bits<'ast, T: Field>( - variables: Vec>, + variables: Vec>, arguments: Vec>, bitwidth: UBitwidth, ) -> TypedExpression<'ast, T> { @@ -213,7 +213,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { .map(|v| BooleanExpression::Value(v).into()) .collect(), ) - .annotate(Type::Boolean, bitwidth.to_usize()) + .annotate(Type::Boolean, bitwidth.to_usize() as u32) .into() } _ => unreachable!("should be a uint value"), @@ -282,7 +282,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { }) .collect(), ) - .annotate(Type::Boolean, T::get_required_bits()) + .annotate( + Type::Boolean, + T::get_required_bits() as u32, + ) .into(), ) } @@ -505,45 +508,59 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } UExpressionInner::Select(box array, box index) => { let array = self.fold_array_expression(array); - let index = self.fold_field_expression(index); + let index = self.fold_uint_expression(index); let inner_type = array.inner_type().clone(); let size = array.size(); - match (array.into_inner(), index) { - (ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => { - let n_as_usize = n.to_dec_string().parse::().unwrap(); - if n_as_usize < size { - UExpression::try_from(v[n_as_usize].clone()) - .unwrap() - .into_inner() - } else { - unreachable!( + match size.into_inner() { + UExpressionInner::Value(size) => { + match (array.into_inner(), index.into_inner()) { + (ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { + if n < size { + UExpression::try_from(v[n as usize].clone()) + .unwrap() + .into_inner() + } else { + unreachable!( "out of bounds index ({} >= {}) found during static analysis", - n_as_usize, size + n, size ); - } - } - (ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => { - match self.constants.get(&TypedAssignee::Select( - box TypedAssignee::Identifier(Variable::array( - id.clone(), - inner_type.clone(), - size, - )), - box FieldElementExpression::Number(n.clone()).into(), - )) { - Some(e) => match e { - TypedExpression::Uint(e) => e.clone().into_inner(), - _ => unreachable!(""), - }, - None => UExpressionInner::Select( - box ArrayExpressionInner::Identifier(id).annotate(inner_type, size), - box FieldElementExpression::Number(n), + } + } + (ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { + match self.constants.get(&TypedAssignee::Select( + box TypedAssignee::Identifier(Variable::array( + id.clone(), + inner_type.clone(), + (size as u32).into(), + )), + box UExpressionInner::Value(n.clone()) + .annotate(UBitwidth::B32) + .into(), + )) { + Some(e) => match e { + TypedExpression::Uint(e) => e.clone().into_inner(), + _ => unreachable!(""), + }, + None => UExpressionInner::Select( + box ArrayExpressionInner::Identifier(id) + .annotate(inner_type, size as u32), + box UExpressionInner::Value(n).annotate(UBitwidth::B32), + ), + } + } + (a, i) => UExpressionInner::Select( + box a.annotate(inner_type, size as u32), + box i.annotate(UBitwidth::B32), ), } } - (a, i) => UExpressionInner::Select(box a.annotate(inner_type, size), box i), + size => fold_uint_expression_inner( + self, + bitwidth, + UExpressionInner::Select(box array, box index), + ), } } UExpressionInner::FunctionCall(key, arguments) => { @@ -647,45 +664,54 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } FieldElementExpression::Select(box array, box index) => { let array = self.fold_array_expression(array); - let index = self.fold_field_expression(index); + let index = self.fold_uint_expression(index); let inner_type = array.inner_type().clone(); let size = array.size(); - match (array.into_inner(), index) { - (ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => { - let n_as_usize = n.to_dec_string().parse::().unwrap(); - if n_as_usize < size { - FieldElementExpression::try_from(v[n_as_usize].clone()).unwrap() - } else { - unreachable!( + match size.into_inner() { + UExpressionInner::Value(size) => { + match (array.into_inner(), index.into_inner()) { + (ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { + if n < size { + FieldElementExpression::try_from(v[n as usize].clone()).unwrap() + } else { + unreachable!( "out of bounds index ({} >= {}) found during static analysis", - n_as_usize, size + n, size ); - } - } - (ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => { - match self.constants.get(&TypedAssignee::Select( - box TypedAssignee::Identifier(Variable::array( - id.clone(), - inner_type.clone(), - size, - )), - box FieldElementExpression::Number(n.clone()).into(), - )) { - Some(e) => match e { - TypedExpression::FieldElement(e) => e.clone(), - _ => unreachable!("??"), - }, - None => FieldElementExpression::Select( - box ArrayExpressionInner::Identifier(id).annotate(inner_type, size), - box FieldElementExpression::Number(n), + } + } + (ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { + match self.constants.get(&TypedAssignee::Select( + box TypedAssignee::Identifier(Variable::array( + id.clone(), + inner_type.clone(), + (size as u32).into(), + )), + box UExpressionInner::Value(n.clone()).annotate(UBitwidth::B32), + )) { + Some(e) => match e { + TypedExpression::FieldElement(e) => e.clone(), + _ => unreachable!("??"), + }, + None => FieldElementExpression::Select( + box ArrayExpressionInner::Identifier(id) + .annotate(inner_type, size as u32), + box UExpressionInner::Value(n).annotate(UBitwidth::B32), + ), + } + } + (a, i) => FieldElementExpression::Select( + box a.annotate(inner_type, size as u32), + box i.annotate(UBitwidth::B32), ), } } - (a, i) => { - FieldElementExpression::Select(box a.annotate(inner_type, size), box i) - } + size => fold_field_expression( + self, + FieldElementExpression::Select(box array, box index), + ), } } FieldElementExpression::Member(box s, m) => { @@ -725,8 +751,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { fn fold_array_expression_inner( &mut self, - ty: &Type, - size: usize, + ty: &Type<'ast, T>, + size: UExpression<'ast, T>, e: ArrayExpressionInner<'ast, T>, ) -> ArrayExpressionInner<'ast, T> { match e { @@ -747,45 +773,52 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } ArrayExpressionInner::Select(box array, box index) => { let array = self.fold_array_expression(array); - let index = self.fold_field_expression(index); + let index = self.fold_uint_expression(index); let inner_type = array.inner_type().clone(); let size = array.size(); - match (array.into_inner(), index) { - (ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => { - let n_as_usize = n.to_dec_string().parse::().unwrap(); - if n_as_usize < size { - ArrayExpression::try_from(v[n_as_usize].clone()) - .unwrap() - .into_inner() - } else { - unreachable!( - "out of bounds index ({} >= {}) found during static analysis", - n_as_usize, size - ); + match size.into_inner() { + UExpressionInner::Value(size) => match (array.into_inner(), index.into_inner()) + { + (ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { + if n < size { + ArrayExpression::try_from(v[n as usize].clone()) + .unwrap() + .into_inner() + } else { + unreachable!( + "out of bounds index ({} >= {}) found during static analysis", + n, size + ); + } } - } - (ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => { - match self.constants.get(&TypedAssignee::Select( - box TypedAssignee::Identifier(Variable::array( - id.clone(), - inner_type.clone(), - size, - )), - box FieldElementExpression::Number(n.clone()).into(), - )) { - Some(e) => match e { - TypedExpression::Array(e) => e.clone().into_inner(), - _ => unreachable!("should be an array"), - }, - None => ArrayExpressionInner::Select( - box ArrayExpressionInner::Identifier(id).annotate(inner_type, size), - box FieldElementExpression::Number(n), - ), + (ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { + match self.constants.get(&TypedAssignee::Select( + box TypedAssignee::Identifier(Variable::array( + id.clone(), + inner_type.clone(), + (size as u32).into(), + )), + box UExpressionInner::Value(n).annotate(UBitwidth::B32).into(), + )) { + Some(e) => match e { + TypedExpression::Array(e) => e.clone().into_inner(), + _ => unreachable!("should be an array"), + }, + None => ArrayExpressionInner::Select( + box ArrayExpressionInner::Identifier(id) + .annotate(inner_type, size as u32), + box (n as u32).into(), + ), + } } - } - (a, i) => ArrayExpressionInner::Select(box a.annotate(inner_type, size), box i), + (a, i) => ArrayExpressionInner::Select( + box a.annotate(inner_type, size as u32), + box i.annotate(UBitwidth::B32), + ), + }, + size => fold_array_expression_inner(self, ty, size.annotate(UBitwidth::B32), e), } } ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => { @@ -839,7 +872,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { fn fold_struct_expression_inner( &mut self, - ty: &StructType, + ty: &StructType<'ast, T>, e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { match e { @@ -859,47 +892,58 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } StructExpressionInner::Select(box array, box index) => { let array = self.fold_array_expression(array); - let index = self.fold_field_expression(index); + let index = self.fold_uint_expression(index); let inner_type = array.inner_type().clone(); let size = array.size(); - match (array.into_inner(), index) { - (ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => { - let n_as_usize = n.to_dec_string().parse::().unwrap(); - if n_as_usize < size { - StructExpression::try_from(v[n_as_usize].clone()) - .unwrap() - .into_inner() - } else { - unreachable!( - "out of bounds index ({} >= {}) found during static analysis", - n_as_usize, size - ); + match size.into_inner() { + UExpressionInner::Value(size) => match (array.into_inner(), index.into_inner()) + { + (ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { + if n < size { + StructExpression::try_from(v[n as usize].clone()) + .unwrap() + .into_inner() + } else { + unreachable!( + "out of bounds index ({} >= {}) found during static analysis", + n, size + ); + } } - } - (ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => { - match self.constants.get(&TypedAssignee::Select( - box TypedAssignee::Identifier(Variable::array( - id.clone(), - inner_type.clone(), - size, - )), - box FieldElementExpression::Number(n.clone()).into(), - )) { - Some(e) => match e { - TypedExpression::Struct(e) => e.clone().into_inner(), - _ => unreachable!("should be a struct"), - }, - None => StructExpressionInner::Select( - box ArrayExpressionInner::Identifier(id).annotate(inner_type, size), - box FieldElementExpression::Number(n), - ), + (ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { + match self.constants.get(&TypedAssignee::Select( + box TypedAssignee::Identifier(Variable::array( + id.clone(), + inner_type.clone(), + (size as u32).into(), + )), + box UExpressionInner::Value(n.clone()) + .annotate(UBitwidth::B32) + .into(), + )) { + Some(e) => match e { + TypedExpression::Struct(e) => e.clone().into_inner(), + _ => unreachable!("should be a struct"), + }, + None => StructExpressionInner::Select( + box ArrayExpressionInner::Identifier(id) + .annotate(inner_type, size as u32), + box UExpressionInner::Value(n).annotate(UBitwidth::B32), + ), + } } - } - (a, i) => { - StructExpressionInner::Select(box a.annotate(inner_type, size), box i) - } + (a, i) => StructExpressionInner::Select( + box a.annotate(inner_type, size as u32), + box i.annotate(UBitwidth::B32), + ), + }, + size => fold_struct_expression_inner( + self, + ty, + StructExpressionInner::Select(box array, box index), + ), } } StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { @@ -1092,43 +1136,55 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } BooleanExpression::Select(box array, box index) => { let array = self.fold_array_expression(array); - let index = self.fold_field_expression(index); + let index = self.fold_uint_expression(index); let inner_type = array.inner_type().clone(); let size = array.size(); - match (array.into_inner(), index) { - (ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => { - let n_as_usize = n.to_dec_string().parse::().unwrap(); - if n_as_usize < size { - BooleanExpression::try_from(v[n_as_usize].clone()).unwrap() - } else { - unreachable!( - "out of bounds index ({} >= {}) found during static analysis", - n_as_usize, size - ); + match size.into_inner() { + UExpressionInner::Value(size) => match (array.into_inner(), index.into_inner()) + { + (ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { + if n < size { + BooleanExpression::try_from(v[n as usize].clone()).unwrap() + } else { + unreachable!( + "out of bounds index ({} >= {}) found during static analysis", + n, size + ); + } } - } - (ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => { - match self.constants.get(&TypedAssignee::Select( - box TypedAssignee::Identifier(Variable::array( - id.clone(), - inner_type.clone(), - size, - )), - box FieldElementExpression::Number(n.clone()).into(), - )) { - Some(e) => match e { - TypedExpression::Boolean(e) => e.clone(), - _ => unreachable!("Should be a boolean"), - }, - None => BooleanExpression::Select( - box ArrayExpressionInner::Identifier(id).annotate(inner_type, size), - box FieldElementExpression::Number(n), - ), + (ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { + match self.constants.get(&TypedAssignee::Select( + box TypedAssignee::Identifier(Variable::array( + id.clone(), + inner_type.clone(), + (size as u32).into(), + )), + box UExpressionInner::Value(n.clone()) + .annotate(UBitwidth::B32) + .into(), + )) { + Some(e) => match e { + TypedExpression::Boolean(e) => e.clone(), + _ => unreachable!("Should be a boolean"), + }, + None => BooleanExpression::Select( + box ArrayExpressionInner::Identifier(id) + .annotate(inner_type, size as u32), + box UExpressionInner::Value(n.clone()).annotate(UBitwidth::B32), + ), + } } - } - (a, i) => BooleanExpression::Select(box a.annotate(inner_type, size), box i), + (a, i) => BooleanExpression::Select( + box a.annotate(inner_type, size as u32), + box i.annotate(UBitwidth::B32), + ), + }, + size => fold_boolean_expression( + self, + BooleanExpression::Select(box array, box index), + ), } } BooleanExpression::Member(box s, m) => { diff --git a/zokrates_core/src/static_analysis/unroll.rs b/zokrates_core/src/static_analysis/unroll.rs index 54c3ad8b..17852dbb 100644 --- a/zokrates_core/src/static_analysis/unroll.rs +++ b/zokrates_core/src/static_analysis/unroll.rs @@ -31,7 +31,7 @@ impl<'ast> Unroller<'ast> { } } - fn issue_next_ssa_variable(&mut self, v: Variable<'ast>) -> Variable<'ast> { + fn issue_next_ssa_variable(&mut self, v: Variable<'ast, T>) -> Variable<'ast, T> { let res = match self.substitution.get(&v.id.id) { Some(i) => Variable { id: Identifier { @@ -79,162 +79,164 @@ impl<'ast> Unroller<'ast> { let head = indices.remove(0); let tail = indices; - match head { - Access::Select(head) => { - statements.insert(TypedStatement::Assertion( - BooleanExpression::Lt( - box head.clone(), - box FieldElementExpression::Number(T::from(size)), - ) - .into(), - )); + unimplemented!() - ArrayExpressionInner::Value( - (0..size) - .map(|i| match inner_ty { - Type::Array(..) => ArrayExpression::if_else( - BooleanExpression::FieldEq( - box FieldElementExpression::Number(T::from(i)), - box head.clone(), - ), - match Self::choose_many( - ArrayExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ) - .into(), - tail.clone(), - new_expression.clone(), - statements, - ) { - TypedExpression::Array(e) => e, - e => unreachable!( - "the interior was expected to be an array, was {}", - e.get_type() - ), - }, - ArrayExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ), - ) - .into(), - Type::Struct(..) => StructExpression::if_else( - BooleanExpression::FieldEq( - box FieldElementExpression::Number(T::from(i)), - box head.clone(), - ), - match Self::choose_many( - StructExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ) - .into(), - tail.clone(), - new_expression.clone(), - statements, - ) { - TypedExpression::Struct(e) => e, - e => unreachable!( - "the interior was expected to be a struct, was {}", - e.get_type() - ), - }, - StructExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ), - ) - .into(), - Type::FieldElement => FieldElementExpression::if_else( - BooleanExpression::FieldEq( - box FieldElementExpression::Number(T::from(i)), - box head.clone(), - ), - match Self::choose_many( - FieldElementExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ) - .into(), - tail.clone(), - new_expression.clone(), - statements, - ) { - TypedExpression::FieldElement(e) => e, - e => unreachable!( - "the interior was expected to be a field, was {}", - e.get_type() - ), - }, - FieldElementExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ), - ) - .into(), - Type::Boolean => BooleanExpression::if_else( - BooleanExpression::FieldEq( - box FieldElementExpression::Number(T::from(i)), - box head.clone(), - ), - match Self::choose_many( - BooleanExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ) - .into(), - tail.clone(), - new_expression.clone(), - statements, - ) { - TypedExpression::Boolean(e) => e, - e => unreachable!( - "the interior was expected to be a boolean, was {}", - e.get_type() - ), - }, - BooleanExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ), - ) - .into(), - Type::Uint(..) => UExpression::if_else( - BooleanExpression::FieldEq( - box FieldElementExpression::Number(T::from(i)), - box head.clone(), - ), - match Self::choose_many( - UExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ) - .into(), - tail.clone(), - new_expression.clone(), - statements, - ) { - TypedExpression::Uint(e) => e, - e => unreachable!( - "the interior was expected to be a uint, was {}", - e.get_type() - ), - }, - UExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ), - ) - .into(), - }) - .collect(), - ) - .annotate(inner_ty.clone(), size) - .into() - } - Access::Member(..) => unreachable!("can't get a member from an array"), - } + // match head { + // Access::Select(head) => { + // statements.insert(TypedStatement::Assertion( + // BooleanExpression::Lt( + // box head.clone(), + // box FieldElementExpression::Number(T::from(size)), + // ) + // .into(), + // )); + + // ArrayExpressionInner::Value( + // (0..size) + // .map(|i| match inner_ty { + // Type::Array(..) => ArrayExpression::if_else( + // BooleanExpression::FieldEq( + // box FieldElementExpression::Number(T::from(i)), + // box head.clone(), + // ), + // match Self::choose_many( + // ArrayExpression::select( + // base.clone(), + // FieldElementExpression::Number(T::from(i)), + // ) + // .into(), + // tail.clone(), + // new_expression.clone(), + // statements, + // ) { + // TypedExpression::Array(e) => e, + // e => unreachable!( + // "the interior was expected to be an array, was {}", + // e.get_type() + // ), + // }, + // ArrayExpression::select( + // base.clone(), + // FieldElementExpression::Number(T::from(i)), + // ), + // ) + // .into(), + // Type::Struct(..) => StructExpression::if_else( + // BooleanExpression::FieldEq( + // box FieldElementExpression::Number(T::from(i)), + // box head.clone(), + // ), + // match Self::choose_many( + // StructExpression::select( + // base.clone(), + // FieldElementExpression::Number(T::from(i)), + // ) + // .into(), + // tail.clone(), + // new_expression.clone(), + // statements, + // ) { + // TypedExpression::Struct(e) => e, + // e => unreachable!( + // "the interior was expected to be a struct, was {}", + // e.get_type() + // ), + // }, + // StructExpression::select( + // base.clone(), + // FieldElementExpression::Number(T::from(i)), + // ), + // ) + // .into(), + // Type::FieldElement => FieldElementExpression::if_else( + // BooleanExpression::FieldEq( + // box FieldElementExpression::Number(T::from(i)), + // box head.clone(), + // ), + // match Self::choose_many( + // FieldElementExpression::select( + // base.clone(), + // FieldElementExpression::Number(T::from(i)), + // ) + // .into(), + // tail.clone(), + // new_expression.clone(), + // statements, + // ) { + // TypedExpression::FieldElement(e) => e, + // e => unreachable!( + // "the interior was expected to be a field, was {}", + // e.get_type() + // ), + // }, + // FieldElementExpression::select( + // base.clone(), + // FieldElementExpression::Number(T::from(i)), + // ), + // ) + // .into(), + // Type::Boolean => BooleanExpression::if_else( + // BooleanExpression::FieldEq( + // box FieldElementExpression::Number(T::from(i)), + // box head.clone(), + // ), + // match Self::choose_many( + // BooleanExpression::select( + // base.clone(), + // FieldElementExpression::Number(T::from(i)), + // ) + // .into(), + // tail.clone(), + // new_expression.clone(), + // statements, + // ) { + // TypedExpression::Boolean(e) => e, + // e => unreachable!( + // "the interior was expected to be a boolean, was {}", + // e.get_type() + // ), + // }, + // BooleanExpression::select( + // base.clone(), + // FieldElementExpression::Number(T::from(i)), + // ), + // ) + // .into(), + // Type::Uint(..) => UExpression::if_else( + // BooleanExpression::FieldEq( + // box FieldElementExpression::Number(T::from(i)), + // box head.clone(), + // ), + // match Self::choose_many( + // UExpression::select( + // base.clone(), + // FieldElementExpression::Number(T::from(i)), + // ) + // .into(), + // tail.clone(), + // new_expression.clone(), + // statements, + // ) { + // TypedExpression::Uint(e) => e, + // e => unreachable!( + // "the interior was expected to be a uint, was {}", + // e.get_type() + // ), + // }, + // UExpression::select( + // base.clone(), + // FieldElementExpression::Number(T::from(i)), + // ), + // ) + // .into(), + // }) + // .collect(), + // ) + // .annotate(inner_ty.clone(), size) + // .into() + // } + // Access::Member(..) => unreachable!("can't get a member from an array"), + // } } TypedExpression::Struct(base) => { let members = match base.get_type() { @@ -355,12 +357,12 @@ impl<'ast> Unroller<'ast> { #[derive(Clone, Debug)] enum Access<'ast, T: Field> { - Select(FieldElementExpression<'ast, T>), + Select(UExpression<'ast, T>), Member(MemberId), } /// Turn an assignee into its representation as a base variable and a list accesses /// a[2][3][4] -> (a, [2, 3, 4]) -fn linear<'ast, T: Field>(a: TypedAssignee<'ast, T>) -> (Variable, Vec>) { +fn linear<'ast, T: Field>(a: TypedAssignee<'ast, T>) -> (Variable<'ast, T>, Vec>) { match a { TypedAssignee::Identifier(v) => (v, vec![]), TypedAssignee::Select(box array, box index) => { @@ -415,7 +417,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> { let indices = indices .into_iter() .map(|a| match a { - Access::Select(i) => Access::Select(self.fold_field_expression(i)), + Access::Select(i) => Access::Select(self.fold_uint_expression(i)), a => a, }) .collect(); diff --git a/zokrates_core/src/static_analysis/variable_access_remover.rs b/zokrates_core/src/static_analysis/variable_access_remover.rs index 991740b1..246e5a71 100644 --- a/zokrates_core/src/static_analysis/variable_access_remover.rs +++ b/zokrates_core/src/static_analysis/variable_access_remover.rs @@ -29,22 +29,27 @@ impl<'ast, T: Field> VariableAccessRemover<'ast, T> { fn select + IfElse<'ast, T>>( &mut self, a: ArrayExpression<'ast, T>, - i: FieldElementExpression<'ast, T>, + i: UExpression<'ast, T>, ) -> U { - match i { - FieldElementExpression::Number(i) => U::select(a, FieldElementExpression::Number(i)), + match i.into_inner() { + UExpressionInner::Value(i) => { + U::select(a, UExpressionInner::Value(i).annotate(UBitwidth::B32)) + } i => { let size = match a.get_type().clone() { - Type::Array(array_ty) => array_ty.size, + Type::Array(array_ty) => match array_ty.size.into_inner() { + UExpressionInner::Value(size) => size as u32, + _ => unreachable!(), + }, _ => unreachable!(), }; self.statements.push(TypedStatement::Assertion( (0..size) .map(|index| { - BooleanExpression::FieldEq( - box i.clone(), - box FieldElementExpression::Number(index.into()).into(), + BooleanExpression::UintEq( + box i.clone().annotate(UBitwidth::B32), + box index.into(), ) }) .fold(None, |acc, e| match acc { @@ -56,14 +61,19 @@ impl<'ast, T: Field> VariableAccessRemover<'ast, T> { )); (0..size) - .map(|i| U::select(a.clone(), FieldElementExpression::Number(i.into()))) + .map(|i| { + U::select( + a.clone(), + UExpressionInner::Value(i.into()).annotate(UBitwidth::B32), + ) + }) .enumerate() .rev() .fold(None, |acc, (index, res)| match acc { Some(acc) => Some(U::if_else( - BooleanExpression::FieldEq( - box i.clone(), - box FieldElementExpression::Number(index.into()), + BooleanExpression::UintEq( + box i.clone().annotate(UBitwidth::B32), + box (index as u32).into(), ), res, acc, @@ -99,8 +109,8 @@ impl<'ast, T: Field> Folder<'ast, T> for VariableAccessRemover<'ast, T> { fn fold_array_expression_inner( &mut self, - ty: &Type, - size: usize, + ty: &Type<'ast, T>, + size: UExpression<'ast, T>, e: ArrayExpressionInner<'ast, T>, ) -> ArrayExpressionInner<'ast, T> { match e { @@ -113,7 +123,7 @@ impl<'ast, T: Field> Folder<'ast, T> for VariableAccessRemover<'ast, T> { fn fold_struct_expression_inner( &mut self, - ty: &StructType, + ty: &StructType<'ast, T>, e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { match e { diff --git a/zokrates_core/src/typed_absy/abi.rs b/zokrates_core/src/typed_absy/abi.rs index 206e802a..22e300e9 100644 --- a/zokrates_core/src/typed_absy/abi.rs +++ b/zokrates_core/src/typed_absy/abi.rs @@ -1,15 +1,16 @@ +use typed_absy::types::ConcreteSignature; +use typed_absy::types::ConcreteType; use typed_absy::types::Signature; -use typed_absy::Type; #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct AbiInput { pub name: String, pub public: bool, #[serde(flatten)] - pub ty: Type, + pub ty: ConcreteType, } -pub type AbiOutput = Type; +pub type AbiOutput = ConcreteType; #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct Abi { @@ -18,7 +19,7 @@ pub struct Abi { } impl Abi { - pub fn signature(&self) -> Signature { + pub fn signature(&self) -> ConcreteSignature { Signature { inputs: self.inputs.iter().map(|i| i.ty.clone()).collect(), outputs: self.outputs.clone(), diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index 385c4dce..e751aff8 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -23,9 +23,9 @@ pub trait Folder<'ast, T: Field>: Sized { fold_function(self, f) } - fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> { + fn fold_parameter(&mut self, p: DeclarationParameter<'ast>) -> DeclarationParameter<'ast> { Parameter { - id: self.fold_variable(p.id), + id: self.fold_declaration_variable(p.id), ..p } } @@ -34,19 +34,34 @@ pub trait Folder<'ast, T: Field>: Sized { n } - fn fold_variable(&mut self, v: Variable<'ast>) -> Variable<'ast> { + fn fold_variable(&mut self, v: Variable<'ast, T>) -> Variable<'ast, T> { Variable { id: self.fold_name(v.id), - ..v + ty: self.fold_type(v.ty), } } + fn fold_declaration_variable(&mut self, v: Variable<'ast, T>) -> Variable<'ast, T> { + DeclarationVariable { + id: self.fold_name(v.id), + ty: self.fold_declaration_type(v.ty), + } + } + + fn fold_type(&mut self, t: Type) -> Type { + unimplemented!() + } + + fn fold_declaration_type(&mut self, t: DeclarationType) -> DeclarationType { + unimplemented!() + } + fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> { match a { TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)), TypedAssignee::Select(box a, box index) => TypedAssignee::Select( box self.fold_assignee(a), - box self.fold_field_expression(index), + box self.fold_uint_expression(index), ), TypedAssignee::Member(box s, m) => TypedAssignee::Member(box self.fold_assignee(s), m), } @@ -121,15 +136,15 @@ pub trait Folder<'ast, T: Field>: Sized { fn fold_array_expression_inner( &mut self, - ty: &Type, - size: usize, + ty: &Type<'ast, T>, + size: UExpression<'ast, T>, e: ArrayExpressionInner<'ast, T>, ) -> ArrayExpressionInner<'ast, T> { fold_array_expression_inner(self, ty, size, e) } fn fold_struct_expression_inner( &mut self, - ty: &StructType, + ty: &StructType<'ast, T>, e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { fold_struct_expression_inner(self, ty, e) @@ -185,8 +200,8 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, - _: &Type, - _: usize, + _: &Type<'ast, T>, + _: UExpression<'ast, T>, e: ArrayExpressionInner<'ast, T>, ) -> ArrayExpressionInner<'ast, T> { match e { @@ -211,7 +226,7 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( } ArrayExpressionInner::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); ArrayExpressionInner::Select(box array, box index) } } @@ -219,7 +234,7 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, - _: &StructType, + _: &StructType<'ast, T>, e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { match e { @@ -244,7 +259,7 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( } StructExpressionInner::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); StructExpressionInner::Select(box array, box index) } } @@ -300,7 +315,7 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( } FieldElementExpression::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); FieldElementExpression::Select(box array, box index) } } @@ -388,7 +403,7 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( } BooleanExpression::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); BooleanExpression::Select(box array, box index) } } @@ -471,7 +486,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( } UExpressionInner::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); UExpressionInner::Select(box array, box index) } UExpressionInner::IfElse(box cond, box cons, box alt) => { @@ -510,8 +525,11 @@ pub fn fold_array_expression<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: ArrayExpression<'ast, T>, ) -> ArrayExpression<'ast, T> { + let size = f.fold_uint_expression(e.size); + ArrayExpression { - inner: f.fold_array_expression_inner(&e.ty, e.size, e.inner), + inner: f.fold_array_expression_inner(&e.ty, size, e.inner), + size, ..e } } diff --git a/zokrates_core/src/typed_absy/identifier.rs b/zokrates_core/src/typed_absy/identifier.rs index a22defd2..dda48136 100644 --- a/zokrates_core/src/typed_absy/identifier.rs +++ b/zokrates_core/src/typed_absy/identifier.rs @@ -1,12 +1,12 @@ use std::fmt; -use typed_absy::types::FunctionKey; +use typed_absy::types::ConcreteFunctionKey; use typed_absy::TypedModuleId; #[derive(Debug, PartialEq, Clone, Hash, Eq)] pub enum CoreIdentifier<'ast> { Source(&'ast str), Internal(&'static str, usize), - Call(FunctionKey<'ast>), + Call(ConcreteFunctionKey<'ast>), } impl<'ast> fmt::Display for CoreIdentifier<'ast> { @@ -27,7 +27,7 @@ pub struct Identifier<'ast> { /// the version of the variable, used after SSA transformation pub version: usize, /// the call stack of the variable, used when inlining - pub stack: Vec<(TypedModuleId, FunctionKey<'ast>, usize)>, + pub stack: Vec<(TypedModuleId, ConcreteFunctionKey<'ast>, usize)>, } impl<'ast> fmt::Display for Identifier<'ast> { @@ -78,7 +78,7 @@ impl<'ast> Identifier<'ast> { self } - pub fn stack(mut self, stack: Vec<(TypedModuleId, FunctionKey<'ast>, usize)>) -> Self { + pub fn stack(mut self, stack: Vec<(TypedModuleId, ConcreteFunctionKey<'ast>, usize)>) -> Self { self.stack = stack; self } diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 04484feb..956ab0a5 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -15,16 +15,18 @@ mod uint; mod variable; pub use self::identifier::CoreIdentifier; -pub use self::parameter::Parameter; -pub use self::types::{Signature, StructType, Type, UBitwidth}; -pub use self::variable::Variable; +pub use self::parameter::{DeclarationParameter, GParameter, Parameter}; +use self::types::{DeclarationFunctionKey, DeclarationSignature, GSignature, GStructType, GType}; +pub use self::types::{DeclarationType, Signature, StructType, Type, UBitwidth}; + +pub use self::variable::{DeclarationVariable, GVariable, Variable}; use std::path::PathBuf; pub use typed_absy::uint::{bitwidth, UExpression, UExpressionInner, UMetadata}; use crate::typed_absy::types::{FunctionKey, MemberId}; use embed::FlatEmbed; use std::collections::HashMap; -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; use std::fmt; use zokrates_field::Field; @@ -43,7 +45,8 @@ pub type TypedModules<'ast, T> = HashMap>; /// # Remarks /// * It is the role of the semantic checker to make sure there are no duplicates for a given `FunctionKey` /// in a given `TypedModule`, hence the use of a HashMap -pub type TypedFunctionSymbols<'ast, T> = HashMap, TypedFunctionSymbol<'ast, T>>; +pub type TypedFunctionSymbols<'ast, T> = + HashMap, TypedFunctionSymbol<'ast, T>>; /// A typed program as a collection of modules, one of them being the main #[derive(PartialEq, Debug, Clone)] @@ -53,7 +56,7 @@ pub struct TypedProgram<'ast, T> { } impl<'ast, T: Field> TypedProgram<'ast, T> { - pub fn abi(&self) -> Abi { + pub fn abi(&self) -> Result { let main = self.modules[&self.main] .functions .iter() @@ -65,18 +68,25 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { _ => unreachable!(), }; - Abi { + Ok(Abi { inputs: main .arguments .iter() - .map(|p| AbiInput { - public: !p.private, - name: p.id.id.to_string(), - ty: p.id._type.clone(), + .map(|p| { + types::ConcreteType::try_from(p.id._type.clone()).map(|ty| AbiInput { + public: !p.private, + name: p.id.id.to_string(), + ty, + }) }) - .collect(), - outputs: main.signature.outputs.clone(), - } + .collect::>()?, + outputs: main + .signature + .outputs + .iter() + .map(|ty| types::ConcreteType::try_from(ty.clone())) + .collect::>()?, + }) } } @@ -112,7 +122,7 @@ pub struct TypedModule<'ast, T> { #[derive(Clone, PartialEq)] pub enum TypedFunctionSymbol<'ast, T> { Here(TypedFunction<'ast, T>), - There(FunctionKey<'ast>, TypedModuleId), + There(DeclarationFunctionKey<'ast>, TypedModuleId), Flat(FlatEmbed), } @@ -128,7 +138,10 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedFunctionSymbol<'ast, T> { } impl<'ast, T: Field> TypedFunctionSymbol<'ast, T> { - pub fn signature<'a>(&'a self, modules: &'a TypedModules) -> Signature { + pub fn signature<'a>( + &'a self, + modules: &'a TypedModules<'ast, T>, + ) -> DeclarationSignature<'ast> { match self { TypedFunctionSymbol::Here(f) => f.signature.clone(), TypedFunctionSymbol::There(key, module_id) => modules @@ -139,7 +152,7 @@ impl<'ast, T: Field> TypedFunctionSymbol<'ast, T> { .unwrap() .signature(&modules) .clone(), - TypedFunctionSymbol::Flat(flat_fun) => flat_fun.signature(), + TypedFunctionSymbol::Flat(flat_fun) => flat_fun.signature().try_into().unwrap(), } } } @@ -185,11 +198,11 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedModule<'ast, T> { #[derive(Clone, PartialEq)] pub struct TypedFunction<'ast, T> { /// Arguments of the function - pub arguments: Vec>, + pub arguments: Vec>, /// Vector of statements that are executed when running the function pub statements: Vec>, /// function signature - pub signature: Signature, + pub signature: DeclarationSignature<'ast>, } impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> { @@ -251,16 +264,13 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedFunction<'ast, T> { /// Something we can assign to. #[derive(Clone, PartialEq, Hash, Eq)] pub enum TypedAssignee<'ast, T> { - Identifier(Variable<'ast>), - Select( - Box>, - Box>, - ), + Identifier(Variable<'ast, T>), + Select(Box>, Box>), Member(Box>, MemberId), } -impl<'ast, T> Typed for TypedAssignee<'ast, T> { - fn get_type(&self) -> Type { +impl<'ast, T> Typed<'ast, T> for TypedAssignee<'ast, T> { + fn get_type(&self) -> Type<'ast, T> { match *self { TypedAssignee::Identifier(ref v) => v.get_type(), TypedAssignee::Select(ref a, _) => { @@ -311,15 +321,15 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedAssignee<'ast, T> { pub enum TypedStatement<'ast, T> { Return(Vec>), Definition(TypedAssignee<'ast, T>, TypedExpression<'ast, T>), - Declaration(Variable<'ast>), + Declaration(Variable<'ast, T>), Assertion(BooleanExpression<'ast, T>), For( - Variable<'ast>, + Variable<'ast, T>, UExpression<'ast, T>, UExpression<'ast, T>, Vec>, ), - MultipleDefinition(Vec>, TypedExpressionList<'ast, T>), + MultipleDefinition(Vec>, TypedExpressionList<'ast, T>), } impl<'ast, T: fmt::Debug> fmt::Debug for TypedStatement<'ast, T> { @@ -407,8 +417,8 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> { } } -pub trait Typed { - fn get_type(&self) -> Type; +pub trait Typed<'ast, T> { + fn get_type(&self) -> Type<'ast, T>; } /// A typed expression @@ -531,8 +541,8 @@ impl<'ast, T: fmt::Debug> fmt::Debug for StructExpression<'ast, T> { } } -impl<'ast, T> Typed for TypedExpression<'ast, T> { - fn get_type(&self) -> Type { +impl<'ast, T> Typed<'ast, T> for TypedExpression<'ast, T> { + fn get_type(&self) -> Type<'ast, T> { match *self { TypedExpression::Boolean(ref e) => e.get_type(), TypedExpression::FieldElement(ref e) => e.get_type(), @@ -543,47 +553,51 @@ impl<'ast, T> Typed for TypedExpression<'ast, T> { } } -impl<'ast, T> Typed for ArrayExpression<'ast, T> { - fn get_type(&self) -> Type { +impl<'ast, T> Typed<'ast, T> for ArrayExpression<'ast, T> { + fn get_type(&self) -> Type<'ast, T> { Type::array(self.ty.clone(), self.size) } } -impl<'ast, T> Typed for StructExpression<'ast, T> { - fn get_type(&self) -> Type { +impl<'ast, T> Typed<'ast, T> for StructExpression<'ast, T> { + fn get_type(&self) -> Type<'ast, T> { Type::Struct(self.ty.clone()) } } -impl<'ast, T> Typed for FieldElementExpression<'ast, T> { - fn get_type(&self) -> Type { +impl<'ast, T> Typed<'ast, T> for FieldElementExpression<'ast, T> { + fn get_type(&self) -> Type<'ast, T> { Type::FieldElement } } -impl<'ast, T> Typed for UExpression<'ast, T> { - fn get_type(&self) -> Type { +impl<'ast, T> Typed<'ast, T> for UExpression<'ast, T> { + fn get_type(&self) -> Type<'ast, T> { Type::Uint(self.bitwidth) } } -impl<'ast, T> Typed for BooleanExpression<'ast, T> { - fn get_type(&self) -> Type { +impl<'ast, T> Typed<'ast, T> for BooleanExpression<'ast, T> { + fn get_type(&self) -> Type<'ast, T> { Type::Boolean } } -pub trait MultiTyped { - fn get_types(&self) -> &Vec; +pub trait MultiTyped<'ast, T> { + fn get_types(&self) -> &Vec>; } #[derive(Clone, PartialEq, Hash, Eq)] pub enum TypedExpressionList<'ast, T> { - FunctionCall(FunctionKey<'ast>, Vec>, Vec), + FunctionCall( + FunctionKey<'ast, T>, + Vec>, + Vec>, + ), } -impl<'ast, T> MultiTyped for TypedExpressionList<'ast, T> { - fn get_types(&self) -> &Vec { +impl<'ast, T> MultiTyped<'ast, T> for TypedExpressionList<'ast, T> { + fn get_types(&self) -> &Vec> { match *self { TypedExpressionList::FunctionCall(_, _, ref types) => types, } @@ -620,12 +634,9 @@ pub enum FieldElementExpression<'ast, T> { Box>, Box>, ), - FunctionCall(FunctionKey<'ast>, Vec>), + FunctionCall(FunctionKey<'ast, T>, Vec>), Member(Box>, MemberId), - Select( - Box>, - Box>, - ), + Select(Box>, Box>), } /// An expression of type `bool` @@ -678,11 +689,8 @@ pub enum BooleanExpression<'ast, T> { Box>, ), Member(Box>, MemberId), - FunctionCall(FunctionKey<'ast>, Vec>), - Select( - Box>, - Box>, - ), + FunctionCall(FunctionKey<'ast, T>, Vec>), + Select(Box>, Box>), } /// An expression of type `array` @@ -692,8 +700,8 @@ pub enum BooleanExpression<'ast, T> { /// type checking #[derive(Clone, PartialEq, Hash, Eq)] pub struct ArrayExpression<'ast, T> { - size: usize, - ty: Type, + size: UExpression<'ast, T>, + ty: Type<'ast, T>, inner: ArrayExpressionInner<'ast, T>, } @@ -701,23 +709,24 @@ pub struct ArrayExpression<'ast, T> { pub enum ArrayExpressionInner<'ast, T> { Identifier(Identifier<'ast>), Value(Vec>), - FunctionCall(FunctionKey<'ast>, Vec>), + FunctionCall(FunctionKey<'ast, T>, Vec>), IfElse( Box>, Box>, Box>, ), Member(Box>, MemberId), - Select( - Box>, - Box>, - ), + Select(Box>, Box>), } impl<'ast, T> ArrayExpressionInner<'ast, T> { - pub fn annotate(self, ty: Type, size: usize) -> ArrayExpression<'ast, T> { + pub fn annotate>>( + self, + ty: Type<'ast, T>, + size: S, + ) -> ArrayExpression<'ast, T> { ArrayExpression { - size, + size: size.into(), ty, inner: self, } @@ -725,11 +734,11 @@ impl<'ast, T> ArrayExpressionInner<'ast, T> { } impl<'ast, T> ArrayExpression<'ast, T> { - pub fn inner_type(&self) -> &Type { + pub fn inner_type(&self) -> &Type<'ast, T> { &self.ty } - pub fn size(&self) -> usize { + pub fn size(&self) -> UExpression<'ast, T> { self.size } @@ -744,12 +753,12 @@ impl<'ast, T> ArrayExpression<'ast, T> { #[derive(Clone, PartialEq, Hash, Eq)] pub struct StructExpression<'ast, T> { - ty: StructType, + ty: StructType<'ast, T>, inner: StructExpressionInner<'ast, T>, } impl<'ast, T> StructExpression<'ast, T> { - pub fn ty(&self) -> &StructType { + pub fn ty(&self) -> &StructType<'ast, T> { &self.ty } @@ -766,21 +775,18 @@ impl<'ast, T> StructExpression<'ast, T> { pub enum StructExpressionInner<'ast, T> { Identifier(Identifier<'ast>), Value(Vec>), - FunctionCall(FunctionKey<'ast>, Vec>), + FunctionCall(FunctionKey<'ast, T>, Vec>), IfElse( Box>, Box>, Box>, ), Member(Box>, MemberId), - Select( - Box>, - Box>, - ), + Select(Box>, Box>), } impl<'ast, T> StructExpressionInner<'ast, T> { - pub fn annotate(self, ty: StructType) -> StructExpression<'ast, T> { + pub fn annotate(self, ty: StructType<'ast, T>) -> StructExpression<'ast, T> { StructExpression { ty, inner: self } } } @@ -1212,23 +1218,23 @@ impl<'ast, T> IfElse<'ast, T> for StructExpression<'ast, T> { } pub trait Select<'ast, T> { - fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self; + fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self; } impl<'ast, T> Select<'ast, T> for FieldElementExpression<'ast, T> { - fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self { + fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self { FieldElementExpression::Select(box array, box index) } } impl<'ast, T> Select<'ast, T> for BooleanExpression<'ast, T> { - fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self { + fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self { BooleanExpression::Select(box array, box index) } } impl<'ast, T> Select<'ast, T> for UExpression<'ast, T> { - fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self { + fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self { let bitwidth = match array.inner_type().clone() { Type::Uint(bitwidth) => bitwidth, _ => unreachable!(), @@ -1239,7 +1245,7 @@ impl<'ast, T> Select<'ast, T> for UExpression<'ast, T> { } impl<'ast, T> Select<'ast, T> for ArrayExpression<'ast, T> { - fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self { + fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self { let (ty, size) = match array.inner_type() { Type::Array(array_type) => (array_type.ty.clone(), array_type.size.clone()), _ => unreachable!(), @@ -1250,7 +1256,7 @@ impl<'ast, T> Select<'ast, T> for ArrayExpression<'ast, T> { } impl<'ast, T> Select<'ast, T> for StructExpression<'ast, T> { - fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self { + fn select(array: ArrayExpression<'ast, T>, index: UExpression<'ast, T>) -> Self { let members = match array.inner_type().clone() { Type::Struct(members) => members, _ => unreachable!(), diff --git a/zokrates_core/src/typed_absy/parameter.rs b/zokrates_core/src/typed_absy/parameter.rs index 277ac537..c1d45009 100644 --- a/zokrates_core/src/typed_absy/parameter.rs +++ b/zokrates_core/src/typed_absy/parameter.rs @@ -1,15 +1,18 @@ -use crate::typed_absy::Variable; +use crate::typed_absy::GVariable; use std::fmt; +use typed_absy::types::Constant; +use typed_absy::TryFrom; +use typed_absy::UExpression; #[derive(Clone, PartialEq)] -pub struct Parameter<'ast> { - pub id: Variable<'ast>, +pub struct GParameter<'ast, S> { + pub id: GVariable<'ast, S>, pub private: bool, } -impl<'ast> Parameter<'ast> { +impl<'ast, S> GParameter<'ast, S> { #[cfg(test)] - pub fn private(v: Variable<'ast>) -> Self { + pub fn private(v: GVariable<'ast, S>) -> Self { Parameter { id: v, private: true, @@ -17,14 +20,32 @@ impl<'ast> Parameter<'ast> { } } -impl<'ast> fmt::Display for Parameter<'ast> { +pub type DeclarationParameter<'ast> = GParameter<'ast, Constant<'ast>>; +pub type ConcreteParameter<'ast> = GParameter<'ast, usize>; +pub type Parameter<'ast, T> = GParameter<'ast, UExpression<'ast, T>>; + +impl<'ast, T> TryFrom> for ConcreteParameter<'ast> { + type Error = (); + + fn try_from(t: Parameter<'ast, T>) -> Result { + unimplemented!() + } +} + +impl<'ast, T> From> for Parameter<'ast, T> { + fn from(t: ConcreteParameter<'ast>) -> Self { + unimplemented!() + } +} + +impl<'ast, S> fmt::Display for GParameter<'ast, S> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let visibility = if self.private { "private " } else { "" }; write!(f, "{}{} {}", visibility, self.id.get_type(), self.id.id) } } -impl<'ast> fmt::Debug for Parameter<'ast> { +impl<'ast, S> fmt::Debug for GParameter<'ast, S> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "Parameter(variable: {:?})", self.id) } diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index 75e597e5..ced286fe 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -1,49 +1,119 @@ +use std::cmp::Ordering; use std::fmt; use std::path::PathBuf; +use typed_absy::TryFrom; use typed_absy::UExpression; pub type Identifier<'ast> = &'ast str; +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Constant<'ast> { + Generic(Identifier<'ast>), + Concrete(u32), +} + +impl<'ast> From for Constant<'ast> { + fn from(e: u32) -> Self { + Constant::Concrete(e) + } +} + +impl<'ast> From> for Constant<'ast> { + fn from(e: Identifier<'ast>) -> Self { + Constant::Generic(e) + } +} + pub type MemberId = String; #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] -pub struct StructMember { +pub struct GStructMember { #[serde(rename = "name")] pub id: MemberId, #[serde(flatten)] - pub ty: Box, + pub ty: Box>, +} + +pub type DeclarationStructMember<'ast> = GStructMember>; +pub type ConcreteStructMember = GStructMember; +pub type StructMember<'ast, T> = GStructMember>; + +impl<'ast, T> TryFrom> for ConcreteStructMember { + type Error = (); + + fn try_from(t: StructMember<'ast, T>) -> Result { + unimplemented!() + } +} + +impl<'ast, T> From for StructMember<'ast, T> { + fn from(t: ConcreteStructMember) -> Self { + unimplemented!() + } } #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] pub struct GArrayType { pub size: S, #[serde(flatten)] - pub ty: Box, + pub ty: Box>, } -pub type ArrayType = GArrayType; +pub type DeclarationArrayType<'ast> = GArrayType>; +pub type ConcreteArrayType = GArrayType; +pub type ArrayType<'ast, T> = GArrayType>; -pub type UArrayType<'ast, T> = GArrayType>; +impl<'ast, T> TryFrom> for ConcreteArrayType { + type Error = (); + + fn try_from(t: ArrayType<'ast, T>) -> Result { + unimplemented!() + } +} + +impl<'ast, T> From for ArrayType<'ast, T> { + fn from(t: ConcreteArrayType) -> Self { + unimplemented!() + } +} #[derive(Clone, Hash, Serialize, Deserialize, PartialOrd, Ord)] -pub struct StructType { +pub struct GStructType { #[serde(skip)] pub module: PathBuf, pub name: String, - pub members: Vec, + pub members: Vec>, } -impl PartialEq for StructType { +pub type DeclarationStructType<'ast> = GStructType>; +pub type ConcreteStructType = GStructType; +pub type StructType<'ast, T> = GStructType>; + +impl<'ast, T> TryFrom> for ConcreteStructType { + type Error = (); + + fn try_from(t: StructType<'ast, T>) -> Result { + unimplemented!() + } +} + +impl<'ast, T> From for StructType<'ast, T> { + fn from(t: ConcreteStructType) -> Self { + unimplemented!() + } +} + +impl PartialEq for GStructType { fn eq(&self, other: &Self) -> bool { self.members.eq(&other.members) } } -impl Eq for StructType {} +impl Eq for GStructType {} -impl StructType { - pub fn new(module: PathBuf, name: String, members: Vec) -> Self { - StructType { +impl GStructType { + pub fn new(module: PathBuf, name: String, members: Vec>) -> Self { + GStructType { module, name, members, @@ -54,13 +124,13 @@ impl StructType { self.members.len() } - pub fn iter(&self) -> std::slice::Iter { + pub fn iter(&self) -> std::slice::Iter> { self.members.iter() } } -impl IntoIterator for StructType { - type Item = StructMember; +impl IntoIterator for GStructType { + type Item = GStructMember; type IntoIter = std::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { @@ -101,47 +171,120 @@ impl fmt::Display for UBitwidth { } } -#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] +#[derive(Clone, Hash, Serialize, Deserialize)] #[serde(tag = "type", content = "components")] -pub enum Type { +pub enum GType { #[serde(rename = "field")] FieldElement, #[serde(rename = "bool")] Boolean, #[serde(rename = "array")] - Array(ArrayType), + Array(GArrayType), #[serde(rename = "struct")] - Struct(StructType), + Struct(GStructType), #[serde(rename = "u")] Uint(UBitwidth), } +pub type DeclarationType<'ast> = GType>; +pub type ConcreteType = GType; +pub type Type<'ast, T> = GType>; + +// we have a looser equality relationship for generic types: an array of unknown size of a given type is equal to any arrays of that type +impl<'ast, T> PartialEq for Type<'ast, T> { + fn eq(&self, other: &Self) -> bool { + unimplemented!() + } +} + +impl PartialEq for ConcreteType { + fn eq(&self, other: &Self) -> bool { + unimplemented!() + } +} + +impl<'ast, T> Eq for Type<'ast, T> {} + +impl Eq for ConcreteType {} + +impl<'ast, T> PartialOrd for Type<'ast, T> { + fn partial_cmp(&self, other: &Self) -> Option { + unimplemented!() + } +} + +impl PartialOrd for ConcreteType { + fn partial_cmp(&self, other: &Self) -> Option { + unimplemented!() + } +} + +impl<'ast, T> Ord for Type<'ast, T> { + fn cmp(&self, other: &Self) -> Ordering { + unimplemented!() + } +} + +impl Ord for ConcreteType { + fn cmp(&self, other: &Self) -> Ordering { + unimplemented!() + } +} + +impl<'ast, T> TryFrom> for ConcreteType { + type Error = (); + + fn try_from(t: Type<'ast, T>) -> Result { + unimplemented!() + } +} + +impl<'ast> TryFrom> for ConcreteType { + type Error = (); + + fn try_from(t: DeclarationType<'ast>) -> Result { + unimplemented!() + } +} + +impl<'ast, T> From for Type<'ast, T> { + fn from(t: ConcreteType) -> Self { + unimplemented!() + } +} + +impl<'ast, T> From> for Type<'ast, T> { + fn from(t: DeclarationType<'ast>) -> Self { + unimplemented!() + } +} + impl GArrayType { - pub fn new(ty: Type, size: S) -> Self { - GArrayType { + pub fn new(ty: GType, size: S) -> Self { + ArrayType { ty: Box::new(ty), size, } } } -impl StructMember { - pub fn new(id: String, ty: Type) -> Self { - StructMember { +impl GStructMember { + pub fn new(id: String, ty: GType) -> Self { + GStructMember { id, ty: Box::new(ty), } } } -impl fmt::Display for Type { +impl fmt::Display for GType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Type::FieldElement => write!(f, "field"), - Type::Boolean => write!(f, "bool"), - Type::Uint(ref bitwidth) => write!(f, "u{}", bitwidth), - Type::Array(ref array_type) => write!(f, "{}[{}]", array_type.ty, array_type.size), - Type::Struct(ref struct_type) => write!( + GType::FieldElement => write!(f, "field"), + GType::Boolean => write!(f, "bool"), + GType::Uint(ref bitwidth) => write!(f, "u{}", bitwidth), + GType::Array(ref array_type) => write!(f, "{}[{}]", array_type.ty, array_type.size), + GType::Struct(ref struct_type) => write!( f, "{} {{{}}}", struct_type.name, @@ -156,14 +299,14 @@ impl fmt::Display for Type { } } -impl fmt::Debug for Type { +impl fmt::Debug for GType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Type::FieldElement => write!(f, "field"), - Type::Boolean => write!(f, "bool"), - Type::Uint(ref bitwidth) => write!(f, "u{}", bitwidth), - Type::Array(ref array_type) => write!(f, "{}[{}]", array_type.ty, array_type.size), - Type::Struct(ref struct_type) => write!( + GType::FieldElement => write!(f, "field"), + GType::Boolean => write!(f, "bool"), + GType::Uint(ref bitwidth) => write!(f, "u{}", bitwidth), + GType::Array(ref array_type) => write!(f, "{}[{}]", array_type.ty, array_type.size), + GType::Struct(ref struct_type) => write!( f, "{} {{{}}}", struct_type.name, @@ -178,26 +321,26 @@ impl fmt::Debug for Type { } } -impl Type { - pub fn array(ty: Type, size: usize) -> Self { - Type::Array(ArrayType::new(ty, size)) +impl GType { + pub fn array>(ty: GType, size: U) -> Self { + GType::Array(ArrayType::new(ty, size.into())) } - pub fn struc(struct_ty: StructType) -> Self { - Type::Struct(struct_ty) + pub fn struc(struct_ty: GStructType) -> Self { + GType::Struct(struct_ty) } pub fn uint>(b: W) -> Self { - Type::Uint(b.into()) + GType::Uint(b.into()) } fn to_slug(&self) -> String { match self { - Type::FieldElement => String::from("f"), - Type::Boolean => String::from("b"), - Type::Uint(bitwidth) => format!("u{}", bitwidth), - Type::Array(array_type) => format!("{}[{}]", array_type.ty.to_slug(), array_type.size), - Type::Struct(struct_type) => format!( + GType::FieldElement => String::from("f"), + GType::Boolean => String::from("b"), + GType::Uint(bitwidth) => format!("u{}", bitwidth), + GType::Array(array_type) => format!("{}[{}]", array_type.ty.to_slug(), array_type.size), + GType::Struct(struct_type) => format!( "{{{}}}", struct_type .iter() @@ -207,15 +350,17 @@ impl Type { ), } } +} +impl ConcreteType { // the number of field elements the type maps to pub fn get_primitive_count(&self) -> usize { match self { - Type::FieldElement => 1, - Type::Boolean => 1, - Type::Uint(_) => 1, - Type::Array(array_type) => array_type.size * array_type.ty.get_primitive_count(), - Type::Struct(struct_type) => struct_type + GType::FieldElement => 1, + GType::Boolean => 1, + GType::Uint(_) => 1, + GType::Array(array_type) => array_type.size * array_type.ty.get_primitive_count(), + GType::Struct(struct_type) => struct_type .iter() .map(|member| member.ty.get_primitive_count()) .sum(), @@ -226,25 +371,51 @@ impl Type { pub type FunctionIdentifier<'ast> = &'ast str; #[derive(PartialEq, Eq, Hash, Debug, Clone)] -pub struct FunctionKey<'ast> { +pub struct GFunctionKey<'ast, S> { pub id: FunctionIdentifier<'ast>, - pub signature: Signature, + pub signature: GSignature, } -impl<'ast> FunctionKey<'ast> { - pub fn with_id>>(id: S) -> Self { - FunctionKey { +pub type DeclarationFunctionKey<'ast> = GFunctionKey<'ast, Constant<'ast>>; +pub type ConcreteFunctionKey<'ast> = GFunctionKey<'ast, usize>; +pub type FunctionKey<'ast, T> = GFunctionKey<'ast, UExpression<'ast, T>>; + +impl<'ast, T> TryFrom> for ConcreteFunctionKey<'ast> { + type Error = (); + + fn try_from(t: FunctionKey<'ast, T>) -> Result { + unimplemented!() + } +} + +impl<'ast> TryFrom> for ConcreteFunctionKey<'ast> { + type Error = (); + + fn try_from(t: DeclarationFunctionKey<'ast>) -> Result { + unimplemented!() + } +} + +impl<'ast, T> From> for FunctionKey<'ast, T> { + fn from(t: ConcreteFunctionKey<'ast>) -> Self { + unimplemented!() + } +} + +impl<'ast, S> GFunctionKey<'ast, S> { + pub fn with_id>>(id: U) -> Self { + GFunctionKey { id: id.into(), - signature: Signature::new(), + signature: GSignature::new(), } } - pub fn signature(mut self, signature: Signature) -> Self { + pub fn signature(mut self, signature: GSignature) -> Self { self.signature = signature; self } - pub fn id>>(mut self, id: S) -> Self { + pub fn id>>(mut self, id: U) -> Self { self.id = id.into(); self } @@ -254,19 +425,81 @@ impl<'ast> FunctionKey<'ast> { } } -pub use self::signature::Signature; +pub use self::signature::{ConcreteSignature, DeclarationSignature, GSignature, Signature}; pub mod signature { use super::*; + use std::cmp::Ordering; use std::fmt; + use std::hash::Hasher; - #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Ord, PartialOrd)] - pub struct Signature { - pub inputs: Vec, - pub outputs: Vec, + #[derive(Clone, Serialize, Deserialize)] + pub struct GSignature { + pub inputs: Vec>, + pub outputs: Vec>, } - impl fmt::Debug for Signature { + impl PartialOrd for GSignature { + fn partial_cmp(&self, other: &Self) -> Option { + unimplemented!() + } + } + + impl Ord for GSignature { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap() + } + } + + impl PartialEq for GSignature { + fn eq(&self, other: &Self) -> bool { + unimplemented!() + } + } + + impl Eq for GSignature {} + + impl std::hash::Hash for GSignature { + fn hash(&self, state: &mut H) { + unimplemented!() + } + } + + pub type DeclarationSignature<'ast> = GSignature>; + pub type ConcreteSignature = GSignature; + pub type Signature<'ast, T> = GSignature>; + + impl<'ast> TryFrom for DeclarationSignature<'ast> { + type Error = (); + + fn try_from(t: ConcreteSignature) -> Result { + unimplemented!() + } + } + + impl<'ast, T> TryFrom> for ConcreteSignature { + type Error = (); + + fn try_from(t: Signature<'ast, T>) -> Result { + unimplemented!() + } + } + + impl<'ast> TryFrom> for ConcreteSignature { + type Error = (); + + fn try_from(t: DeclarationSignature<'ast>) -> Result { + unimplemented!() + } + } + + impl<'ast, T> From for Signature<'ast, T> { + fn from(t: ConcreteSignature) -> Self { + unimplemented!() + } + } + + impl fmt::Debug for GSignature { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, @@ -276,7 +509,7 @@ pub mod signature { } } - impl fmt::Display for Signature { + impl fmt::Display for GSignature { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "(")?; for (i, t) in self.inputs.iter().enumerate() { @@ -303,7 +536,7 @@ pub mod signature { } } - impl Signature { + impl GSignature { /// Returns a slug for a signature, with the following encoding: /// i{inputs}o{outputs} where {inputs} and {outputs} each encode a list of types. /// A list of types is encoded by compressing sequences of the same type like so: @@ -347,19 +580,19 @@ pub mod signature { format!("i{}o{}", to_slug(&self.inputs), to_slug(&self.outputs)) } - pub fn new() -> Signature { - Signature { + pub fn new() -> GSignature { + Self { inputs: vec![], outputs: vec![], } } - pub fn inputs(mut self, inputs: Vec) -> Self { + pub fn inputs(mut self, inputs: Vec>) -> Self { self.inputs = inputs; self } - pub fn outputs(mut self, outputs: Vec) -> Self { + pub fn outputs(mut self, outputs: Vec>) -> Self { self.outputs = outputs; self } diff --git a/zokrates_core/src/typed_absy/uint.rs b/zokrates_core/src/typed_absy/uint.rs index d5245777..0221ba06 100644 --- a/zokrates_core/src/typed_absy/uint.rs +++ b/zokrates_core/src/typed_absy/uint.rs @@ -82,6 +82,12 @@ pub struct UExpression<'ast, T> { pub inner: UExpressionInner<'ast, T>, } +impl<'ast, T> From for UExpression<'ast, T> { + fn from(u: u32) -> Self { + UExpressionInner::Value(u).annotate(UBitwidth::B32) + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum UExpressionInner<'ast, T> { Identifier(Identifier<'ast>), @@ -101,17 +107,14 @@ pub enum UExpressionInner<'ast, T> { Box>, Box>, ), - FunctionCall(FunctionKey<'ast>, Vec>), + FunctionCall(FunctionKey<'ast, T>, Vec>), IfElse( Box>, Box>, Box>, ), Member(Box>, MemberId), - Select( - Box>, - Box>, - ), + Select(Box>, Box>), } impl<'ast, T> UExpressionInner<'ast, T> { diff --git a/zokrates_core/src/typed_absy/variable.rs b/zokrates_core/src/typed_absy/variable.rs index 3db55cf9..124f154b 100644 --- a/zokrates_core/src/typed_absy/variable.rs +++ b/zokrates_core/src/typed_absy/variable.rs @@ -1,76 +1,94 @@ -use crate::typed_absy::types::Type; +use crate::typed_absy::types::GType; use crate::typed_absy::Identifier; use std::fmt; -use typed_absy::types::{StructType, UBitwidth}; +use typed_absy::types::{Constant, GStructType, UBitwidth}; +use typed_absy::TryFrom; +use typed_absy::UExpression; #[derive(Clone, PartialEq, Hash, Eq)] -pub struct Variable<'ast> { +pub struct GVariable<'ast, S> { pub id: Identifier<'ast>, - pub _type: Type, + pub _type: GType, } -impl<'ast> Variable<'ast> { - pub fn field_element>>(id: I) -> Variable<'ast> { - Self::with_id_and_type(id, Type::FieldElement) +pub type DeclarationVariable<'ast> = GVariable<'ast, Constant<'ast>>; +pub type ConcreteVariable<'ast> = GVariable<'ast, usize>; +pub type Variable<'ast, T> = GVariable<'ast, UExpression<'ast, T>>; + +impl<'ast, T> TryFrom> for ConcreteVariable<'ast> { + type Error = (); + + fn try_from(t: Variable<'ast, T>) -> Result { + unimplemented!() + } +} + +impl<'ast> TryFrom> for ConcreteVariable<'ast> { + type Error = (); + + fn try_from(t: DeclarationVariable<'ast>) -> Result { + unimplemented!() + } +} + +impl<'ast, T> From> for Variable<'ast, T> { + fn from(t: ConcreteVariable<'ast>) -> Self { + unimplemented!() + } +} + +impl<'ast, T> From> for Variable<'ast, T> { + fn from(t: DeclarationVariable<'ast>) -> Self { + unimplemented!() + } +} + +impl<'ast, S> GVariable<'ast, S> { + pub fn field_element>>(id: I) -> Self { + Self::with_id_and_type(id, GType::FieldElement) } - pub fn boolean>>(id: I) -> Variable<'ast> { - Self::with_id_and_type(id, Type::Boolean) + pub fn boolean>>(id: I) -> Self { + Self::with_id_and_type(id, GType::Boolean) } - pub fn uint>, W: Into>( - id: I, - bitwidth: W, - ) -> Variable<'ast> { - Self::with_id_and_type(id, Type::uint(bitwidth)) + pub fn uint>, W: Into>(id: I, bitwidth: W) -> Self { + Self::with_id_and_type(id, GType::uint(bitwidth)) } #[cfg(test)] - pub fn field_array>>(id: I, size: usize) -> Variable<'ast> { - Self::array(id, Type::FieldElement, size) + pub fn field_array>>(id: I, size: S) -> Self { + Self::array(id, GType::FieldElement, size) } - pub fn array>>(id: I, ty: Type, size: usize) -> Variable<'ast> { - Self::with_id_and_type(id, Type::array(ty, size)) + pub fn array>>(id: I, ty: GType, size: S) -> Self { + Self::with_id_and_type(id, GType::array(ty, size)) } - pub fn struc>>(id: I, ty: StructType) -> Variable<'ast> { - Self::with_id_and_type(id, Type::Struct(ty)) + pub fn struc>>(id: I, ty: GStructType) -> Self { + Self::with_id_and_type(id, GType::Struct(ty)) } - pub fn with_id_and_type>>(id: I, _type: Type) -> Variable<'ast> { - Variable { + pub fn with_id_and_type>>(id: I, _type: GType) -> Self { + GVariable { id: id.into(), _type, } } - pub fn get_type(&self) -> Type { + pub fn get_type(&self) -> GType { self._type.clone() } } -impl<'ast> fmt::Display for Variable<'ast> { +impl<'ast, S> fmt::Display for GVariable<'ast, S> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{} {}", self._type, self.id,) } } -impl<'ast> fmt::Debug for Variable<'ast> { +impl<'ast, S> fmt::Debug for GVariable<'ast, S> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "Variable(type: {:?}, id: {:?})", self._type, self.id,) } } - -// impl<'ast> From> for Variable<'ast> { -// fn from(v: absy::Variable) -> Variable { -// Variable::with_id_and_type( -// Identifier { -// id: v.id, -// version: 0, -// stack: vec![], -// }, -// v._type, -// ) -// } -// } diff --git a/zokrates_core/src/zir/from_typed.rs b/zokrates_core/src/zir/from_typed.rs index a2c1a25f..b7fa30d1 100644 --- a/zokrates_core/src/zir/from_typed.rs +++ b/zokrates_core/src/zir/from_typed.rs @@ -1,16 +1,16 @@ use typed_absy; use zir; -impl<'ast> From> for zir::types::FunctionKey<'ast> { - fn from(k: typed_absy::types::FunctionKey<'ast>) -> zir::types::FunctionKey<'ast> { +impl<'ast> From> for zir::types::FunctionKey<'ast> { + fn from(k: typed_absy::types::ConcreteFunctionKey<'ast>) -> zir::types::FunctionKey<'ast> { zir::types::FunctionKey { id: k.id, signature: k.signature.into(), } } } -impl From for zir::types::Signature { - fn from(s: typed_absy::types::Signature) -> zir::types::Signature { +impl From for zir::types::Signature { + fn from(s: typed_absy::types::ConcreteSignature) -> zir::types::Signature { zir::types::Signature { inputs: s.inputs.into_iter().flat_map(|t| from_type(t)).collect(), outputs: s.outputs.into_iter().flat_map(|t| from_type(t)).collect(), @@ -18,16 +18,18 @@ impl From for zir::types::Signature { } } -fn from_type(t: typed_absy::types::Type) -> Vec { +fn from_type(t: typed_absy::types::ConcreteType) -> Vec { match t { - typed_absy::Type::FieldElement => vec![zir::Type::FieldElement], - typed_absy::Type::Boolean => vec![zir::Type::Boolean], - typed_absy::Type::Uint(bitwidth) => vec![zir::Type::uint(bitwidth.to_usize())], - typed_absy::Type::Array(array_type) => { + typed_absy::types::ConcreteType::FieldElement => vec![zir::Type::FieldElement], + typed_absy::types::ConcreteType::Boolean => vec![zir::Type::Boolean], + typed_absy::types::ConcreteType::Uint(bitwidth) => { + vec![zir::Type::uint(bitwidth.to_usize())] + } + typed_absy::types::ConcreteType::Array(array_type) => { let inner = from_type(*array_type.ty); (0..array_type.size).flat_map(|_| inner.clone()).collect() } - typed_absy::Type::Struct(members) => members + typed_absy::types::ConcreteType::Struct(members) => members .into_iter() .flat_map(|struct_member| from_type(*struct_member.ty)) .collect(), diff --git a/zokrates_core/src/zir/mod.rs b/zokrates_core/src/zir/mod.rs index e95f7687..01c79b61 100644 --- a/zokrates_core/src/zir/mod.rs +++ b/zokrates_core/src/zir/mod.rs @@ -14,7 +14,7 @@ pub use zir::uint::{ShouldReduce, UExpression, UExpressionInner, UMetadata}; use embed::FlatEmbed; use std::collections::HashMap; -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; use std::fmt; use zir::types::{FunctionKey, Signature}; use zokrates_field::Field; @@ -90,7 +90,7 @@ impl<'ast, T> ZirFunctionSymbol<'ast, T> { .unwrap() .signature(&modules) .clone(), - ZirFunctionSymbol::Flat(flat_fun) => flat_fun.signature().into(), + ZirFunctionSymbol::Flat(flat_fun) => flat_fun.signature().try_into().unwrap(), } } } diff --git a/zokrates_field/src/lib.rs b/zokrates_field/src/lib.rs index ab5a2a7a..3787b65b 100644 --- a/zokrates_field/src/lib.rs +++ b/zokrates_field/src/lib.rs @@ -35,6 +35,8 @@ pub trait Field: + Ord + Display + Debug + + Default + + Hash + Add + for<'a> Add<&'a Self, Output = Self> + Sub