1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

propagate embed call in reducer

This commit is contained in:
schaeff 2023-05-23 10:08:21 +02:00
parent 633ee82547
commit 066d3533d8
6 changed files with 44 additions and 14 deletions

View file

@ -0,0 +1 @@
Propagate embed call

View file

@ -55,7 +55,7 @@ impl fmt::Display for Error {
pub struct Propagator<'ast, T> { pub struct Propagator<'ast, T> {
// constants keeps track of constant expressions // constants keeps track of constant expressions
// we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is // we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is
constants: Constants<'ast, T>, pub constants: Constants<'ast, T>,
} }
impl<'ast, T: Field> Propagator<'ast, T> { impl<'ast, T: Field> Propagator<'ast, T> {

View file

@ -213,6 +213,8 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
let return_value = self.fold_expression(return_value)?; let return_value = self.fold_expression(return_value)?;
let return_value = self.propagator.fold_expression(return_value)?;
Ok(FunctionCallOrExpression::Expression( Ok(FunctionCallOrExpression::Expression(
E::from(return_value).into_inner(), E::from(return_value).into_inner(),
)) ))
@ -226,22 +228,28 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
FunctionCallExpression::<_, E>::new(e.function_key, generics, arguments) FunctionCallExpression::<_, E>::new(e.function_key, generics, arguments)
))), ))),
Err(InlineError::Flat(embed, generics, output_type)) => { Err(InlineError::Flat(embed, generics, output_type)) => {
let identifier = self.ssa.issue_next_identifier(CoreIdentifier::Call(0)); let identifier = self.ssa.issue_next_identifier(CoreIdentifier::Call);
let var = Variable::new(identifier.clone(), output_type); let var = Variable::new(identifier.clone(), output_type);
let v: TypedAssignee<'ast, T> = var.clone().into(); let v: TypedAssignee<'ast, T> = var.clone().into();
self.statement_buffer.push( let definition = TypedStatement::embed_call_definition(
TypedStatement::embed_call_definition( v,
v, EmbedCall::new(embed, generics, arguments),
EmbedCall::new(embed, generics, arguments), )
) .span(span);
.span(span),
); let definition = self.propagator.fold_statement(definition)?;
Ok(FunctionCallOrExpression::Expression(
E::identifier(identifier).span(span), self.statement_buffer.extend(definition);
))
let e = match self.propagator.constants.get(&identifier) {
Some(v) => E::try_from(v.clone()).unwrap().into_inner(),
None => E::identifier(identifier),
};
Ok(FunctionCallOrExpression::Expression(e.span(span)))
} }
}; };

View file

@ -8,7 +8,7 @@ pub type SourceIdentifier<'ast> = std::borrow::Cow<'ast, str>;
pub enum CoreIdentifier<'ast> { pub enum CoreIdentifier<'ast> {
#[serde(borrow)] #[serde(borrow)]
Source(ShadowedIdentifier<'ast>), Source(ShadowedIdentifier<'ast>),
Call(usize), Call,
Constant(CanonicalConstantIdentifier<'ast>), Constant(CanonicalConstantIdentifier<'ast>),
Condition(usize), Condition(usize),
} }
@ -17,7 +17,7 @@ impl<'ast> fmt::Display for CoreIdentifier<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
CoreIdentifier::Source(s) => write!(f, "{}", s), CoreIdentifier::Source(s) => write!(f, "{}", s),
CoreIdentifier::Call(i) => write!(f, "#CALL_RETURN_AT_INDEX_{}", i), CoreIdentifier::Call => write!(f, "#CALL_RETURN"),
CoreIdentifier::Constant(c) => write!(f, "{}/{}", c.module.display(), c.id), CoreIdentifier::Constant(c) => write!(f, "{}/{}", c.module.display(), c.id),
CoreIdentifier::Condition(i) => write!(f, "#CONDITION_{}", i), CoreIdentifier::Condition(i) => write!(f, "#CONDITION_{}", i),
} }

View file

@ -0,0 +1,5 @@
{
"entry_point": "./tests/tests/constants/propagate_embed.zok",
"max_constraint_count": 2,
"tests": []
}

View file

@ -0,0 +1,16 @@
import "utils/casts/field_to_u32";
from "EMBED" import unpack;
def foo<N>() -> field {
return 1;
}
def main() -> field {
u32 N = field_to_u32(1);
for u32 i in 0..N {
log("{}", i);
}
bool[1] B = unpack(1);
u32 P = B[0] ? 1 : 0;
return foo::<N>() + foo::<P>();
}