Home Machine Learning Structured Generative AI. Learn how to constrain your mannequin to output… | by Oren Matar | Apr, 2024

Structured Generative AI. Learn how to constrain your mannequin to output… | by Oren Matar | Apr, 2024

0
Structured Generative AI. Learn how to constrain your mannequin to output… | by Oren Matar | Apr, 2024

[ad_1]

Learn how to constrain your mannequin to output outlined codecs

On this submit I’ll clarify and show the idea of “structured generative AI”: generative AI constrained to outlined codecs. By the top of the submit, you’ll perceive the place and when it may be used and implement it whether or not you’re crafting a transformer mannequin from scratch or using Hugging Face’s fashions. Moreover, we’ll cowl an necessary tip for tokenization that’s particularly related for structured languages.

One of many many makes use of of generative AI is as a translation software. This typically includes translating between two human languages however may embrace pc languages or codecs. For instance, your utility might must translate pure (human) language to SQL:

Pure language: “Get buyer names and emails of shoppers from the US”

SQL: "SELECT identify, e mail FROM clients WHERE nation = 'USA'"

Or to transform textual content knowledge right into a JSON format:

Pure language: “I'm John Doe, telephone quantity is 555–123–4567,
my associates are Anna and Sara”

JSON: {identify: "John Doe",
phone_number: "555–123–5678",
associates: {
identify: [["Anna", "Sara"]]}
}

Naturally, many extra purposes are doable, for different structured languages. The coaching course of for such duties includes feeding examples of pure language alongside structured codecs to an encoder-decoder mannequin. Alternatively, leveraging a pre-trained Language Mannequin (LLM) can suffice.

Whereas reaching 100% accuracy is unattainable, there may be one class of errors that we are able to get rid of: syntax errors. These are violations of the format of the language, like changing commas with dots, utilizing desk names that aren’t current within the SQL schema, or omitting bracket closures, which render SQL or JSON non-executable.

The truth that we’re translating right into a structured language implies that the listing of legit tokens at each technology step is restricted, and pre-determined. If we may insert this data into the generative AI course of we are able to keep away from a variety of incorrect outcomes. That is the concept behind structured generative AI: constrain it to an inventory of legit tokens.

A fast reminder on how tokens are generated

Whether or not using an encoder-decoder or GPT structure, token technology operates sequentially. Every token’s choice depends on each the enter and beforehand generated tokens, persevering with till a <finish> token is generated, signifying the completion of the sequence. At every step, a classifier assigns logit values to all tokens within the vocabulary, representing the chance of every token as the following choice. The following token is sampled primarily based on these logits.

The decoder classifier assigns a logit to each token within the vocabulary (Picture by creator)

Limiting token technology

To constrain token technology, we incorporate data of the output language’s construction. Illegitimate tokens have their logits set to -inf, guaranteeing their exclusion from choice. As an illustration, if solely a comma or “FROM” is legitimate after “Choose identify,” all different token logits are set to -inf.

In the event you’re utilizing Hugging Face, this may be applied utilizing a “logits processor”. To make use of it you want to implement a category with a __call__ methodology, which will probably be referred to as after the logits are calculated, however earlier than the sampling. This methodology receives all token logits and generated enter IDs, returning modified logits for all tokens.

The logits returned from the logits processor: all illegitimate tokens get a worth of -inf (Picture by creator)

I’ll show the code with a simplified instance. First, we initialize the mannequin, we’ll use Bart on this case, however this could work with any mannequin.

from transformers import BartForConditionalGeneration, BartTokenizerFast, PreTrainedTokenizer
from transformers.technology.logits_process import LogitsProcessorList, LogitsProcessor
import torch

identify = 'fb/bart-large'
tokenizer = BartTokenizerFast.from_pretrained(identify, add_prefix_space=True)
pretrained_model = BartForConditionalGeneration.from_pretrained(identify)

If we need to generate a translation from the pure language to SQL, we are able to run:

to_translate = 'clients emails from the us'
phrases = to_translate.cut up()
tokenized_text = tokenizer([words], is_split_into_words=True)

out = pretrained_model.generate(
torch.tensor(tokenized_text["input_ids"]),
max_new_tokens=20,
)
print(tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(
out[0], skip_special_tokens=True)))

Returning

'Extra emails from the us'

Since we didn’t fine-tune the mannequin for text-to-SQL duties, the output doesn’t resemble SQL. We is not going to practice the mannequin on this tutorial, however we’ll information it to generate an SQL question. We’ll obtain this by using a perform that maps every generated token to an inventory of permissible subsequent tokens. For simplicity, we’ll focus solely on the speedy previous token, however extra difficult mechanisms are straightforward to implement. We’ll use a dictionary defining for every token, which tokens are allowed to observe it. E.g. The question should start with “SELECT” or “DELETE”, and after “SELECT” solely “identify”, “e mail”, or ”id” are allowed since these are the columns in our schema.

guidelines = {'<s>': ['SELECT', 'DELETE'], # starting of the technology
'SELECT': ['name', 'email', 'id'], # names of columns in our schema
'DELETE': ['name', 'email', 'id'],
'identify': [',', 'FROM'],
'e mail': [',', 'FROM'],
'id': [',', 'FROM'],
',': ['name', 'email', 'id'],
'FROM': ['customers', 'vendors'], # names of tables in our schema
'clients': ['</s>'],
'distributors': ['</s>'], # finish of the technology
}

Now we have to convert these tokens to the IDs utilized by the mannequin. This can occur inside a category inheriting from LogitsProcessor.

def convert_token_to_id(token):
return tokenizer(token, add_special_tokens=False)['input_ids'][0]

class SQLLogitsProcessor(LogitsProcessor):
def __init__(self, tokenizer: PreTrainedTokenizer):
self.tokenizer = tokenizer
self.guidelines = {convert_token_to_id(ok): [convert_token_to_id(v0) for v0 in v] for ok,v in guidelines.objects()}

Lastly, we’ll implement the __call__ perform, which known as after the logits are calculated. The perform creates a brand new tensor of -infs, checks which IDs are legit in response to the foundations (the dictionary), and locations their scores within the new tensor. The result’s a tensor that solely has legitimate values for the legitimate tokens.

class SQLLogitsProcessor(LogitsProcessor):
def __init__(self, tokenizer: PreTrainedTokenizer):
self.tokenizer = tokenizer
self.guidelines = {convert_token_to_id(ok): [convert_token_to_id(v0) for v0 in v] for ok,v in guidelines.objects()}

def __call__(self, input_ids: torch.LongTensor, scores: torch.LongTensor):
if not (input_ids == self.tokenizer.bos_token_id).any():
# we should enable the beginning token to seem earlier than we begin processing
return scores
# create a brand new tensor of -inf
new_scores = torch.full((1, self.tokenizer.vocab_size), float('-inf'))
# ids of legit tokens
legit_ids = self.guidelines[int(input_ids[0, -1])]
# place their values within the new tensor
new_scores[:, legit_ids] = scores[0, legit_ids]
return new_scores

And that’s it! We will now run a technology with the logits-processor:

to_translate = 'clients emails from the us'
phrases = to_translate.cut up()
tokenized_text = tokenizer([words], is_split_into_words=True, return_offsets_mapping=True)

logits_processor = LogitsProcessorList([SQLLogitsProcessor(tokenizer)])

out = pretrained_model.generate(
torch.tensor(tokenized_text["input_ids"]),
max_new_tokens=20,
logits_processor=logits_processor
)
print(tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(
out[0], skip_special_tokens=True)))

Returning

 SELECT e mail , e mail , id , e mail FROM clients

The result is slightly unusual, however keep in mind: we didn’t even practice the mannequin! We solely enforced token technology primarily based on particular guidelines. Notably, constraining technology doesn’t intervene with coaching; constraints solely apply throughout technology post-training. Thus, when appropriately applied, these constraints can solely improve technology accuracy.

Our simplistic implementation falls wanting overlaying all of the SQL syntax. An actual implementation should help extra syntax, doubtlessly contemplating not simply the final token however a number of, and allow batch technology. As soon as these enhancements are in place, our skilled mannequin can reliably generate executable SQL queries, constrained to legitimate desk and column names from the schema. A Related strategy can implement constraints in producing JSON, guaranteeing key presence and bracket closure.

Watch out of tokenization

Tokenization is commonly neglected however right tokenization is essential when utilizing generative AI for structured output. Nonetheless, beneath the hood, tokenization could make an affect on the coaching of your mannequin. For instance, you could fine-tune a mannequin to translate textual content right into a JSON. As a part of the fine-tuning course of, you present the mannequin with examples of text-JSON pairs, which it tokenizes. What is going to this tokenization appear to be?

(Picture by creator)

When you learn “[[“ as two square brackets, the tokenizer converts them into a single ID, which will be treated as a completely distinct class from the single bracket by the token classifier. This makes the entire logic that the model must learn — more complicated (for example, remembering how many brackets to close). Similarly, adding a space before words may change their tokenization and their class ID. For instance:

(Image by author)

Again, this complicates the logic the model will have to learn since the weights connected to each of these IDs will have to be learned separately, for slightly different cases.

For simpler learning, ensure each concept and punctuation is consistently converted to the same token, by adding spaces before words and characters.

Spaced-out words lead to more consistent tokenization (Image by author)

Inputting spaced examples during fine-tuning simplifies the patterns the model has to learn, enhancing model accuracy. During prediction, the model will output the JSON with spaces, which you can then remove before parsing.

Summary

Generative AI offers a valuable approach for translating into a formatted language. By leveraging the knowledge of the output structure, we can constrain the generative process, eliminating a class of errors and ensuring the executability of queries and parse-ability of data structures.

Additionally, these formats may use punctuation and keywords to signify certain meanings. Making sure that the tokenization of these keywords is consistent can dramatically reduce the complexity of the patterns that the model has to learn, thus reducing the required size of the model and its training time, while increasing its accuracy.

Structured generative AI can effectively translate natural language into any structured format. These translations enable information extraction from text or query generation, which is a powerful tool for numerous applications.

[ad_2]