Models That Prove Their Own Correctness
A review of https://eccc.weizmann.ac.il/report/2024/098/
We live in an age where (1) there are now entire industries focused on generating mathematical proofs, through formal verification, theorem proving and interactive proving [1] (2) many people believe AI will be eating the world. Perhaps metaphorically, perhaps literally.
It is only natural that some of the brightest minds alive are working to prove properties about these AI systems, to make them more safe, robust, and perhaps even more effective. In this note, I'll discuss two different ways in which interactive proof systems can be used for this purpose.
Technique 1: Creating proofs of knowledge about an AI model's training or inference run
A lot of recent efforts have focused on creating interactive proofs of a model's training or inference run. That could allow a model provider to succinctly prove to a user that the model adheres to certain attractive properties. For example, you could:
- Allow a model provider to prove that the same model was used for different inferences, so the user is really getting their money's worth of AI.
- Allow a model provider to prove that the same model was used during inference as was audited by a trusted third party.
- Allow a model provider to prove that the model is not secretly trying to take over the world.
However, this approach does have challenges. It's likely that generating an interactive proof will have at least 1000x overhead over the original computation for the foreseeable future even in the most optimistic engineering scenario's.
Technique 2: Letting an AI model's inference prove its own correctness
In contrast to this, a different setup was recently invented by Amit, Goldwasser, Paradise and Rothblum. Their contribution allows for a model to prove their own correctness. And amazingly, the model can do this using interactive proofs (which can be made non-interactive), and we know we can create interactive proofs for a very large amount of statements.
There is a subtle limitation to be aware of. If your goal is to just create an interactive proof of some computation (technique 1), you can do so for arbitrary computations, even if the computation is not legible or usable in any way. However, the purpose of self-proving models, is to actually prove correctness of the result. The verifier must thus have some understanding and notion of what correctness means exactly. We know how to compute and prove a number's greatest common divisor, whereas we don't and can't know exactly what weather it'll be 10 years from now.
So, how does a model go about proving correctness? The basic technique is really conceptually quite simple:
- During training, a model should not just train on correct examples, but it additionally trains on correct proving transcripts. For example, if a model is trained to output greatest common divisors, it should additionally train on transcripts which can convince a verifier that a number is indeed the greatest common divisor. To make this more effective and efficient, the transcript can also be annotated, succinctness is not our friend at this point.
- During inference, part of the model's output will not just be the output relating to a query, but also a transcript which should convince a verifier of the correctness of the output. Taking the above example again, ...
There are many scenario's where ZK proofs can further compress the interaction.
Now, how exciting is this? While the generality of this recent insight is intriguing, there are reasons to be skeptical. Because it seems that this approach is incredibly inefficient.
Most ML is used for probabilistic functions, and they can be extremely efficient and powerful at that. And most formal verification in the context of ML, is geared towards ensuring that the model is unbiased or that it has a limited bias in its probabilities with regards to some hypothetical ground truth. In turn, ML can be used to "speed up" the analysis to template most of the required formal verification
The big question is: for which classes of functions, will the transcript to verify correctness of its output (and training on it) not be orders of magnitude larger than computing and formally verifying the function directly? There may be function-specific transcripts which are very efficient, but that eats away at the generality of this technique. A more general way would be to train a model on transcripts of formal verification languages like Lean or ACL2, but this would be extremely expensive. Training on small functions (and their correctness) will quickly get way to expensive.
To conclude, this is very exciting work, but will need fundamental improvements with regards to creating succinct transcripts in an automated way.
Footnotes
[1] By no means is this an exhaustive or mutually exclusive list of proving methods.