Content Ranges - golololologol/LLM-Distillery GitHub Wiki

Content ranges are a key concept used within the pipeline for distillation of instruct samples.

Why are they needed? Lets take a look at this problem:

Say your teacher model uses this prompt format:

### Human:
{Some text}
### Assistant:
{Some AI slop}

But you want your student to be trained on this prompt format:

User: {Some rambling}
AI: {Generic response filled with gpt-isms}

After you tokenize them, and even when both are from the exact same sample, they of course will be different due to the prompt formatting, meaning you cannot directly compare the predictions the student made against the teacher's collected data, because the contents of the sample will be at different indexes in text e.g. misaligned, so for example you will be comparing how the teacher predicted its prompt format token, to how the student is predicting the start of the conversation.

So, what to do? Answer: Content ranges!

Using them, we can crop out all the un-aligned tokens, so when the time comes to loss calculations, the content of the conversations are directly comparable tensors.

Pipeline identifies the start and end indexes of the content for the turn, in accordance to python slicing: [startID, endID) (start id is inclusive, end id is exclusive). It does that by:
Tokenizing each turn delimiter separately (USER_START, SYS_END, etc...), and tokenizing each sample's text without any prompt formatting, to avoid prompt format and content tokens being fuzed into one.

And during tokenization, it keeps meticulous track of how many tokens are already in the tokenized sample, and creates content ranges using this simplified example in code:

start_idx = -1
conversation_content_ranges = []
pf = prompt_format
# turn delimiters were already tokenized

for turn in sample:
	turn_speaker = turn.get("from")
	assistant = turn_speaker == "gpt"
	
	turn_start_tokenized = pf['ASSISTANT_START'] if assistant else pf['USER_START']
	turn_end_tokenized = pf['ASSISTANT_END'] if assistant else pf['USER_END']
	
	turn_text = turn.get("value")
	
	turn_tokenized = tokenizer.encode(turn_text)
	num_content_tokens = len(turn_tokenized)
	full_turn_len = len(turn_start_tokenized) + num_content_tokens + len(turn_end_tokenized)
	
	content_start_index = len(turn_start_tokenized) + start_idx
	content_end_index = content_start_index + num_content_tokens + 1
	conversation_content_ranges.append((content_start_index, content_end_index))
	
	start_idx += full_turn_len
	
	...(prompt-formatted turn gets added to the list, etc...)

So, now to get a real hang of it, lets get the content ranges, and the content tokens for this text with its prompt format:

"User", ":", "\n", "Hi", "there", "\n", "Assistant", ":", "\n", "Hi", "!", "\n" (12 tokens in total) Tokenized prompt format:

user_start = ["User", ":", "\n"]
user_end = ["\n"]
assistant_start = ["Assistant", ":", "\n"]
assistant_end = ["\n"]

Content ranges:
We see that the User's turn start has 3 tokens, and as the models predict text one token ahead of the current one, we'll need to subtract 1 from the start of User's range:
content_ranges = [(2, ?)] (Reminder that python's slicing is 0-index based, so index 2 points to the token "\n" in text)

Now, we know where the model starts predicting the content text, but where does it stop? Lets figure out:
We know the start of the prediction for content tokens, now, the content itself is 2 tokens long, we add those together, but it points to the index 4, and with python's slicing the last index is exclusive, meaning if we were to slice the text using the range (2, 4), we would actually collect the distributions for only 1 content token, that being Hi. So we need to add 1 to the end index to account for that:
content_ranges = [(2, 5)]

Do the same thing for the assistant's turn, and we get:
content_ranges = [(2, 5), (7, 10)]
And just to see what they represent, lets use them to slice the text to get the content tokens:
content_tokens = ["\n", "Hi"], ["\n", "Hi"](/golololologol/LLM-Distillery/wiki/"\n",-"Hi"],-["\n",-"Hi") (reminder, models predict one token ahead)

Hooray! We now have the content ranges that can be applied to the collected logits for the full sample, to get the Content Logits, which will be cross-comparable with other Content Logits collected on this sample, using any other prompt format!