1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

Merge pull request #913 from Zokrates/use-global-constant-map

Use global constant map for constant inlining
This commit is contained in:
Thibaut Schaeffer 2021-06-17 12:13:29 +02:00 committed by GitHub
commit 17095e966c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 430 additions and 240 deletions

View file

@ -0,0 +1 @@
Fix crash on import of functions containing constants

View file

@ -4,7 +4,8 @@ use crate::flat_absy::{
};
use crate::solvers::Solver;
use crate::typed_absy::types::{
ConcreteGenericsAssignment, Constant, DeclarationSignature, DeclarationType, GenericIdentifier,
ConcreteGenericsAssignment, DeclarationConstant, DeclarationSignature, DeclarationType,
GenericIdentifier,
};
use std::collections::HashMap;
use zokrates_field::{Bn128Field, Field};
@ -43,10 +44,12 @@ impl FlatEmbed {
.inputs(vec![DeclarationType::uint(32)])
.outputs(vec![DeclarationType::FieldElement]),
FlatEmbed::Unpack => DeclarationSignature::new()
.generics(vec![Some(Constant::Generic(GenericIdentifier {
name: "N",
index: 0,
}))])
.generics(vec![Some(DeclarationConstant::Generic(
GenericIdentifier {
name: "N",
index: 0,
},
))])
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::array((
DeclarationType::Boolean,
@ -122,7 +125,7 @@ impl FlatEmbed {
.generics
.into_iter()
.map(|c| match c.unwrap() {
Constant::Generic(g) => g,
DeclarationConstant::Generic(g) => g,
_ => unreachable!(),
});

View file

@ -19,9 +19,9 @@ use crate::parser::Position;
use crate::absy::types::{UnresolvedSignature, UnresolvedType, UserTypeId};
use crate::typed_absy::types::{
ArrayType, Constant, DeclarationArrayType, DeclarationFunctionKey, DeclarationSignature,
DeclarationStructMember, DeclarationStructType, DeclarationType, GenericIdentifier,
StructLocation,
ArrayType, DeclarationArrayType, DeclarationConstant, DeclarationFunctionKey,
DeclarationSignature, DeclarationStructMember, DeclarationStructType, DeclarationType,
GenericIdentifier, StructLocation,
};
use std::hash::{Hash, Hasher};
@ -350,7 +350,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
fn check_constant_definition(
&mut self,
id: &'ast str,
id: ConstantIdentifier<'ast>,
c: ConstantDefinitionNode<'ast>,
module_id: &ModuleId,
state: &State<'ast, T>,
@ -445,8 +445,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
declaration: SymbolDeclarationNode<'ast>,
module_id: &ModuleId,
state: &mut State<'ast, T>,
functions: &mut HashMap<DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>>,
constants: &mut HashMap<ConstantIdentifier<'ast>, TypedConstantSymbol<'ast, T>>,
functions: &mut TypedFunctionSymbols<'ast, T>,
constants: &mut TypedConstantSymbols<'ast, T>,
symbol_unifier: &mut SymbolUnifier<'ast>,
) -> Result<(), Vec<Error>> {
let mut errors: Vec<Error> = vec![];
@ -506,8 +506,13 @@ impl<'ast, T: Field> Checker<'ast, T> {
.in_file(module_id),
),
true => {
constants
.insert(declaration.id, TypedConstantSymbol::Here(c.clone()));
constants.push((
CanonicalConstantIdentifier::new(
declaration.id,
module_id.into(),
),
TypedConstantSymbol::Here(c.clone()),
));
self.insert_into_scope(Variable::with_id_and_type(
declaration.id,
c.get_type(),
@ -600,7 +605,9 @@ impl<'ast, T: Field> Checker<'ast, T> {
.constants
.entry(import.module_id.to_path_buf())
.or_default()
.get(import.symbol_id)
.iter()
.find(|(i, _)| *i == &import.symbol_id)
.map(|(_, c)| c)
.cloned();
match (function_candidates.len(), type_candidate, const_candidate) {
@ -653,7 +660,10 @@ impl<'ast, T: Field> Checker<'ast, T> {
}});
}
true => {
constants.insert(declaration.id, TypedConstantSymbol::There(import.module_id.to_path_buf(), import.symbol_id));
let imported_id = CanonicalConstantIdentifier::new(import.symbol_id, import.module_id);
let id = CanonicalConstantIdentifier::new(declaration.id, module_id.into());
constants.push((id.clone(), TypedConstantSymbol::There(imported_id)));
self.insert_into_scope(Variable::with_id_and_type(declaration.id, ty.clone()));
state
@ -750,8 +760,8 @@ impl<'ast, T: Field> Checker<'ast, T> {
module_id: &ModuleId,
state: &mut State<'ast, T>,
) -> Result<(), Vec<Error>> {
let mut checked_functions = HashMap::new();
let mut checked_constants = HashMap::new();
let mut checked_functions = TypedFunctionSymbols::new();
let mut checked_constants = TypedConstantSymbols::new();
// check if the module was already removed from the untyped ones
let to_insert = match state.modules.remove(module_id) {
@ -856,7 +866,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let v = Variable::with_id_and_type(
match generic {
Constant::Generic(g) => g.name,
DeclarationConstant::Generic(g) => g.name,
_ => unreachable!(),
},
Type::Uint(UBitwidth::B32),
@ -996,7 +1006,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
} else {
match generics_map.insert(g.value, index).is_none() {
true => {
generics.push(Some(Constant::Generic(GenericIdentifier {
generics.push(Some(DeclarationConstant::Generic(GenericIdentifier {
name: g.value,
index,
})));
@ -1112,16 +1122,17 @@ impl<'ast, T: Field> Checker<'ast, T> {
fn check_generic_expression(
&mut self,
expr: ExpressionNode<'ast>,
module_id: &ModuleId,
constants_map: &HashMap<ConstantIdentifier<'ast>, Type<'ast, T>>,
generics_map: &HashMap<Identifier<'ast>, usize>,
) -> Result<Constant<'ast>, ErrorInner> {
) -> Result<DeclarationConstant<'ast>, ErrorInner> {
let pos = expr.pos();
match expr.value {
Expression::U32Constant(c) => Ok(Constant::Concrete(c)),
Expression::U32Constant(c) => Ok(DeclarationConstant::Concrete(c)),
Expression::IntConstant(c) => {
if c <= BigUint::from(2u128.pow(32) - 1) {
Ok(Constant::Concrete(
Ok(DeclarationConstant::Concrete(
u32::from_str_radix(&c.to_str_radix(16), 16).unwrap(),
))
} else {
@ -1138,7 +1149,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
match (constants_map.get(name), generics_map.get(&name)) {
(Some(ty), None) => {
match ty {
Type::Uint(UBitwidth::B32) => Ok(Constant::Identifier(name, 32usize)),
Type::Uint(UBitwidth::B32) => Ok(DeclarationConstant::Constant(CanonicalConstantIdentifier::new(name, module_id.into()))),
_ => Err(ErrorInner {
pos: Some(pos),
message: format!(
@ -1148,7 +1159,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
})
}
}
(None, Some(index)) => Ok(Constant::Generic(GenericIdentifier { name, index: *index })),
(None, Some(index)) => Ok(DeclarationConstant::Generic(GenericIdentifier { name, index: *index })),
_ => Err(ErrorInner {
pos: Some(pos),
message: format!("Undeclared symbol `{}` in function definition", name)
@ -1182,6 +1193,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
UnresolvedType::Array(t, size) => {
let checked_size = self.check_generic_expression(
size.clone(),
module_id,
state.constants.get(module_id).unwrap_or(&HashMap::new()),
generics_map,
)?;

View file

@ -1,161 +1,153 @@
use crate::static_analysis::propagation::Propagator;
use crate::static_analysis::Propagator;
use crate::typed_absy::folder::*;
use crate::typed_absy::result_folder::ResultFolder;
use crate::typed_absy::types::{Constant, DeclarationStructType, GStructMember};
use crate::typed_absy::types::DeclarationConstant;
use crate::typed_absy::*;
use std::collections::HashMap;
use std::convert::TryInto;
use zokrates_field::Field;
pub struct ConstantInliner<'ast, 'a, T: Field> {
type ProgramConstants<'ast, T> =
HashMap<OwnedTypedModuleId, HashMap<Identifier<'ast>, TypedExpression<'ast, T>>>;
pub struct ConstantInliner<'ast, T> {
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
propagator: Propagator<'ast, 'a, T>,
constants: ProgramConstants<'ast, T>,
}
impl<'ast, 'a, T: Field> ConstantInliner<'ast, 'a, T> {
impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
pub fn new(
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
propagator: Propagator<'ast, 'a, T>,
constants: ProgramConstants<'ast, T>,
) -> Self {
ConstantInliner {
modules,
location,
propagator,
constants,
}
}
pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
let mut constants = HashMap::new();
let mut inliner = ConstantInliner::new(
p.modules.clone(),
p.main.clone(),
Propagator::with_constants(&mut constants),
);
let constants = ProgramConstants::new();
let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone(), constants);
inliner.fold_program(p)
}
fn module(&self) -> &TypedModule<'ast, T> {
self.modules.get(&self.location).unwrap()
}
fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId {
let prev = self.location.clone();
self.location = location;
self.constants.entry(self.location.clone()).or_default();
prev
}
fn get_constant(&mut self, id: &Identifier) -> Option<TypedConstant<'ast, T>> {
self.modules
.get(&self.location)
.unwrap()
.constants
.get(id.clone().try_into().unwrap())
.cloned()
.map(|symbol| self.get_canonical_constant(symbol))
fn treated(&self, id: &TypedModuleId) -> bool {
self.constants.contains_key(id)
}
fn get_canonical_constant(
&mut self,
symbol: TypedConstantSymbol<'ast, T>,
) -> TypedConstant<'ast, T> {
match symbol {
TypedConstantSymbol::There(module_id, id) => {
let location = self.change_location(module_id);
let symbol = self.module().constants.get(id).cloned().unwrap();
fn get_constant(
&self,
id: &CanonicalConstantIdentifier<'ast>,
) -> Option<TypedExpression<'ast, T>> {
self.constants
.get(&id.module)
.and_then(|constants| constants.get(&id.id.into()))
.cloned()
}
let symbol = self.get_canonical_constant(symbol);
let _ = self.change_location(location);
symbol
}
TypedConstantSymbol::Here(tc) => {
let tc: TypedConstant<T> = self.fold_constant(tc);
TypedConstant {
expression: self.propagator.fold_expression(tc.expression).unwrap(),
..tc
}
}
}
fn get_constant_for_identifier(
&self,
id: &Identifier<'ast>,
) -> Option<TypedExpression<'ast, T>> {
self.constants
.get(&self.location)
.and_then(|constants| constants.get(&id))
.cloned()
}
}
impl<'ast, 'a, T: Field> Folder<'ast, T> for ConstantInliner<'ast, 'a, T> {
fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
TypedProgram {
modules: p
.modules
impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
fn fold_module_id(&mut self, id: OwnedTypedModuleId) -> OwnedTypedModuleId {
// anytime we encounter a module id, visit the corresponding module if it hasn't been done yet
if !self.treated(&id) {
let current_m_id = self.change_location(id.clone());
let m = self.modules.remove(&id).unwrap();
let m = self.fold_module(m);
self.modules.insert(id.clone(), m);
self.change_location(current_m_id);
}
id
}
fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> {
TypedModule {
constants: m
.constants
.into_iter()
.map(|(module_id, module)| {
self.change_location(module_id.clone());
(module_id, self.fold_module(module))
.map(|(id, tc)| {
let constant = match tc {
TypedConstantSymbol::There(imported_id) => {
// visit the imported symbol. This triggers visiting the corresponding module if needed
let imported_id = self.fold_canonical_constant_identifier(imported_id);
// after that, the constant must have been defined defined in the global map. It is already reduced
// to a literal, so running propagation isn't required
self.get_constant(&imported_id).unwrap()
}
TypedConstantSymbol::Here(c) => {
let non_propagated_constant = fold_constant(self, c).expression;
// folding the constant above only reduces it to an expression containing only literals, not to a single literal.
// propagating with an empty map of constants reduces it to a single literal
Propagator::with_constants(&mut HashMap::default())
.fold_expression(non_propagated_constant)
.unwrap()
}
};
// add to the constant map. The value added is always a single litteral
self.constants
.get_mut(&self.location)
.unwrap()
.insert(id.id.into(), constant.clone());
(
id,
TypedConstantSymbol::Here(TypedConstant {
ty: constant.get_type().clone(),
expression: constant,
}),
)
})
.collect(),
functions: m
.functions
.into_iter()
.map(|(key, fun)| {
(
self.fold_declaration_function_key(key),
self.fold_function_symbol(fun),
)
})
.collect(),
main: p.main,
}
}
fn fold_declaration_type(&mut self, t: DeclarationType<'ast>) -> DeclarationType<'ast> {
match t {
DeclarationType::Array(ref array_ty) => match array_ty.size {
Constant::Identifier(name, _) => {
let tc = self.get_constant(&name.into()).unwrap();
let expression: UExpression<'ast, T> = tc.expression.try_into().unwrap();
match expression.inner {
UExpressionInner::Value(v) => DeclarationType::array((
self.fold_declaration_type(*array_ty.ty.clone()),
Constant::Concrete(v as u32),
)),
_ => unreachable!("expected u32 value"),
}
}
_ => t,
},
DeclarationType::Struct(struct_ty) => DeclarationType::struc(DeclarationStructType {
members: struct_ty
.members
.into_iter()
.map(|m| GStructMember::new(m.id, self.fold_declaration_type(*m.ty)))
.collect(),
..struct_ty
}),
_ => t,
}
}
fn fold_type(&mut self, t: Type<'ast, T>) -> Type<'ast, T> {
use self::GType::*;
match t {
Array(ref array_type) => match &array_type.size.inner {
UExpressionInner::Identifier(v) => match self.get_constant(v) {
Some(tc) => {
let expression: UExpression<'ast, T> = tc.expression.try_into().unwrap();
Type::array(GArrayType::new(
self.fold_type(*array_type.ty.clone()),
expression,
))
}
None => t,
},
_ => t,
},
Struct(struct_type) => Type::struc(GStructType {
members: struct_type
.members
.into_iter()
.map(|m| GStructMember::new(m.id, self.fold_type(*m.ty)))
.collect(),
..struct_type
}),
_ => t,
}
}
fn fold_constant_symbol(
fn fold_declaration_constant(
&mut self,
s: TypedConstantSymbol<'ast, T>,
) -> TypedConstantSymbol<'ast, T> {
let tc = self.get_canonical_constant(s);
TypedConstantSymbol::Here(tc)
c: DeclarationConstant<'ast>,
) -> DeclarationConstant<'ast> {
match c {
// replace constants by their concrete value in declaration types
DeclarationConstant::Constant(id) => {
DeclarationConstant::Concrete(match self.get_constant(&id).unwrap() {
TypedExpression::Uint(UExpression {
inner: UExpressionInner::Value(v),
..
}) => v as u32,
_ => unreachable!("all constants found in declaration types should be reduceable to u32 literals"),
})
}
c => c,
}
}
fn fold_field_expression(
@ -163,10 +155,12 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ConstantInliner<'ast, 'a, T> {
e: FieldElementExpression<'ast, T>,
) -> FieldElementExpression<'ast, T> {
match e {
FieldElementExpression::Identifier(ref id) => match self.get_constant(id) {
Some(c) => self.fold_constant(c).try_into().unwrap(),
None => fold_field_expression(self, e),
},
FieldElementExpression::Identifier(ref id) => {
match self.get_constant_for_identifier(id) {
Some(c) => c.try_into().unwrap(),
None => fold_field_expression(self, e),
}
}
e => fold_field_expression(self, e),
}
}
@ -176,8 +170,8 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ConstantInliner<'ast, 'a, T> {
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::Identifier(ref id) => match self.get_constant(id) {
Some(c) => self.fold_constant(c).try_into().unwrap(),
BooleanExpression::Identifier(ref id) => match self.get_constant_for_identifier(id) {
Some(c) => c.try_into().unwrap(),
None => fold_boolean_expression(self, e),
},
e => fold_boolean_expression(self, e),
@ -190,9 +184,9 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ConstantInliner<'ast, 'a, T> {
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
match e {
UExpressionInner::Identifier(ref id) => match self.get_constant(id) {
UExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id) {
Some(c) => {
let e: UExpression<'ast, T> = self.fold_constant(c).try_into().unwrap();
let e: UExpression<'ast, T> = c.try_into().unwrap();
e.into_inner()
}
None => fold_uint_expression_inner(self, size, e),
@ -207,13 +201,15 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ConstantInliner<'ast, 'a, T> {
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
ArrayExpressionInner::Identifier(ref id) => match self.get_constant(id) {
Some(c) => {
let e: ArrayExpression<'ast, T> = self.fold_constant(c).try_into().unwrap();
e.into_inner()
ArrayExpressionInner::Identifier(ref id) => {
match self.get_constant_for_identifier(id) {
Some(c) => {
let e: ArrayExpression<'ast, T> = c.try_into().unwrap();
e.into_inner()
}
None => fold_array_expression_inner(self, ty, e),
}
None => fold_array_expression_inner(self, ty, e),
},
}
e => fold_array_expression_inner(self, ty, e),
}
}
@ -224,9 +220,10 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ConstantInliner<'ast, 'a, T> {
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::Identifier(ref id) => match self.get_constant(id) {
StructExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id)
{
Some(c) => {
let e: StructExpression<'ast, T> = self.fold_constant(c).try_into().unwrap();
let e: StructExpression<'ast, T> = c.try_into().unwrap();
e.into_inner()
}
None => fold_struct_expression_inner(self, ty, e),
@ -265,7 +262,7 @@ mod tests {
};
let constants: TypedConstantSymbols<_> = vec![(
const_id,
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(Bn128Field::from(1))),
@ -353,7 +350,7 @@ mod tests {
};
let constants: TypedConstantSymbols<_> = vec![(
const_id,
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::Boolean,
TypedExpression::Boolean(BooleanExpression::Value(true)),
@ -442,7 +439,7 @@ mod tests {
};
let constants: TypedConstantSymbols<_> = vec![(
const_id,
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::Uint(UBitwidth::B32),
UExpressionInner::Value(1u128)
@ -543,7 +540,7 @@ mod tests {
};
let constants: TypedConstantSymbols<_> = vec![(
const_id,
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::Array(
@ -682,7 +679,7 @@ mod tests {
.collect(),
constants: vec![
(
const_a_id,
CanonicalConstantIdentifier::new(const_a_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
@ -691,7 +688,7 @@ mod tests {
)),
),
(
const_b_id,
CanonicalConstantIdentifier::new(const_b_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Add(
@ -740,7 +737,7 @@ mod tests {
.collect(),
constants: vec![
(
const_a_id,
CanonicalConstantIdentifier::new(const_a_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
@ -749,7 +746,7 @@ mod tests {
)),
),
(
const_b_id,
CanonicalConstantIdentifier::new(const_b_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
@ -801,7 +798,7 @@ mod tests {
.into_iter()
.collect(),
constants: vec![(
foo_const_id,
CanonicalConstantIdentifier::new(foo_const_id, "foo".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
@ -833,8 +830,11 @@ mod tests {
.into_iter()
.collect(),
constants: vec![(
foo_const_id,
TypedConstantSymbol::There(OwnedTypedModuleId::from("foo"), foo_const_id),
CanonicalConstantIdentifier::new(foo_const_id, "main".into()),
TypedConstantSymbol::There(CanonicalConstantIdentifier::new(
foo_const_id,
"foo".into(),
)),
)]
.into_iter()
.collect(),
@ -871,7 +871,7 @@ mod tests {
.into_iter()
.collect(),
constants: vec![(
foo_const_id,
CanonicalConstantIdentifier::new(foo_const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(

View file

@ -84,6 +84,7 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
} else {
r
};
// reduce the program to a single function
let r = reduce_program(r).map_err(Error::from)?;
// generate abi

View file

@ -614,7 +614,7 @@ fn compute_hash<T: Field>(f: &TypedFunction<T>) -> u64 {
#[cfg(test)]
mod tests {
use super::*;
use crate::typed_absy::types::Constant;
use crate::typed_absy::types::DeclarationConstant;
use crate::typed_absy::types::DeclarationSignature;
use crate::typed_absy::{
ArrayExpression, ArrayExpressionInner, DeclarationFunctionKey, DeclarationType,
@ -834,11 +834,11 @@ mod tests {
)])
.inputs(vec![DeclarationType::array((
DeclarationType::FieldElement,
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
))])
.outputs(vec![DeclarationType::array((
DeclarationType::FieldElement,
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
))]);
let foo: TypedFunction<Bn128Field> = TypedFunction {
@ -1053,11 +1053,11 @@ mod tests {
)])
.inputs(vec![DeclarationType::array((
DeclarationType::FieldElement,
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
))])
.outputs(vec![DeclarationType::array((
DeclarationType::FieldElement,
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
))]);
let foo: TypedFunction<Bn128Field> = TypedFunction {
@ -1285,11 +1285,11 @@ mod tests {
let foo_signature = DeclarationSignature::new()
.inputs(vec![DeclarationType::array((
DeclarationType::FieldElement,
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
))])
.outputs(vec![DeclarationType::array((
DeclarationType::FieldElement,
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
))])
.generics(vec![Some(
GenericIdentifier::with_name("K").index(0).into(),
@ -1299,7 +1299,7 @@ mod tests {
arguments: vec![DeclarationVariable::array(
"a",
DeclarationType::FieldElement,
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
)
.into()],
statements: vec![
@ -1363,7 +1363,7 @@ mod tests {
arguments: vec![DeclarationVariable::array(
"a",
DeclarationType::FieldElement,
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
)
.into()],
statements: vec![TypedStatement::Return(vec![

View file

@ -1,6 +1,6 @@
// Generic walk through a typed AST. Not mutating in place
use crate::typed_absy::types::{ArrayType, StructMember, StructType};
use crate::typed_absy::types::*;
use crate::typed_absy::*;
use zokrates_field::Field;
@ -80,6 +80,13 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_signature(self, s)
}
fn fold_declaration_constant(
&mut self,
c: DeclarationConstant<'ast>,
) -> DeclarationConstant<'ast> {
fold_declaration_constant(self, c)
}
fn fold_parameter(&mut self, p: DeclarationParameter<'ast>) -> DeclarationParameter<'ast> {
DeclarationParameter {
id: self.fold_declaration_variable(p.id),
@ -144,7 +151,40 @@ pub trait Folder<'ast, T: Field>: Sized {
}
fn fold_declaration_type(&mut self, t: DeclarationType<'ast>) -> DeclarationType<'ast> {
t
use self::GType::*;
match t {
Array(array_type) => Array(self.fold_declaration_array_type(array_type)),
Struct(struct_type) => Struct(self.fold_declaration_struct_type(struct_type)),
t => t,
}
}
fn fold_declaration_array_type(
&mut self,
t: DeclarationArrayType<'ast>,
) -> DeclarationArrayType<'ast> {
DeclarationArrayType {
ty: box self.fold_declaration_type(*t.ty),
size: self.fold_declaration_constant(t.size),
}
}
fn fold_declaration_struct_type(
&mut self,
t: DeclarationStructType<'ast>,
) -> DeclarationStructType<'ast> {
DeclarationStructType {
members: t
.members
.into_iter()
.map(|m| DeclarationStructMember {
ty: box self.fold_declaration_type(*m.ty),
..m
})
.collect(),
..t
}
}
fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> {
@ -175,6 +215,20 @@ pub trait Folder<'ast, T: Field>: Sized {
}
}
fn fold_canonical_constant_identifier(
&mut self,
i: CanonicalConstantIdentifier<'ast>,
) -> CanonicalConstantIdentifier<'ast> {
CanonicalConstantIdentifier {
module: self.fold_module_id(i.module),
id: i.id,
}
}
fn fold_module_id(&mut self, i: OwnedTypedModuleId) -> OwnedTypedModuleId {
i
}
fn fold_expression(&mut self, e: TypedExpression<'ast, T>) -> TypedExpression<'ast, T> {
match e {
TypedExpression::FieldElement(e) => self.fold_field_expression(e).into(),
@ -316,7 +370,12 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>(
constants: m
.constants
.into_iter()
.map(|(key, tc)| (key, f.fold_constant_symbol(tc)))
.map(|(id, tc)| {
(
f.fold_canonical_constant_identifier(id),
f.fold_constant_symbol(tc),
)
})
.collect(),
functions: m
.functions
@ -901,6 +960,13 @@ fn fold_signature<'ast, T: Field, F: Folder<'ast, T>>(
}
}
fn fold_declaration_constant<'ast, T: Field, F: Folder<'ast, T>>(
_: &mut F,
c: DeclarationConstant<'ast>,
) -> DeclarationConstant<'ast> {
c
}
pub fn fold_array_expression<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
e: ArrayExpression<'ast, T>,
@ -988,7 +1054,9 @@ pub fn fold_constant_symbol<'ast, T: Field, F: Folder<'ast, T>>(
) -> TypedConstantSymbol<'ast, T> {
match s {
TypedConstantSymbol::Here(tc) => TypedConstantSymbol::Here(f.fold_constant(tc)),
there => there,
TypedConstantSymbol::There(id) => {
TypedConstantSymbol::There(f.fold_canonical_constant_identifier(id))
}
}
}
@ -998,7 +1066,10 @@ pub fn fold_function_symbol<'ast, T: Field, F: Folder<'ast, T>>(
) -> TypedFunctionSymbol<'ast, T> {
match s {
TypedFunctionSymbol::Here(fun) => TypedFunctionSymbol::Here(f.fold_function(fun)),
there => there, // by default, do not fold modules recursively
TypedFunctionSymbol::There(key) => {
TypedFunctionSymbol::There(f.fold_declaration_function_key(key))
}
s => s,
}
}
@ -1023,8 +1094,8 @@ pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>(
modules: p
.modules
.into_iter()
.map(|(module_id, module)| (module_id, f.fold_module(module)))
.map(|(module_id, module)| (f.fold_module_id(module_id), f.fold_module(module)))
.collect(),
main: p.main,
main: f.fold_module_id(p.main),
}
}

View file

@ -19,9 +19,10 @@ mod variable;
pub use self::identifier::CoreIdentifier;
pub use self::parameter::{DeclarationParameter, GParameter};
pub use self::types::{
ConcreteFunctionKey, ConcreteSignature, ConcreteType, DeclarationFunctionKey,
DeclarationSignature, DeclarationType, GArrayType, GStructType, GType, GenericIdentifier,
IntoTypes, Signature, StructType, Type, Types, UBitwidth,
CanonicalConstantIdentifier, ConcreteFunctionKey, ConcreteSignature, ConcreteType,
ConstantIdentifier, DeclarationFunctionKey, DeclarationSignature, DeclarationType, GArrayType,
GStructType, GType, GenericIdentifier, IntoTypes, Signature, StructType, Type, Types,
UBitwidth,
};
use crate::typed_absy::types::ConcreteGenericsAssignment;
@ -62,17 +63,18 @@ pub type TypedModules<'ast, T> = HashMap<OwnedTypedModuleId, TypedModule<'ast, T
pub type TypedFunctionSymbols<'ast, T> =
HashMap<DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>>;
pub type ConstantIdentifier<'ast> = &'ast str;
#[derive(Clone, Debug, PartialEq)]
pub enum TypedConstantSymbol<'ast, T> {
Here(TypedConstant<'ast, T>),
There(OwnedTypedModuleId, ConstantIdentifier<'ast>),
There(CanonicalConstantIdentifier<'ast>),
}
/// A collection of `TypedConstantSymbol`s
pub type TypedConstantSymbols<'ast, T> =
HashMap<ConstantIdentifier<'ast>, TypedConstantSymbol<'ast, T>>;
/// It is still ordered, as we inline the constants in the order they are declared
pub type TypedConstantSymbols<'ast, T> = Vec<(
CanonicalConstantIdentifier<'ast>,
TypedConstantSymbol<'ast, T>,
)>;
/// A typed program as a collection of modules, one of them being the main
#[derive(PartialEq, Debug, Clone)]
@ -188,12 +190,17 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> {
let res = self
.constants
.iter()
.map(|(key, symbol)| match symbol {
.map(|(id, symbol)| match symbol {
TypedConstantSymbol::Here(ref tc) => {
format!("const {} {} = {}", tc.ty, key, tc.expression)
format!("const {} {} = {}", tc.ty, id.id, tc.expression)
}
TypedConstantSymbol::There(ref module_id, ref id) => {
format!("from \"{}\" import {} as {}", module_id.display(), id, key)
TypedConstantSymbol::There(ref imported_id) => {
format!(
"from \"{}\" import {} as {}",
imported_id.module.display(),
imported_id.id,
id.id
)
}
})
.chain(self.functions.iter().map(|(key, symbol)| match symbol {
@ -291,6 +298,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> {
#[derive(Clone, PartialEq, Debug)]
pub struct TypedConstant<'ast, T> {
// the type is already stored in the TypedExpression, but we want to avoid awkward trait bounds in `fmt::Display`
pub ty: Type<'ast, T>,
pub expression: TypedExpression<'ast, T>,
}
@ -303,6 +311,7 @@ impl<'ast, T> TypedConstant<'ast, T> {
impl<'ast, T: fmt::Display> fmt::Display for TypedConstant<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// using `self.expression.get_type()` would be better here but ends up requiring stronger trait bounds
write!(f, "const {}({})", self.ty, self.expression)
}
}

View file

@ -1,4 +1,4 @@
use crate::typed_absy::types::Constant;
use crate::typed_absy::types::DeclarationConstant;
use crate::typed_absy::GVariable;
use std::fmt;
@ -18,7 +18,7 @@ impl<'ast, S> From<GVariable<'ast, S>> for GParameter<'ast, S> {
}
}
pub type DeclarationParameter<'ast> = GParameter<'ast, Constant<'ast>>;
pub type DeclarationParameter<'ast> = GParameter<'ast, DeclarationConstant<'ast>>;
impl<'ast, S: fmt::Display + Clone> fmt::Display for GParameter<'ast, S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {

View file

@ -1,6 +1,6 @@
// Generic walk through a typed AST. Not mutating in place
use crate::typed_absy::types::{ArrayType, StructMember, StructType};
use crate::typed_absy::types::*;
use crate::typed_absy::*;
use zokrates_field::Field;
@ -97,6 +97,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fold_signature(self, s)
}
fn fold_declaration_constant(
&mut self,
c: DeclarationConstant<'ast>,
) -> Result<DeclarationConstant<'ast>, Self::Error> {
fold_declaration_constant(self, c)
}
fn fold_parameter(
&mut self,
p: DeclarationParameter<'ast>,
@ -107,6 +114,20 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
})
}
fn fold_canonical_constant_identifier(
&mut self,
i: CanonicalConstantIdentifier<'ast>,
) -> Result<CanonicalConstantIdentifier<'ast>, Self::Error> {
Ok(CanonicalConstantIdentifier {
module: self.fold_module_id(i.module)?,
id: i.id,
})
}
fn fold_module_id(&mut self, i: OwnedTypedModuleId) -> Result<OwnedTypedModuleId, Self::Error> {
Ok(i)
}
fn fold_name(&mut self, n: Identifier<'ast>) -> Result<Identifier<'ast>, Self::Error> {
Ok(n)
}
@ -224,6 +245,34 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
Ok(t)
}
fn fold_declaration_array_type(
&mut self,
t: DeclarationArrayType<'ast>,
) -> Result<DeclarationArrayType<'ast>, Self::Error> {
Ok(DeclarationArrayType {
ty: box self.fold_declaration_type(*t.ty)?,
size: self.fold_declaration_constant(t.size)?,
})
}
fn fold_declaration_struct_type(
&mut self,
t: DeclarationStructType<'ast>,
) -> Result<DeclarationStructType<'ast>, Self::Error> {
Ok(DeclarationStructType {
members: t
.members
.into_iter()
.map(|m| {
let id = m.id;
self.fold_declaration_type(*m.ty)
.map(|ty| DeclarationStructMember { ty: box ty, id })
})
.collect::<Result<_, _>>()?,
..t
})
}
fn fold_assignee(
&mut self,
a: TypedAssignee<'ast, T>,
@ -909,6 +958,7 @@ pub fn fold_declaration_function_key<'ast, T: Field, F: ResultFolder<'ast, T>>(
key: DeclarationFunctionKey<'ast>,
) -> Result<DeclarationFunctionKey<'ast>, F::Error> {
Ok(DeclarationFunctionKey {
module: f.fold_module_id(key.module)?,
signature: f.fold_signature(key.signature)?,
..key
})
@ -955,6 +1005,13 @@ fn fold_signature<'ast, T: Field, F: ResultFolder<'ast, T>>(
})
}
fn fold_declaration_constant<'ast, T: Field, F: ResultFolder<'ast, T>>(
_: &mut F,
c: DeclarationConstant<'ast>,
) -> Result<DeclarationConstant<'ast>, F::Error> {
Ok(c)
}
pub fn fold_array_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
e: ArrayExpression<'ast, T>,
@ -1046,7 +1103,9 @@ pub fn fold_constant_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
) -> Result<TypedConstantSymbol<'ast, T>, F::Error> {
match s {
TypedConstantSymbol::Here(tc) => Ok(TypedConstantSymbol::Here(f.fold_constant(tc)?)),
there => Ok(there),
TypedConstantSymbol::There(id) => Ok(TypedConstantSymbol::There(
f.fold_canonical_constant_identifier(id)?,
)),
}
}
@ -1056,7 +1115,10 @@ pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>(
) -> Result<TypedFunctionSymbol<'ast, T>, F::Error> {
match s {
TypedFunctionSymbol::Here(fun) => Ok(TypedFunctionSymbol::Here(f.fold_function(fun)?)),
there => Ok(there), // by default, do not fold modules recursively
TypedFunctionSymbol::There(key) => Ok(TypedFunctionSymbol::There(
f.fold_declaration_function_key(key)?,
)),
s => Ok(s),
}
}
@ -1088,6 +1150,6 @@ pub fn fold_program<'ast, T: Field, F: ResultFolder<'ast, T>>(
.into_iter()
.map(|(module_id, module)| f.fold_module(module).map(|m| (module_id, m)))
.collect::<Result<_, _>>()?,
main: p.main,
main: f.fold_module_id(p.main)?,
})
}

View file

@ -101,37 +101,51 @@ impl<'ast> fmt::Display for GenericIdentifier<'ast> {
#[derive(Debug)]
pub struct SpecializationError;
pub type ConstantIdentifier<'ast> = &'ast str;
#[derive(Clone, PartialEq, Eq, Debug, Hash, PartialOrd, Ord)]
pub struct CanonicalConstantIdentifier<'ast> {
pub module: OwnedTypedModuleId,
pub id: ConstantIdentifier<'ast>,
}
impl<'ast> CanonicalConstantIdentifier<'ast> {
pub fn new(id: ConstantIdentifier<'ast>, module: OwnedTypedModuleId) -> Self {
CanonicalConstantIdentifier { module, id }
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Constant<'ast> {
pub enum DeclarationConstant<'ast> {
Generic(GenericIdentifier<'ast>),
Concrete(u32),
Identifier(&'ast str, usize),
Constant(CanonicalConstantIdentifier<'ast>),
}
impl<'ast> From<u32> for Constant<'ast> {
impl<'ast> From<u32> for DeclarationConstant<'ast> {
fn from(e: u32) -> Self {
Constant::Concrete(e)
DeclarationConstant::Concrete(e)
}
}
impl<'ast> From<usize> for Constant<'ast> {
impl<'ast> From<usize> for DeclarationConstant<'ast> {
fn from(e: usize) -> Self {
Constant::Concrete(e as u32)
DeclarationConstant::Concrete(e as u32)
}
}
impl<'ast> From<GenericIdentifier<'ast>> for Constant<'ast> {
impl<'ast> From<GenericIdentifier<'ast>> for DeclarationConstant<'ast> {
fn from(e: GenericIdentifier<'ast>) -> Self {
Constant::Generic(e)
DeclarationConstant::Generic(e)
}
}
impl<'ast> fmt::Display for Constant<'ast> {
impl<'ast> fmt::Display for DeclarationConstant<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Constant::Generic(i) => write!(f, "{}", i),
Constant::Concrete(v) => write!(f, "{}", v),
Constant::Identifier(v, _) => write!(f, "{}", v),
DeclarationConstant::Generic(i) => write!(f, "{}", i),
DeclarationConstant::Concrete(v) => write!(f, "{}", v),
DeclarationConstant::Constant(v) => write!(f, "{}/{}", v.module.display(), v.id),
}
}
}
@ -142,15 +156,17 @@ impl<'ast, T> From<usize> for UExpression<'ast, T> {
}
}
impl<'ast, T> From<Constant<'ast>> for UExpression<'ast, T> {
fn from(c: Constant<'ast>) -> Self {
impl<'ast, T> From<DeclarationConstant<'ast>> for UExpression<'ast, T> {
fn from(c: DeclarationConstant<'ast>) -> Self {
match c {
Constant::Generic(i) => {
DeclarationConstant::Generic(i) => {
UExpressionInner::Identifier(i.name.into()).annotate(UBitwidth::B32)
}
Constant::Concrete(v) => UExpressionInner::Value(v as u128).annotate(UBitwidth::B32),
Constant::Identifier(v, size) => {
UExpressionInner::Identifier(Identifier::from(v)).annotate(UBitwidth::from(size))
DeclarationConstant::Concrete(v) => {
UExpressionInner::Value(v as u128).annotate(UBitwidth::B32)
}
DeclarationConstant::Constant(v) => {
UExpressionInner::Identifier(Identifier::from(v.id)).annotate(UBitwidth::B32)
}
}
}
@ -169,12 +185,12 @@ impl<'ast, T> TryInto<usize> for UExpression<'ast, T> {
}
}
impl<'ast> TryInto<usize> for Constant<'ast> {
impl<'ast> TryInto<usize> for DeclarationConstant<'ast> {
type Error = SpecializationError;
fn try_into(self) -> Result<usize, Self::Error> {
match self {
Constant::Concrete(v) => Ok(v as usize),
DeclarationConstant::Concrete(v) => Ok(v as usize),
_ => Err(SpecializationError),
}
}
@ -190,7 +206,7 @@ pub struct GStructMember<S> {
pub ty: Box<GType<S>>,
}
pub type DeclarationStructMember<'ast> = GStructMember<Constant<'ast>>;
pub type DeclarationStructMember<'ast> = GStructMember<DeclarationConstant<'ast>>;
pub type ConcreteStructMember = GStructMember<usize>;
pub type StructMember<'ast, T> = GStructMember<UExpression<'ast, T>>;
@ -242,7 +258,7 @@ pub struct GArrayType<S> {
pub ty: Box<GType<S>>,
}
pub type DeclarationArrayType<'ast> = GArrayType<Constant<'ast>>;
pub type DeclarationArrayType<'ast> = GArrayType<DeclarationConstant<'ast>>;
pub type ConcreteArrayType = GArrayType<usize>;
pub type ArrayType<'ast, T> = GArrayType<UExpression<'ast, T>>;
@ -250,7 +266,7 @@ impl<'ast, T: PartialEq> PartialEq<DeclarationArrayType<'ast>> for ArrayType<'as
fn eq(&self, other: &DeclarationArrayType<'ast>) -> bool {
*self.ty == *other.ty
&& match (self.size.as_inner(), &other.size) {
(UExpressionInner::Value(l), Constant::Concrete(r)) => *l as u32 == *r,
(UExpressionInner::Value(l), DeclarationConstant::Concrete(r)) => *l as u32 == *r,
_ => true,
}
}
@ -349,7 +365,7 @@ pub struct GStructType<S> {
pub members: Vec<GStructMember<S>>,
}
pub type DeclarationStructType<'ast> = GStructType<Constant<'ast>>;
pub type DeclarationStructType<'ast> = GStructType<DeclarationConstant<'ast>>;
pub type ConcreteStructType = GStructType<usize>;
pub type StructType<'ast, T> = GStructType<UExpression<'ast, T>>;
@ -588,7 +604,7 @@ impl<'de, S: Deserialize<'de>> Deserialize<'de> for GType<S> {
}
}
pub type DeclarationType<'ast> = GType<Constant<'ast>>;
pub type DeclarationType<'ast> = GType<DeclarationConstant<'ast>>;
pub type ConcreteType = GType<usize>;
pub type Type<'ast, T> = GType<UExpression<'ast, T>>;
@ -711,7 +727,7 @@ impl<'ast, T: fmt::Display + PartialEq + fmt::Debug> Type<'ast, T> {
// check the size if types match
match (&l.size.as_inner(), &r.size) {
// compare the sizes for concrete ones
(UExpressionInner::Value(v), Constant::Concrete(c)) => {
(UExpressionInner::Value(v), DeclarationConstant::Concrete(c)) => {
(*v as u32) == *c
}
_ => true,
@ -772,7 +788,7 @@ pub struct GFunctionKey<'ast, S> {
pub signature: GSignature<S>,
}
pub type DeclarationFunctionKey<'ast> = GFunctionKey<'ast, Constant<'ast>>;
pub type DeclarationFunctionKey<'ast> = GFunctionKey<'ast, DeclarationConstant<'ast>>;
pub type ConcreteFunctionKey<'ast> = GFunctionKey<'ast, usize>;
pub type FunctionKey<'ast, T> = GFunctionKey<'ast, UExpression<'ast, T>>;
@ -948,7 +964,7 @@ pub mod signature {
}
}
pub type DeclarationSignature<'ast> = GSignature<Constant<'ast>>;
pub type DeclarationSignature<'ast> = GSignature<DeclarationConstant<'ast>>;
pub type ConcreteSignature = GSignature<usize>;
pub type Signature<'ast, T> = GSignature<UExpression<'ast, T>>;
@ -968,15 +984,17 @@ pub mod signature {
&& match &t0.size {
// if the declared size is an identifier, we insert into the map, or check if the concrete size
// matches if this identifier is already in the map
Constant::Generic(id) => match constants.0.entry(id.clone()) {
DeclarationConstant::Generic(id) => match constants.0.entry(id.clone()) {
Entry::Occupied(e) => *e.get() == s1,
Entry::Vacant(e) => {
e.insert(s1);
true
}
},
Constant::Concrete(s0) => s1 == *s0 as usize,
Constant::Identifier(_, s0) => s1 == *s0,
DeclarationConstant::Concrete(s0) => s1 == *s0 as usize,
// in the case of a constant, we do not know the value yet, so we optimistically assume it's correct
// if it does not match, it will be caught during inlining
DeclarationConstant::Constant(..) => true,
}
}
(DeclarationType::FieldElement, GType::FieldElement)
@ -1000,9 +1018,11 @@ pub mod signature {
let ty = box specialize_type(*t0.ty, &constants)?;
let size = match t0.size {
Constant::Generic(s) => constants.0.get(&s).cloned().ok_or(s),
Constant::Concrete(s) => Ok(s.into()),
Constant::Identifier(_, s) => Ok((s as u32).into()),
DeclarationConstant::Generic(s) => constants.0.get(&s).cloned().ok_or(s),
DeclarationConstant::Concrete(s) => Ok(s.into()),
DeclarationConstant::Constant(..) => {
unreachable!("identifiers should have been removed in constant inlining")
}
}?;
GType::Array(GArrayType { size, ty })
@ -1053,7 +1073,7 @@ pub mod signature {
assert_eq!(self.generics.len(), values.len());
let decl_generics = self.generics.iter().map(|g| match g.clone().unwrap() {
Constant::Generic(g) => g,
DeclarationConstant::Generic(g) => g,
_ => unreachable!(),
});
@ -1096,7 +1116,7 @@ pub mod signature {
v.map(|v| {
(
match g.clone().unwrap() {
Constant::Generic(g) => g,
DeclarationConstant::Generic(g) => g,
_ => unreachable!(),
},
v,

View file

@ -1,4 +1,4 @@
use crate::typed_absy::types::{Constant, GStructType, UBitwidth};
use crate::typed_absy::types::{DeclarationConstant, GStructType, UBitwidth};
use crate::typed_absy::types::{GType, SpecializationError};
use crate::typed_absy::Identifier;
use crate::typed_absy::UExpression;
@ -11,7 +11,7 @@ pub struct GVariable<'ast, S> {
pub _type: GType<S>,
}
pub type DeclarationVariable<'ast> = GVariable<'ast, Constant<'ast>>;
pub type DeclarationVariable<'ast> = GVariable<'ast, DeclarationConstant<'ast>>;
pub type ConcreteVariable<'ast> = GVariable<'ast, usize>;
pub type Variable<'ast, T> = GVariable<'ast, UExpression<'ast, T>>;

View file

@ -0,0 +1,4 @@
{
"entry_point": "./tests/tests/constants/import/destination.zok",
"tests": []
}

View file

@ -0,0 +1,4 @@
from "./origin.zok" import foo
def main():
assert(foo([1, 1]))
return

View file

@ -0,0 +1,3 @@
const u32 N = 1 + 1
def foo(field[N] a) -> bool:
return true