Skip to content
This repository has been archived by the owner on Oct 20, 2024. It is now read-only.

feat(): allow macros as macro args (rebased) #300

Draft
wants to merge 6 commits into
base: stage
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 79 additions & 7 deletions huff_codegen/src/irgen/arg_calls.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use huff_utils::prelude::*;
use std::str::FromStr;

use crate::Codegen;

// Arguments can be literals, labels, opcodes, or constants
// !! IF THERE IS AMBIGUOUS NOMENCLATURE
// !! (E.G. BOTH OPCODE AND LABEL ARE THE SAME STRING)
Expand All @@ -9,6 +11,7 @@ use std::str::FromStr;
/// Arg Call Bubbling
#[allow(clippy::too_many_arguments)]
pub fn bubble_arg_call(
evm_version: &EVMVersion,
arg_name: &str,
bytes: &mut Vec<(usize, Bytes)>,
macro_def: &MacroDefinition,
Expand All @@ -18,6 +21,9 @@ pub fn bubble_arg_call(
// mis: Parent macro invocations and their indices
mis: &mut [(usize, MacroInvocation)],
jump_table: &mut JumpTable,
circular_codesize_invocations: &mut CircularCodeSizeIndices,
label_indices: &mut LabelIndices,
table_instances: &mut Jumps,
) -> Result<(), CodegenError> {
let starting_offset = *offset;

Expand Down Expand Up @@ -70,6 +76,7 @@ pub fn bubble_arg_call(
let ac_ = &ac.to_string();
return if last_mi.1.macro_name.eq(&macro_def.name) {
bubble_arg_call(
evm_version,
ac_,
bytes,
bubbled_macro_invocation,
Expand All @@ -78,9 +85,13 @@ pub fn bubble_arg_call(
offset,
&mut mis[..mis_len.saturating_sub(1)],
jump_table,
circular_codesize_invocations,
label_indices,
table_instances,
)
} else {
bubble_arg_call(
evm_version,
ac_,
bytes,
bubbled_macro_invocation,
Expand All @@ -89,14 +100,24 @@ pub fn bubble_arg_call(
offset,
mis,
jump_table,
circular_codesize_invocations,
label_indices,
table_instances,
)
}
}
MacroArg::Ident(iden) => {
tracing::debug!(target: "codegen", "Found MacroArg::Ident IN \"{}\" Macro Invocation: \"{}\"!", macro_invoc.1.macro_name, iden);

// Check for a constant first
if let Some(constant) = contract
// The opcode check needs to happens before the constants lookup
// because otherwise the mutex can deadlock when bubbling up to
// resolve macros as arguments.
if let Ok(o) = Opcode::from_str(iden) {
tracing::debug!(target: "codegen", "Found Opcode: {}", o);
let b = Bytes(o.to_string());
*offset += b.0.len() / 2;
bytes.push((starting_offset, b));
} else if let Some(constant) = contract
.constants
.lock()
.map_err(|_| {
Expand Down Expand Up @@ -130,11 +151,62 @@ pub fn bubble_arg_call(
*offset += push_bytes.len() / 2;
tracing::info!(target: "codegen", "OFFSET: {}, PUSH BYTES: {:?}", offset, push_bytes);
bytes.push((starting_offset, Bytes(push_bytes)));
} else if let Ok(o) = Opcode::from_str(iden) {
tracing::debug!(target: "codegen", "Found Opcode: {}", o);
let b = Bytes(o.to_string());
*offset += b.0.len() / 2;
bytes.push((starting_offset, b));
} else if let Some(ir_macro) = contract.find_macro_by_name(iden) {
tracing::debug!(target: "codegen", "ARG CALL IS MACRO: {}", iden);
tracing::debug!(target: "codegen", "CURRENT MACRO DEF: {}", macro_def.name);

let mut new_scopes = scope.to_vec();
new_scopes.push(ir_macro);
let mut new_mis = mis.to_vec();
new_mis.push((
*offset,
MacroInvocation {
macro_name: iden.to_string(),
args: vec![],
span: AstSpan(vec![]),
},
));

let mut res: BytecodeRes = match Codegen::macro_to_bytecode(
evm_version,
ir_macro,
contract,
&mut new_scopes,
*offset,
&mut new_mis,
false,
Some(circular_codesize_invocations),
) {
Ok(r) => r,
Err(e) => {
tracing::error!(
target: "codegen",
"FAILED TO RECURSE INTO MACRO \"{}\"",
ir_macro.name
);
return Err(e)
}
};

for j in res.unmatched_jumps.iter_mut() {
let new_index = j.bytecode_index;
j.bytecode_index = 0;
let mut new_jumps = if let Some(jumps) = jump_table.get(&new_index)
{
jumps.clone()
} else {
vec![]
};
new_jumps.push(j.clone());
jump_table.insert(new_index, new_jumps);
}
table_instances.extend(res.table_instances);
label_indices.extend(res.label_indices);

// Increase offset by byte length of recursed macro
*offset += res.bytes.iter().map(|(_, b)| b.0.len()).sum::<usize>() / 2;
// Add the macro's bytecode to the final result
res.bytes.iter().for_each(|(a, b)| bytes.push((*a, b.clone())));
} else {
tracing::debug!(target: "codegen", "Found Label Call: {}", iden);

Expand Down
4 changes: 4 additions & 0 deletions huff_codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ impl Codegen {
// Bubble up arg call by looking through the previous scopes.
// Once the arg value is found, add it to `bytes`
bubble_arg_call(
evm_version,
arg_name,
&mut bytes,
macro_def,
Expand All @@ -348,6 +349,9 @@ impl Codegen {
&mut offset,
mis,
&mut jump_table,
circular_codesize_invocations,
&mut label_indices,
&mut table_instances,
)?
}
}
Expand Down
90 changes: 90 additions & 0 deletions huff_core/tests/macro_invoc_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,93 @@ fn test_bubbled_arg_with_different_name() {
// Check the bytecode
assert_eq!(main_bytecode, expected_bytecode);
}

#[test]
fn test_macro_macro_arg() {
let source = r#"
#define constant TWO = 0x02

#define macro MUL_BY_10() = takes(1) returns (1) {
0x0a mul
}

#define macro EXEC_WITH_VALUE(value, macro) = takes(0) returns(1) {
<value> <macro>
}

#define macro MAIN() = takes(0) returns(0) {
EXEC_WITH_VALUE(TWO, MUL_BY_10)
}
"#;

// Lex + Parse
let flattened_source = FullFileSource { source, file: None, spans: vec![] };
let lexer = Lexer::new(flattened_source.source);
let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::<Vec<Token>>();
let mut parser = Parser::new(tokens, None);
let mut contract = parser.parse().unwrap();
contract.derive_storage_pointers();

let evm_version = EVMVersion::default();

// Create main and constructor bytecode
let main_bytecode = Codegen::generate_main_bytecode(&evm_version, &contract, None).unwrap();

// Full expected bytecode output (generated from huffc) (placed here as a reference)
let expected_bytecode = "6002600a02";

// Check the bytecode
assert_eq!(main_bytecode.to_lowercase(), expected_bytecode.to_lowercase());
}

#[test]
fn test_bubbled_macro_macro_arg() {
let source = r#"
#define constant TWO = 0x02

#define macro MUL_BY_10() = takes(1) returns (1) {
0x0a mul
}

#define macro DO_OP(op) = takes(0) returns(0) {
<op>
}

#define macro DIV_BY_5() = takes(1) returns (1) {
0x05 swap1 DO_OP(div)
}

#define macro EXEC_WITH_VALUE(value, macro) = takes(0) returns(1) {
<value> <macro>
}

#define macro SUM_RESULTS(value, macro1, macro2) = takes(0) returns (1) {
EXEC_WITH_VALUE(<value>, <macro1>)
EXEC_WITH_VALUE(<value>, <macro2>)
add
}

#define macro MAIN() = takes(0) returns(0) {
SUM_RESULTS(TWO, MUL_BY_10, DIV_BY_5)
}
"#;

// Lex + Parse
let flattened_source = FullFileSource { source, file: None, spans: vec![] };
let lexer = Lexer::new(flattened_source.source);
let tokens = lexer.into_iter().map(|x| x.unwrap()).collect::<Vec<Token>>();
let mut parser = Parser::new(tokens, None);
let mut contract = parser.parse().unwrap();
contract.derive_storage_pointers();

let evm_version = EVMVersion::default();

// Create main and constructor bytecode
let main_bytecode = Codegen::generate_main_bytecode(&evm_version, &contract, None).unwrap();

// Full expected bytecode output (generated from huffc) (placed here as a reference)
let expected_bytecode = "6002600a0260026005900401";

// Check the bytecode
assert_eq!(main_bytecode.to_lowercase(), expected_bytecode.to_lowercase());
}
Loading