A tutorial on building a merkle tree AIR script in Plonky3

Alex

Cryptography Engineer

Plonky3 is a toolkit that includes primitives such as polynomial commitment schemes for implementing polynomial IOPs (PIOPs). Although it is mainly designed for implementing STARK-based zkVMs, we find it helpful to understand how Plonky3 works by implementing a simple AIR script. We implemented a Merkle tree AIR script using the Poseidon2 hash AIR script provided in the Plonky3 repo. The implementation of Merkle tree AIR script can be found here.

In this write-up, we will mainly go through the implementation of the Merkle tree AIR script and the usage of SubAirBuilder to incorporate other AIR scripts (Poseidon2 in this case) into ours. For readers who wish to understand other parts of Plonky3, we also recommend the Fibonacci AIR example by Brian Seong and other resources at Awesome Plonky3.

What is a Merkle Tree?

First, lets define the Merkle tree we will use in this example. In our AIR script, our goal is to prove that a node is a member of the Merkle tree via a merkle proof.

For example, if we want to prove L2 is part of the Merkle tree, we generate a proof that we know a Merkle proof (Hash0-0 and Hash1) with public input Hash0-1(hash(L2)) that can generate a root matching the Top Hash (also a public input) by hashing up from the leaves to the root of the Merkle tree. The way we verify the Merkle proof is by checking whether H(H(Hash0-0, Hash0-1),Hash1) == Top Hash.

Defining the AIR Script

Now, let's move on to the Plonky3 AIR script where we will define our main program. Think of the script as a proving machine with memory (or registers in a CPU) that holds all of our proving data (including the intermediate results during proving). Our goal is to design the layout of the memory that is most suitable for our needs. The memory layout is a matrix where one row represents one iteration in our program. We will also define constraints that will apply to each row as we will see later.

There are three main components in our AIR script. First, a main structure that defines all the data (witness, public input, ...) we need for the proof. Our memory layout is an n x m matrix where each row being one iteration of hashing in the Merkle tree. Second, an eval() function that defines the constraints for each iteration. Third, a generate_merkle_proof_trace() function that will generate the Merkle proof and fill the memory matrix with the witnesses.

pub struct MerkleTreeAir {
    pub root: u32,
    pub leaf_value: u32,
    pub poseidon2_air: Poseidon2Air<
        BabyBear,
        GenericPoseidon2LinearLayersBabyBear,
        WIDTH,
        SBOX_DEGREE,
        SBOX_REGISTERS,
        HALF_FULL_ROUNDS,
        PARTIAL_ROUNDS,
    >,
}
impl<F: Field> BaseAir<F> for MerkleTreeAir {
    fn width(&self) -> usize {
        // is_odd_index + Poseidon2Col2(includes current_node and sibling of the merkle proof)
        self.poseidon2_air.width() + 1
    }
}
pub struct MerkleTreeAir {
    pub root: u32,
    pub leaf_value: u32,
    pub poseidon2_air: Poseidon2Air<
        BabyBear,
        GenericPoseidon2LinearLayersBabyBear,
        WIDTH,
        SBOX_DEGREE,
        SBOX_REGISTERS,
        HALF_FULL_ROUNDS,
        PARTIAL_ROUNDS,
    >,
}
impl<F: Field> BaseAir<F> for MerkleTreeAir {
    fn width(&self) -> usize {
        // is_odd_index + Poseidon2Col2(includes current_node and sibling of the merkle proof)
        self.poseidon2_air.width() + 1
    }
}
pub struct MerkleTreeAir {
    pub root: u32,
    pub leaf_value: u32,
    pub poseidon2_air: Poseidon2Air<
        BabyBear,
        GenericPoseidon2LinearLayersBabyBear,
        WIDTH,
        SBOX_DEGREE,
        SBOX_REGISTERS,
        HALF_FULL_ROUNDS,
        PARTIAL_ROUNDS,
    >,
}
impl<F: Field> BaseAir<F> for MerkleTreeAir {
    fn width(&self) -> usize {
        // is_odd_index + Poseidon2Col2(includes current_node and sibling of the merkle proof)
        self.poseidon2_air.width() + 1
    }
}

Here, we define the main struct MerkleTreeAir. It includes the root and leaf_value as the public input. Here we just hard-coded them in the constraints. We can also use the public_values() function in the AirBuilder, so the prover can provide different public inputs for every proof. We also include the Poseidon2Air script in our main struct since we will use it to prove the correctness of our hashing in the Merkle tree. The width() function here will return the width of our memory matrix. Each row in our program is defined as below:

We define the first memory slot as a boolean flag to indicate whether the current node (can be the leaf or its ancestors) ,which we are hashing with the Merkle proof, has an even or odd index. i.e. If the leaf's index is odd, the hash input would be H(merkle_proof[0], leaf). Second to last memory slots are used in the Poseidon2 AIR script, which includes the left input, right input, hash output, and all of the intermediate data used in the hash function. As we can see, the width of the row is exactly poseidon2_air.width() + 1.

Constraints

Next, let's define the constraints in the eval() function. We use the AirBuilder trait to help building the AIR script. It gives us access to the memory matrix and to define the constraints. We use main.row_slice() to access each row where index 0 means the row for the current iteration, and 1 means the row for the next iteration.

impl<AB: AirBuilder<F = BabyBear>> Air<AB> for MerkleTreeAir {
    fn eval(&self, builder: &mut AB) {
        let main = builder.main();
        // current row
        let local = main.row_slice(0).unwrap();
        // next row
        let next = main.row_slice(1).unwrap();
        
        ...
        // add constraints here
    }
impl<AB: AirBuilder<F = BabyBear>> Air<AB> for MerkleTreeAir {
    fn eval(&self, builder: &mut AB) {
        let main = builder.main();
        // current row
        let local = main.row_slice(0).unwrap();
        // next row
        let next = main.row_slice(1).unwrap();
        
        ...
        // add constraints here
    }
impl<AB: AirBuilder<F = BabyBear>> Air<AB> for MerkleTreeAir {
    fn eval(&self, builder: &mut AB) {
        let main = builder.main();
        // current row
        let local = main.row_slice(0).unwrap();
        // next row
        let next = main.row_slice(1).unwrap();
        
        ...
        // add constraints here
    }

The constraints we need for our AIR script are as below:

  1. The left or right input for the next hash function is the hash output of the current iteration. If is_odd_index is true, the inputs should be H(merkle_proof, hash_output[i-1]), otherwise they should be H(hash_output[i-1], merkle_proof).
    As seen below, first we use a helper function to get the hash result. Afterwards, we use next[0] * (constraint) + (1 - next[0]) * (constraint) for the if condition. next[0] is the is_odd_index flag for the next iteration. builder.when_transition() here means it will apply this condition on all rows except the last.

// Use helper function from Poseidon2Cols to get the hash result
// Transmute the slice into Poseidon2Cols (starting at col 1)
let p2_cols: &Poseidon2Cols<
    _,
    WIDTH,
    SBOX_DEGREE,
    SBOX_REGISTERS,
    HALF_FULL_ROUNDS,
    PARTIAL_ROUNDS,
> = local[1..].borrow();
let hash_output = p2_cols.ending_full_rounds[HALF_FULL_ROUNDS - 1].post[0];

// Constrain: hash output == next row's left (col[1]) or right (col[2])
// input based on is_odd_index (col[0])
builder.when_transition().assert_zero(
    (next[0] * (hash_output - next[2]))
        + ((AB::Expr::ONE - next[0]) * (hash_output - next[1])),
);
// Use helper function from Poseidon2Cols to get the hash result
// Transmute the slice into Poseidon2Cols (starting at col 1)
let p2_cols: &Poseidon2Cols<
    _,
    WIDTH,
    SBOX_DEGREE,
    SBOX_REGISTERS,
    HALF_FULL_ROUNDS,
    PARTIAL_ROUNDS,
> = local[1..].borrow();
let hash_output = p2_cols.ending_full_rounds[HALF_FULL_ROUNDS - 1].post[0];

// Constrain: hash output == next row's left (col[1]) or right (col[2])
// input based on is_odd_index (col[0])
builder.when_transition().assert_zero(
    (next[0] * (hash_output - next[2]))
        + ((AB::Expr::ONE - next[0]) * (hash_output - next[1])),
);
// Use helper function from Poseidon2Cols to get the hash result
// Transmute the slice into Poseidon2Cols (starting at col 1)
let p2_cols: &Poseidon2Cols<
    _,
    WIDTH,
    SBOX_DEGREE,
    SBOX_REGISTERS,
    HALF_FULL_ROUNDS,
    PARTIAL_ROUNDS,
> = local[1..].borrow();
let hash_output = p2_cols.ending_full_rounds[HALF_FULL_ROUNDS - 1].post[0];

// Constrain: hash output == next row's left (col[1]) or right (col[2])
// input based on is_odd_index (col[0])
builder.when_transition().assert_zero(
    (next[0] * (hash_output - next[2]))
        + ((AB::Expr::ONE - next[0]) * (hash_output - next[1])),
);
  1. The first left or right input for the hash function should match the leaf_value in the public inputs.

// Constrain: first left (col[1]) or right (col[2]) input should be the leaf_value based on
// is_odd_index (col[0])
builder.when_first_row().assert_zero(
    (local[0] * (AB::Expr::from_u32(self.leaf_value) - local[2]))
        + ((AB::Expr::ONE - local[0])
        * (AB::Expr::from_u32(self.leaf_value) - local[1])),
);
// Constrain: first left (col[1]) or right (col[2]) input should be the leaf_value based on
// is_odd_index (col[0])
builder.when_first_row().assert_zero(
    (local[0] * (AB::Expr::from_u32(self.leaf_value) - local[2]))
        + ((AB::Expr::ONE - local[0])
        * (AB::Expr::from_u32(self.leaf_value) - local[1])),
);
// Constrain: first left (col[1]) or right (col[2]) input should be the leaf_value based on
// is_odd_index (col[0])
builder.when_first_row().assert_zero(
    (local[0] * (AB::Expr::from_u32(self.leaf_value) - local[2]))
        + ((AB::Expr::ONE - local[0])
        * (AB::Expr::from_u32(self.leaf_value) - local[1])),
);
  1. is_odd_index in each row should be 0 or 1.

// Constrain: is_odd should be bool
builder.when_transition().assert_bool(local[0].clone());
builder.when_last_row().assert_bool(local[0].clone());
// Constrain: is_odd should be bool
builder.when_transition().assert_bool(local[0].clone());
builder.when_last_row().assert_bool(local[0].clone());
// Constrain: is_odd should be bool
builder.when_transition().assert_bool(local[0].clone());
builder.when_last_row().assert_bool(local[0].clone());
  1. The last hash output should match the root in the public inputs.

// Constrain: final output == root
let root = AB::Expr::from_u32(self.root);
builder.when_last_row().assert_eq(root, hash_output);
// Constrain: final output == root
let root = AB::Expr::from_u32(self.root);
builder.when_last_row().assert_eq(root, hash_output);
// Constrain: final output == root
let root = AB::Expr::from_u32(self.root);
builder.when_last_row().assert_eq(root, hash_output);

Besides checking the constraints of the Merkle tree, we also needs to check the constraints of the Poseidon2 hash function. We will use the SubAirBuilder to access another air script and only expose part of the memory matrix to the sub-air script (column[1..] in our case). In the eval function, we add:

// create Poseidon2 sub air to evaluate Poseidon2 hash
// given input as column range [1..]
let p2_col_count = self.poseidon2_air.width();
let mut sub: SubAirBuilder<AB, Poseidon2AirType, AB::Var> =
    SubAirBuilder::new(builder, 1..1 + p2_col_count);
self.poseidon2_air.eval(&mut sub);
// create Poseidon2 sub air to evaluate Poseidon2 hash
// given input as column range [1..]
let p2_col_count = self.poseidon2_air.width();
let mut sub: SubAirBuilder<AB, Poseidon2AirType, AB::Var> =
    SubAirBuilder::new(builder, 1..1 + p2_col_count);
self.poseidon2_air.eval(&mut sub);
// create Poseidon2 sub air to evaluate Poseidon2 hash
// given input as column range [1..]
let p2_col_count = self.poseidon2_air.width();
let mut sub: SubAirBuilder<AB, Poseidon2AirType, AB::Var> =
    SubAirBuilder::new(builder, 1..1 + p2_col_count);
self.poseidon2_air.eval(&mut sub);

Put in the graph of the memory matrix, it should look like:

Generate trace

After defining the AIR script, we will now generate the proof witness and fill the memory matrix as above. We define a generate_merkle_proof_trace() function that will first generate the Merkle proof and then return a RowMajorMatrix filled with witness that satisfies the constraints we defined above.

Besides the Merkle proof, we also need to generate traces for the Poseison2 hash to fill the second to last columns in the matrix. For this we will use generate_trace_rows() function in the poseidon2_air library where it will calculate the hash result and return a row that include the trace we need to fill the matrix.

See our github repo for the complete function.

Run the ZK System

Now we have everything ready, we can start to configure and run the ZK system. We will not get into the details of the configuration of Plonky3 here, please refer to the Fibonacci AIR example for more details on the setup. In our example we are using configurations in the poseidon2-air example of Plonky3.


HashCloak specializes in enabling teams with the expertise to build privacy infrastructure leveraging the state of the art in advanced cryptography.

Schedule a call with us today.