Identifying the ordering of state access across contract calls

An interesting problem with public functions which @alvaro just raised with me (having chatted with Guillaume, Maxim, Adam).

tl;dr the Public Inputs ABI needs to contain a counter to help the kernel circuit to validate public state changes (and indeed private state changes).

Example:

contract A {
    // Imagine x starts life as 0:
    int x;

    fn foo() {
        x++; // 0 -> 1
        B::bar();
    }
    
    fn baz() {
        x++; // 1 -> 2
    }
}

contract B {
    fn bar() {
        A::baz();
    }
}

What do the state_transition objects look like, that the Public Kernel Circuit sees, in order of processing them?

  • A::foo: state_transitions = [ { slot: x, old_value: 0, new_value: 1 }, empty,... ]
  • B::bar: state_transitions = [ empty,... ]
  • A: :baz state_transitions = [ { slot: x, old_value: 1, new_value: 2 }, empty,... ]

At the moment, our approach is to ‘squash’ state transitions for x into a single state transition on the Public Kernel Circuit’s own state_transitions array. So, ultimately, it would be squashed to:

  • `kernel_state_transitions = [ { slot: x, old_value: 0, new_value: 2 }, empty,… ]

BUT WHAT IF WE MODIFIED A::foo TO CALL B::bar FIRST, before incrementing x:

contract A {
    // Imagine x starts life as 0:
    int x;

    fn foo() {
        B::bar();
        x++; // 1 -> 2
    }
    
    fn baz() {
        x++; // 0 -> 1
    }
}

contract B {
    fn bar() {
        A::baz();
    }
}

Notice now the ordering of the transitions has been switched. But the Kernel Circuit has no information to be able to figure out which ordering to use!

Repeating the above… What do the state_transition objects look like, that the Public Kernel Circuit sees, in order of processing them?

  • A::foo: state_transitions = [ { slot: x, old_value: 1, new_value: 2 }, empty,... ] <— NO LONGER THE FIRST TRANSITION
  • B::bar: state_transitions = [ empty,... ]
  • A: :baz state_transitions = [ { slot: x, old_value: 0, new_value: 1 }, empty,... ]

Ultimately, this SHOULD be squashed to:

  • `kernel_state_transitions = [ { slot: x, old_value: 0, new_value: 2 }, empty,… ]
    … just like the first example. But the Kernel Circuit doesn’t know which transition happened first!

Suggestion to fix this problem:

Each function also tracks a counter, and (I’m open to suggestions for names :sweat_smile:). Maybe a clause counter, because it tracks clauses of the body of each function.

fn foo () {
    // counter = 0
    bar() {
        // counter++ // = 1
        baz() {
            // counter++ // = 2
        }
        // counter++ // =  3
    }
    // counter++ // = 4
}

Each function’s Public Inputs ABI then would need to expose its start_counter and end_counter. These values would go nicely inside the call_context of the ABI. Our Noir Contract standard libraries (which Alvaro has been writing) can correctly track these counters, every time a call is made).

Any state_reads which happen when the counter is n, would need to contain this extra piece of information:
state_read = { counter, storage_slot, value }, and similarly:
state_transitions = { counter, storage_slot, old_value, new_value }


Now if we repeat our 2nd example:

  • A::foo: state_transitions = [ { counter: 4, slot: x, old_value: 1, new_value: 2 }, empty,... ]
  • B::bar: state_transitions = [ empty,... ]
  • A: :baz state_transitions = [ { counter: 2, slot: x, old_value: 0, new_value: 1 }, empty,... ]

Great! The Kernel Circuit can now see that the final-listed transition of x was the first to occur. So it can double-check that the new_value = 1 of the first-ordered transition of x (from 0 → 1) matches the old_value = 1 of the next-ordered transition of x (from 1 → 2).

18 Likes

Good catch, and good solution! I think the counter approach works fine, though I’m not sure how much complexity is introduced into the kernel circuit having to “navigate” through the counters. Maybe we’d need to provide additional inputs so we don’t have to actually sort within the circuit? But I’m out of my depth here.


EDIT: Broken by @alvaro below.

As another approach, would it work if each iteration of the public circuit reported the state changes caused by it and all nested calls? So in your example what the kernel circuit would see is:

  • A::foo: state_transitions = [ { slot: x, old_value: 0, new_value: 2 }, empty,... ]
  • B::bar: state_transitions = [ { slot: x, old_value: 1, new_value: 2 }, empty,... ]
  • A: :baz state_transitions = [ { slot: x, old_value: 1, new_value: 2 }, empty,... ]

This would mean that a nested public circuit would return the list of transitions to its caller (similarly to how it returns the function return values) which the caller then “merges” into its own state. Does this make sense?

So in the first example, as the public circuit for A::foo is running, it first tracks the 0->1 transition, calls into B::bar and gets back the 1->2 transition (along with any results from bar) which merges into its own set of transitions.


EDIT: Also broken by @alvaro below :slight_smile:

Alternatively, would it be too crazy to change how the app and kernel circuits are invoked, so we do one iteration of the circuits for each different value of counter? This would mean that the circuits would not verify entire function runs, but rather chunks of code without any nested calls in them.

In your example, the kernel circuit would have five iterations instead of 3:

  • A::foo (before calling bar)
  • B::bar (before calling baz)
  • A::baz
  • B::bar (after returning from baz)
  • A::foo (after returning from bar)

This feels more natural to me as it just follows the flow of the program, rather than having to reassemble the side affects after the fact. On the other hand, it also makes app circuits more annoying to deal with, since each function is now composed of multiple chunks rather than a single circuit.

@spalladino I think the slicing solution would increase kernel complexity by a lot by having to save in public function public inputs the current state of the VM memory at the point of the slicing (you would be continuing from a given snapshot instead of starting clean)…
And the first option of aggregating the state changes of the nested calls I think wouldn’t fix the case where a downstream call reads a value that was updated by a upstream function after calling it

True, you’d need to snapshot the current state to be able to go back and resume from it.

You’re right, the kernel is still missing info to know whether the transition in the nested function is legit or not, and we’re back to adding counters!

Looking through the current implementation I think we are susceptible to the problem you have highlighted @Mike. As it stands, state transitions are squashed at the function level by the public VM circuit (currently simulated by the acir simulator) but no consideration is paid to possible interleaving of state updates and function calls.

I think the counter solution will work. A per slot transition counter might be easier than a global ‘execution frame’ counter as above. It would mean there weren’t holes in the counter values for a given slot. Is this a problem? Maybe the fact that this requires state across calls causes problems?

The public kernel circuit can perform a kind of ‘add, sort, verify and squash’ for each iteration. This is where having an unbroken sequence of counter values would be beneficial. If the circuit sees counter==0 and counter==1 on the first iteration (for the same slot), it can squash at that point. Reducing the accumulated state transitions. But if they were counter==0 and counter==4, the circuit doesn’t know if there might be 1, 2 and/or 3 coming in a later iteration.

9 Likes

To expand on @PhilWindle’s idea, this would mean adding a new field to the PublicVM circuit input and outputs, something like a:

state_transition_counters: Array<(slot, counter), STATE_TRANSITIONS_LENGTH>

Whenever the public circuit emits a new transition, it increases the corresponding slot. And when it calls a nested function, it receives (along with the return_values) the updated counters from it, and carries on from there. This would give us all the information we need for squashing on every iteration of the kernel circuit, whenever it’s possible.


However, if we’re passing a data structure that has one element per updated slot, wouldn’t it make sense to just pass around the actual values, and handle everything at the public circuit level? It’d be similar to the first proposal here, only that each public circuit iteration also receives the “starting” set of state transitions from its caller. I believe this should fix the issue that Alvaro pointed out.

So the oracle only provides the starting value for each slot the first time it appears in the tx. The VM circuit proves that every state transition was done properly, starting from those initial values. And the kernel circuit proves that 1) the starting values belong to the state tree, and 2) the starting and end state transitions in a nested call “stitch” together properly, exactly the same as the args and return values.

In Mike’s first example, the resulting fields would be:

  • A::foo .starting_transitions = [], .end_transitions = [x: 0->2]
  • B::bar .starting_transitions = [], .end_transitions = [x: 0->1]
  • A::baz .starting_transitions = [], .end_transitions = [x: 0->1]

And in the second one:

  • A::foo .starting_transitions = [], .end_transitions = [x: 0->2]
  • B::bar .starting_transitions = [x: 0->1], .end_transitions = [x: 1->2]
  • A::baz .starting_transitions = [x: 0->1], .end_transitions = [x: 1->2]

Does this make sense, or am I falling into the same errors again? In the meantime, we can get started with the original global counter idea.

To recap, these are the options we’re considering right now, discussed with @PhilWindle and @Mike over a call:

1. Global call counter

Mike’s proposal from the first post. This requires adding a frame_counter (for lack of a better name) to the public VM circuit inputs and outputs. The VM circuit takes care of incrementing it right before making a call and immediately after returning.

This lets the kernel circuit build a sorted list of state transitions across iterations. This list will contain one item for each state transition done in the transaction. Once all iterations of the kernel circuit are done, the list can be checked so that end values and start values of consecutive transitions for the same slot do match. In Mike’s 2nd example:

- A::foo state_transitions = [ { counter: 4, slot: x, old_value: 1, new_value: 2 }, empty,... ]
- B::bar state_transitions = [ empty,... ]
- A::baz state_transitions = [ { counter: 2, slot: x, old_value: 0, new_value: 1 }, empty,... ]

=> list = [ { counter: 2, slot: x, old_value: 0, new_value: 1 }, { counter: 4, slot: x, old_value: 1, new_value: 2 }]
=> assert list[0].new_value == list[1].old_value

Note that, since the counter skips from 2 to 4, the kernel circuit cannot know (until all iterations are complete) if there’ll be any items in between, so it cannot squash them into a single transition of 0->2.

Implementing this implies:

  • VM circuit: one extra field of input and output to track the frame_counter, plus logic to increment it immediately before and after each call.
  • Kernel: a list of MAX_TRANSITIONS_PER_TX (note that two writes to the same slot count as two separate items, since they cannot be squashed)
  • Base rollup: as the first circuit that runs after all iterations of the kernel are done, it needs to validate that consecutive transitions for each slot can be properly stitched together.

2. Global transition counter

Instead of keeping a frame_counter, we keep a transition_counter that tracks the number of transitions done so far in the tx. This still allows us to sort the list, but it also lets us know when there are no “holes” to be expected, so we can sometimes squash it as we process it in the successive kernel iterations.

In Mike’s 2nd example, this means:

- A::foo state_transitions = [ { counter: 1, slot: x, old_value: 1, new_value: 2 }, empty,... ]
- B::bar state_transitions = [ empty,... ]
- A::baz state_transitions = [ { counter: 0, slot: x, old_value: 0, new_value: 1 }, empty,... ]

=> list = [ { counter: 0, slot: x, old_value: 0, new_value: 1 }, { counter: 1, slot: x, old_value: 1, new_value: 2 }]
=> squashed = [{ counter: 0-1, slot: x, old_value: 0, new_value: 2 }]

The logic for squashing may be a bit difficult, since it requires to sort by slot first and then counter, and check across all items if there are holes or no to check if squashing is possible. Let’s bring up another example, built by Phil:

foo() { 
    x = 5
    bar() {         
        baz() {             
            y = 9
            x = 7
            bat() { z = 4 }
        }         
    }     
    x = 9
}

In this example, we’d have:

- A::foo state_transitions = [ { #0 x: 0 -> 5 }, { #4 x: 7 -> 9} ]
- B::bar state_transitions = []
- A::baz state_transitions = [ { #1 y: 0 -> 9}, { #2 x: 5 -> 7 } ]
- A::bat state_transitions = [ { #3 z: 0 -> 4 } ]

When the kernel circuit reaches baz, it has built the following structure:

x: [{#0 0 -> 5}, {#2 5 -> 7}, {#4 7 -> 9}]
y: [{#1 0 -> 9}]

With that information, it can safely squash the first two entries for x, but it needs to wait another iteration before it can squash the last one, since index #3 could be for slot x:

x: [{#0-2 0 -> 7}, {#4 7 -> 9}]
y: [{#1 0 -> 9}]

This approach allows squashing all transitions for a slot that occur in each function before jumping into a nested call. Also, it’s unclear to me how complex it is to implement this squashing in a zk circuit, so the tradeoff is unclear.

3. Per-slot transition counter

Similar to the approach above, only that instead of keeping a single transition_counter, it keeps a counter for each slot affected in the tx. This opens up more opportunities to squash in the kernel circuit. In the foo-bar-baz-bat example, the outputs would be:

- A::foo state_transitions = [ { #0 x: 0 -> 5 }, { #2 x: 7 -> 9} ]
- B::bar state_transitions = []
- A::baz state_transitions = [ { #0 y: 0 -> 9}, { #1 x: 5 -> 7 } ]
- A::bat state_transitions = [ { #0 z: 0 -> 4 } ]

Which means that now, by the time the kernel circuit iterations reach baz, we’d have assembled:

x: [{#0 0 -> 5}, {#1 5 -> 7}, {#2 7 -> 9}]
y: [{#0 0 -> 9}]

Here we can squash all of x transitions, and it’s also easier to check (since we don’t need to see y’s list to see what indices have been used).

However, this approach has a big downside: we now need to pass a lot more information to and from each VM circuit run. Rather than passing and returning a single transition_counter, we now need to pass an array of {slot, counter} pairs. The length of the array will be the maximum number of different slots that can be affected in a single tx (nevertheless, and unlike option 1, mulitple writes to the same slot take up only one item in that list). It’ll be responsibility of the VM circuit to increase the counter for a slot whenever it writes to it.

4. Per-slot state

Building on option 3, if we’re now passing a lot of information to and from each VM circuit run, we can just pass the current value for all slots modified in the tx so far, and return the updated state to the caller. This moves the validation of state transitions from the kernel circuit to the public VM circuit.

In the foo-bar-baz-bat example, when foo calls bar, it’d pass [x=5] along with any arguments for the function call. And when bar returns to foo, it’d return x=7, y=9, z=4 along with any function return values. The kernel circuit now just checks that the state snapshots are passed correctly between calls, same as it checks that arguments are passed correctly and return values are returned correctly.

The key insight here is that the public VM circuit, unlike the private app circuit, is controlled by the network, so we can trust that it behaves properly when interpreting brillig. The VM circuit then, upon a SSTORE instruction, it’d update it’s internal state snapshot with the new value. This snapshot would be received from the caller, and return updated to it.

The big advantage here is that we don’t have to assemble any sorted list in the kernel circuit or validate any state transition stitching. The kernel circuit just proves the initial values for each slot affected in the tx, and grabs the resulting state snapshot from the topmost call stack item to pass to the base rollup circuit, who executes these updates. Each iteration of the kernel circuit sees the following, but doesn’t need to validate anything.

- A::foo init_state=[ ] end_state=[ x=9, y=9, z=4 ]
- B::bar init_state=[ x=5 ] end_state=[ x=7, y=9, z=4 ]
- A::baz init_state=[ x=5 ] end_state=[ x=7, y=9, z=4 ]
- A::bat init_state=[ x=7, y=9 ] end_state=[ x=7, y=9, z=4 ]

Note that the resulting state is in the topmost item in the call stack, which condenses all state transitions to be done to the tree.

As an optimization, note that z=4 is not really needed outside of bat`, so we could not pass it back to baz when the call ends. Problem is that we won’t know whether it’ll be needed later until we finish the tx execution, but if we can run this optimization after the execution and before the proving, we can reduce the number of public inputs to the VM circuit.


Given the pros/cons for each option, it’d seem like option 2 is strictly better than option 1 (since it opens the door for optimizations by squashing elements in the list more easily), and option 4 is strictly better than 3 (since it removes effort in the kernel circuit using the same amount of additional fields in the public inputs).

1 Like

In debugging another issue, I’m thinking that this problem may actually extend to any execution side-effect where ordering is important. For instance, we’re losing track of the ordering of emitted events across nested calls, or enqueued public function calls.

I may be mistaken though, as I need to dig a bit more, but I wanted to raise this soon since it is related to what @david-banks and team is working on regarding ordering of read requests. Maybe this is another reason to move to the one-kernel-per-frame approach?

1 Like

@jean made a good argument for the per-frame approach as well. It was something like… We will be more confident that our solution is comprehensive if we take an approach where the structure of the kernel more closely matches the program execution.

Edit: but @alvaro’s comment still holds true:

I think the slicing solution would increase kernel complexity by a lot by having to save in public function public inputs the current state of the VM memory at the point of the slicing (you would be continuing from a given snapshot instead of starting clean)…

And per-frame would work for the public VM, but this problem is also present in private functions, where I don’t know if it would even be possible to have a per-frame approach…

Can confirm this affects enqueued public function calls. Today the private kernel, on each iteration, just pushes the public_call_stack of the current app circuit to the accumulated stack:

const auto& this_public_call_stack = private_call_public_inputs.public_call_stack;
push_array_to_array(composer, this_public_call_stack, public_inputs.end.public_call_stack);

So enqueued public function calls will always be executed in the order of the private kernel circuit iterations. And we cannot fix this with the information we get today from the circuits, since the kernel cannot distinguish between these two scenarios:

fn foo() {
  enqueue call to pub1();
  bar();
}

fn bar() {
  enqueue call to pub2();
}
fn foo() {
  bar();
  enqueue call to pub1();
}

fn bar() {
  enqueue call to pub2();
}
1 Like

Example of why state ordering matters for private calls

This simple example below shows why state access ordering matters for private calls. In this example, the function bar() can read commitment A before it has been created.

contract SimpleExample {
    X: Set<Note>;
    Y: Set<Note>;

	fn call_bar_increment_x() {
	    bar();
	    let x_note = X.get(); // read request (gets commitment A)
        let new_x = Note::new(x_note.value + 1);
        X.remove(x_note); // nullify A
	    X.insert(new_x); // create commitment B
    }
    
    // 100th caller gets some reward
	fn bar() {
        let x_note = X.get(); // read request (should get A, gets B)
        
        if (x_note.value > 100) {
            Y.insert("SENDER WINS ALL TOKENS"); // or whatever
        }
	}
}

Transition counters for private calls

Most of the discussion in this thread so far has been for public calls. As shown above, this issue still exists for private calls. Here I have attempted to adapt the transition counter proposal for private:

App circuits / call_stack_items

  • App circuit accepts a start_transition_counter as input and increments it for each transition that occurs within the call or nested calls. The end_transition_counter is then an output of the app circuit.
    • These two values will be stored in a call’s public_inputs
    • Each read_request within a call will be assigned a transition_counter value
      • This kind of informs the kernel “what was the last state transition before this read request?”
    • Each new_commitment and new_nullifier will be assigned a transition_counter and will then trigger an increment of the counter
    • Each kernel iteration will have a start_transition_counter and end_transition_counter

Private kernel

  • Initial kernel will enforce the following constraints:
    • call.public_inputs.start_transition_counter == 0
  • Inner kernel will enforce the following constraints:
    • start_transition_counter == previous_kernel.end_transition_counter
  • Initial and inner kernel will both enforce the following constraints:
    • end_transition_counter >= previous_kernel.start_transition_counter
    • All read_requests must have a transition_counter within start_transition_counter <= transition_counter <= end_transition_counter
      • In the context of read_requests, this counter is really meant as an ordering stamp so it is clear what commitments and nullifiers exist at the time of the read
    • All new_commitments and new_nullifiers must have a transition_counter within start_transition_counter <= transition_counter < end_transition_counter
      • Note that unlike read_requests, here the transition counter must be strictly less than the end counter, since the last transition will be end_transition_counter-1 and will then increment the counter to end_transition_counter
    • Less important / possibly excessive rules:
      • Each state transition (new commitment or nullifier) within a single call must have counter greater than the previous
    • Each read request within a single call must have counter greater or equal than the previous
  • Rules relating to pending commitments / transient storage
    • A read_request can only be matched with a new_commitment if the read’s counter is >= that of the commitment
    • A new_nullifier can only be matched with (and therefore squashed with) a new_commitment if the nullifier’s counter is > that of the commitment
    • At the end of a kernel iteration, all unmatched “transient” state accesses/transitions must be forwarded to the next kernel to hopefully be matched in a later iteration.

Abbreviations:

  • call.public_inputs abbreviates private_inputs.private_call.call_stack_item.public_inputs
  • previous_kernel.*_transition_counter abbreviates previous_kernel.public_inputs.end.*_transition_counter

Program counters instead?

This was an idea of @cheethas / @Maddiaa’s.

Basically we could just associate every acir opcode with a program counter / index. Along with acir_hash a circuit would need to provide max_program_counter. Every foreign call would have an associated counter value which it will output. We would then be able to use this counter to order all foreign calls (state accesses/transitions, events, …).