Back in the early days of LLMs, it was a struggle to get structured output out of models, which made it hard to use the results in a programmatic way -- there was no well-defined interface. That problem was solved a while back, and now most inference systems support JSON schema outputs. As part of building my educational inference runtime, I implemented support for this as well. My code leverages the outlines-core library which does a lot of the heavy lifting. I'll be going over what the library does and how that integrates into the inference runtime.
Let's start with an example JSON schema:
{"type": "object","properties": {"foo": {"type": "string"},"bar": {"type": "integer"},"baz": {"enum": ["a", "b", "c"]}},"required": ["foo"]}
There are two key observations that allow us to make JSON schemas usable in the LLM sampling process.
- Most JSON schemas can be turned into regular expressions (see supported features here).
- A regular expression can be turned into a deterministic finite automaton (DFA), i.e. a state machine that transitions based on bytes.
Using outlines-core, we can see that the above JSON schema turns into the following regex:
Admittedly, not the prettiest thing in the world, but a regex nonetheless. A lot of the ugliness comes from handling optional whitespace, so if you strip that away it's fairly straightforward and what you would expect, i.e.
\{"foo":"<string>"(,"bar":<int>)?(,"baz":("a"|"b"|"c"))?\}
It's worth noting that, by design, this regex fixes the order of the fields in the JSON object, so it's a strict subset of valid outputs, which is okay.
Next, outlines-core uses the regex_automata crate to convert the regex to a DFA. This DFA encodes the valid per-byte transitions from any state as well as the final states that match the regex. But recall that LLMs do not output single bytes at a time, they output tokens. So the outlines-core library goes through a process to convert this DFA into an alternate DFA with the same states that has per-token transitions instead of per-byte transitions (saved in the Index object).
| Converting a regex byte-level DFA to a token-level DFA (simplified). |
To do this, it performs a breadth-first search over the DFA states using the vocabulary (set of all tokens) as the possible transitions. For each token, it runs the bytes through the DFA and checks whether the resulting state is not a dead state. If so, it records the token as a valid transition and continues the search from the resulting state (if it hasn't been seen before). It builds this entire DFA upfront, which can be expensive, but once it's available, checking the transitions at a given state is constant time. That lets us easily answer the question: at any point in the LLM's output, what is the set of valid tokens that can be produced next?
LLM logits are masked based on valid token transitions provided by the DFA.
Now that we have the valid token transitions, we're ready to integrate into the inference runtime. The integration happens at the logits layer, right before we apply softmax and sample the token output. We keep track of the current state within the DFA and advance it for each generated token. And then, based on that state, it's a simple masking process where, for any token that isn't in the valid set of transitions, we set the logits to negative infinity so they won't be sampled. That's it!
What I've described is just one way to handle structured outputs for an LLM. There are various approaches that support different types of structures with different performance characteristics. One particularly interesting one is llguidance, which supports context-free grammars (in addition to JSON schema and regex) and is also super fast thanks to using token tries and sparse masks. Perhaps an interesting topic for a future post.