To recap, these are the options we’re considering right now, discussed with @PhilWindle and @Mike over a call:
1. Global call counter
Mike’s proposal from the first post. This requires adding a frame_counter
(for lack of a better name) to the public VM circuit inputs and outputs. The VM circuit takes care of incrementing it right before making a call and immediately after returning.
This lets the kernel circuit build a sorted list of state transitions across iterations. This list will contain one item for each state transition done in the transaction. Once all iterations of the kernel circuit are done, the list can be checked so that end values and start values of consecutive transitions for the same slot do match. In Mike’s 2nd example:
- A::foo state_transitions = [ { counter: 4, slot: x, old_value: 1, new_value: 2 }, empty,... ]
- B::bar state_transitions = [ empty,... ]
- A::baz state_transitions = [ { counter: 2, slot: x, old_value: 0, new_value: 1 }, empty,... ]
=> list = [ { counter: 2, slot: x, old_value: 0, new_value: 1 }, { counter: 4, slot: x, old_value: 1, new_value: 2 }]
=> assert list[0].new_value == list[1].old_value
Note that, since the counter skips from 2 to 4, the kernel circuit cannot know (until all iterations are complete) if there’ll be any items in between, so it cannot squash them into a single transition of 0->2.
Implementing this implies:
- VM circuit: one extra field of input and output to track the
frame_counter
, plus logic to increment it immediately before and after each call. - Kernel: a list of
MAX_TRANSITIONS_PER_TX
(note that two writes to the same slot count as two separate items, since they cannot be squashed) - Base rollup: as the first circuit that runs after all iterations of the kernel are done, it needs to validate that consecutive transitions for each slot can be properly stitched together.
2. Global transition counter
Instead of keeping a frame_counter
, we keep a transition_counter
that tracks the number of transitions done so far in the tx. This still allows us to sort the list, but it also lets us know when there are no “holes” to be expected, so we can sometimes squash it as we process it in the successive kernel iterations.
In Mike’s 2nd example, this means:
- A::foo state_transitions = [ { counter: 1, slot: x, old_value: 1, new_value: 2 }, empty,... ]
- B::bar state_transitions = [ empty,... ]
- A::baz state_transitions = [ { counter: 0, slot: x, old_value: 0, new_value: 1 }, empty,... ]
=> list = [ { counter: 0, slot: x, old_value: 0, new_value: 1 }, { counter: 1, slot: x, old_value: 1, new_value: 2 }]
=> squashed = [{ counter: 0-1, slot: x, old_value: 0, new_value: 2 }]
The logic for squashing may be a bit difficult, since it requires to sort by slot first and then counter, and check across all items if there are holes or no to check if squashing is possible. Let’s bring up another example, built by Phil:
foo() {
x = 5
bar() {
baz() {
y = 9
x = 7
bat() { z = 4 }
}
}
x = 9
}
In this example, we’d have:
- A::foo state_transitions = [ { #0 x: 0 -> 5 }, { #4 x: 7 -> 9} ]
- B::bar state_transitions = []
- A::baz state_transitions = [ { #1 y: 0 -> 9}, { #2 x: 5 -> 7 } ]
- A::bat state_transitions = [ { #3 z: 0 -> 4 } ]
When the kernel circuit reaches baz, it has built the following structure:
x: [{#0 0 -> 5}, {#2 5 -> 7}, {#4 7 -> 9}]
y: [{#1 0 -> 9}]
With that information, it can safely squash the first two entries for x
, but it needs to wait another iteration before it can squash the last one, since index #3 could be for slot x
:
x: [{#0-2 0 -> 7}, {#4 7 -> 9}]
y: [{#1 0 -> 9}]
This approach allows squashing all transitions for a slot that occur in each function before jumping into a nested call. Also, it’s unclear to me how complex it is to implement this squashing in a zk circuit, so the tradeoff is unclear.
3. Per-slot transition counter
Similar to the approach above, only that instead of keeping a single transition_counter
, it keeps a counter for each slot affected in the tx. This opens up more opportunities to squash in the kernel circuit. In the foo-bar-baz-bat example, the outputs would be:
- A::foo state_transitions = [ { #0 x: 0 -> 5 }, { #2 x: 7 -> 9} ]
- B::bar state_transitions = []
- A::baz state_transitions = [ { #0 y: 0 -> 9}, { #1 x: 5 -> 7 } ]
- A::bat state_transitions = [ { #0 z: 0 -> 4 } ]
Which means that now, by the time the kernel circuit iterations reach baz
, we’d have assembled:
x: [{#0 0 -> 5}, {#1 5 -> 7}, {#2 7 -> 9}]
y: [{#0 0 -> 9}]
Here we can squash all of x
transitions, and it’s also easier to check (since we don’t need to see y
’s list to see what indices have been used).
However, this approach has a big downside: we now need to pass a lot more information to and from each VM circuit run. Rather than passing and returning a single transition_counter
, we now need to pass an array of {slot, counter}
pairs. The length of the array will be the maximum number of different slots that can be affected in a single tx (nevertheless, and unlike option 1, mulitple writes to the same slot take up only one item in that list). It’ll be responsibility of the VM circuit to increase the counter for a slot whenever it writes to it.
4. Per-slot state
Building on option 3, if we’re now passing a lot of information to and from each VM circuit run, we can just pass the current value for all slots modified in the tx so far, and return the updated state to the caller. This moves the validation of state transitions from the kernel circuit to the public VM circuit.
In the foo-bar-baz-bat example, when foo calls bar, it’d pass [x=5]
along with any arguments for the function call. And when bar returns to foo, it’d return x=7, y=9, z=4
along with any function return values. The kernel circuit now just checks that the state snapshots are passed correctly between calls, same as it checks that arguments are passed correctly and return values are returned correctly.
The key insight here is that the public VM circuit, unlike the private app circuit, is controlled by the network, so we can trust that it behaves properly when interpreting brillig. The VM circuit then, upon a SSTORE
instruction, it’d update it’s internal state snapshot with the new value. This snapshot would be received from the caller, and return updated to it.
The big advantage here is that we don’t have to assemble any sorted list in the kernel circuit or validate any state transition stitching. The kernel circuit just proves the initial values for each slot affected in the tx, and grabs the resulting state snapshot from the topmost call stack item to pass to the base rollup circuit, who executes these updates. Each iteration of the kernel circuit sees the following, but doesn’t need to validate anything.
- A::foo init_state=[ ] end_state=[ x=9, y=9, z=4 ]
- B::bar init_state=[ x=5 ] end_state=[ x=7, y=9, z=4 ]
- A::baz init_state=[ x=5 ] end_state=[ x=7, y=9, z=4 ]
- A::bat init_state=[ x=7, y=9 ] end_state=[ x=7, y=9, z=4 ]
Note that the resulting state is in the topmost item in the call stack, which condenses all state transitions to be done to the tree.
As an optimization, note that z=4 is not really needed outside of bat`, so we could not pass it back to baz when the call ends. Problem is that we won’t know whether it’ll be needed later until we finish the tx execution, but if we can run this optimization after the execution and before the proving, we can reduce the number of public inputs to the VM circuit.
Given the pros/cons for each option, it’d seem like option 2 is strictly better than option 1 (since it opens the door for optimizations by squashing elements in the list more easily), and option 4 is strictly better than 3 (since it removes effort in the kernel circuit using the same amount of additional fields in the public inputs).