Sometimes a single word isn’t enough. What if you could blend tokens or use non-tokens (aka nokens1)?
Let’s explore these ideas, learn more about transformers: the architecture and Hugging Face library.
Installing & Setup
I’m going to use Microsoft’s Phi-2 as it is smaller than Mistral / Llama2. The ideas should transfer to other transformer models.
After pip install transformers torch
, we can load our language model to experiment with using the Hugging Face library.
from transformers import AutoTokenizer, PhiForCausalLM
import torch
'mps') # set to 'cuda' for Nvidia GPUs
torch.set_default_device(
= PhiForCausalLM.from_pretrained('microsoft/phi-2')
phi = AutoTokenizer.from_pretrained('microsoft/phi-2') tokenizer
After the download completes, we can test the model by asking it to generate some text.
Generating Text The Old Fashion Way
Before exploring unknown latent spaces, let’s verify we can generate text from a prompt.
With the transformers models you can’t just send in the prompt, we need to manually tokenize (encode) the prompt and then decode the response.
def llm(prompt, max_length=30):
= tokenizer(prompt, return_tensors='pt')
inputs = phi.generate(inputs.input_ids,
generated_ids =inputs.attention_mask,
attention_mask=max_length)
max_lengthreturn tokenizer.batch_decode(generated_ids,
=False,
clean_up_tokenization_spaces=True)[0] skip_special_tokens
We can verify the snippet works by asking Phi about whippets.
"Whippet facts:") llm(
Whippet facts:
- The Whippet is a medium-sized dog breed.
- They have a slender and elegant body structure.
In order to use nokens, we need to learn how to generate text from token embeddings.
Token Embeddings
At the bottom of the architecture diagram, we see the first step is embedding the tokens. This is where converting each token into an embedding happens. If you need a refresher on tokens, check out Simon Willison’s interactive GPT Tokenizer Notebook.
Each token becomes a 2560 dimensional embedding. We can see this by running the following code:
= "Hello world :)"
prompt = tokenizer(prompt, return_tensors="pt")
inputs for id in inputs.input_ids[0]:
print(tokenizer.decode(id), id.item(), phi.model.embed_tokens(id))
Token | Id | Token Embedding - 2560 dimensions |
---|---|---|
"Hello" |
15496 | [-0.0298, -0.0281, 0.0052, …, 0.0336, 0.0179, -0.0231] |
" world " |
995 | [ 0.0178, 0.0300, 0.0014, …, -0.0061, 0.0106, -0.0367] |
" :)" |
14373 | [ 0.0083, -0.0308, 0.0373, …, -0.0142, -0.0107, -0.0084] |
We can generate text from these token embeddings sending inputs_embeds
instead of input_ids
.
def llm_embed(prompt, max_length=30):
= tokenizer(prompt, return_tensors="pt")
inputs = phi.model.embed_tokens(inputs.input_ids)
embeds = phi.generate(inputs_embeds=embeds,
generated_ids =inputs.attention_mask,
attention_mask=30)
max_lengthreturn tokenizer.batch_decode(generated_ids,
=True,
skip_special_tokens=False)[0] clean_up_tokenization_spaces
Check everything is working by comparing the text generated with our original llm
function2.
'Random number?') == llm_embed('Random number?') llm(
Now we have all the pieces we need to use nokens instead of tokens.
Generating With Random Nokens
Our first experiment is replacing a token with seeded randomness.
By using the token Red
as the first word in our prompt, we ensure we replace this token with a noken. We replace the token embedding with random numbers of the same size/shape as the original token embedding.
def llm_rand(prompt, seed, loc=0):
= tokenizer(prompt, return_tensors='pt')
inputs
# overwrite embedding of token at `loc` with seeded randomness
torch.manual_seed(seed)= phi.model.embed_tokens(inputs.input_ids)
embeds 0,loc,:] = torch.randn_like(embeds[0,loc,:])
embeds[
# generate new tokens
= phi.generate(inputs_embeds=embeds,
generate_ids =inputs.attention_mask,
attention_mask=30)
max_length
# return decoded tokens
return tokenizer.batch_decode(generate_ids,
=True,
skip_special_tokens=False)[0]
clean_up_tokenization_spaces
for seed in range(10):
print(llm_rand("Red is a color that is often associated with", seed))
Seed | Response |
---|---|
0 | luxury and elegance. It is a deep, rich shade of brown that is often used in high-end |
1 | luxury and elegance. It is a deep, rich shade that is often used in high-end fashion and |
2 | the sun, fire, and energy. It is a bright and vibrant color that can evoke feelings of warmth |
3 | the color wheel. It is a combination of red and yellow, and it is a color that is often |
4 | the ocean and the sky. It is a deep, rich blue that can evoke feelings of calmness and |
5 | the ocean, the sky, and the sun. It is a bright and cheerful color that can evoke feelings |
6 | the ocean and the sky. It is a calming and soothing color that can help to create a sense of |
7 | the natural world. It is the color of the sky, the grass, and the leaves on trees. |
8 | luxury, elegance, and sophistication. It is a color that is often used in high-end fashion, |
9 | luxury and elegance. It is a color that is often used in high-end fashion and home decor. |
It worked! The LLM decided that different nokens have different color-ish meanings.
Generating With LERP’d Nokens
For our next experiment, we will LERP3 to blend two different tokens.
To simplify, let’s just LERP between two prompts that have the same number of tokens. The first word of both prompts is a single token (Black
/White
).
def llm_lerp(prompt_1: str, prompt_2: str, p: float) -> str:
= tokenizer(prompt_1, return_tensors='pt')
inputs_1 = tokenizer(prompt_2, return_tensors='pt')
inputs_2 = phi.model.embed_tokens(inputs_1.input_ids)
embeds_1 = phi.model.embed_tokens(inputs_2.input_ids)
embeds_2 if embeds_1.shape != embeds_2.shape:
raise ValueError('Embeddings must be the same shape')
= (1-p) * embeds_1 + p * embeds_2
embeds = phi.generate(inputs_embeds=embeds,
generate_ids =inputs_1.attention_mask,
attention_mask=30)
max_lengthreturn tokenizer.batch_decode(generate_ids,
=True,
skip_special_tokens=False)[0]
clean_up_tokenization_spaces
for p in range(11):
= p // 10
p print(llm_lerp('White whippets are', 'Black whippets are', p))
Seed | Response |
---|---|
0.0 | also known as the “dancing dog” because of their graceful movements. They are very active and love to run and play |
0.1 | also known as “snow dogs” because of their ability to pull sleds in snowy conditions. They have a thick, |
0.2 | also known as “snow dogs” because of their ability to pull sleds in snowy conditions. They have a thick, |
0.3 | also known as “snow whippets” because of their white coats. They are a popular breed for families with children |
0.4 | also known as “snow whippets” because of their white coats. They are a popular breed for families with children |
0.5 | also known as “snow whippets” because of their white coats. They are a popular breed for families with children |
0.6 | also known for their intelligence and trainability. They are eager to please their owners and can be easily trained to perform various tasks |
0.7 | also known for their intelligence and trainability. They are eager to please their owners and can be easily trained to perform various tasks |
0.8 | also known for their intelligence and trainability. They are eager to please their owners and can be easily trained to perform various tasks |
0.9 | also known for their intelligence and trainability. They are eager to please their owners and can be trained to do various tricks and |
1.0 | also known for their intelligence and trainability. They are eager to please their owners and can be trained to do various tricks and |
Success!
More Control - Generate one token at a time
Perhaps you want to mess with the embeddings of tokens that are not in the prompt. Perhaps trigger on specific tokens or interactively changing… I’ve got you covered, this code snippet is complex but exposes the inner workings of the model.
# helpers
= LogitsProcessorList([
logits_processor
MinLengthLogitsProcessor(15, eos_token_id=phi.generation_config.eos_token_id
),
])= LogitsProcessorList([
logits_warper 1),
TopKLogitsWarper(0.01),
TemperatureLogitsWarper(
])
def embsampler(prompt="Red is the first", seed=0, tokens=10):
print("prompt:", prompt.replace("Red", "*"), " | ", end="")
= None
past_key_values
= tokenizer(prompt, return_tensors="pt")
inputs = inputs.input_ids
input_ids = inputs.attention_mask
attention_mask = phi.model.embed_tokens(input_ids)
inputs_embeds
# replace the first token with a random seeded noken
torch.manual_seed(seed)0,:] = torch.randn_like(inputs_embeds[:,0,:])
inputs_embeds[:,
for _ in range(tokens):
# run the core model
= phi(
outputs =past_key_values,
past_key_values=inputs_embeds,
inputs_embeds=attention_mask,
attention_mask=True,
use_cache=True,
return_dict=False,
output_attentions=False,
output_hidden_states
)
# outputs returns logits we need to process to determine the next token
= outputs.logits[:, -1, :]
next_token_logits = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
next_token_scores = F.softmax(next_token_scores, dim=-1)
probs = torch.multinomial(probs, num_samples=1).squeeze(1)
next_token print(tokenizer.decode(next_token), end="")
# we use kvcache to speed things up, only seed new tokens
= next_token[:, None]
input_ids = phi.model.embed_tokens(input_ids)
inputs_embeds
# this is the cache!
= outputs.past_key_values
past_key_values
# attention_mask.append(1) .. but in torch :(
= torch.cat(
attention_mask 0], 1))],
[attention_mask, attention_mask.new_ones((attention_mask.shape[=-1,
dim
)print()
for seed in range(0,11):
=seed, tokens=20) embsampler(seed
Seed | Response |
---|---|
0 | to offer a comprehensive, integrated approach to the treatment of chronic pain. The program is designed to |
1 | of its kind in the world. The building is also home to the National Gallery of Victoria |
2 | to be built in the United States. Exercise: What is the purpose of the new |
3 | to be released in the series. Question: What is the plot of the game? |
4 | to offer a comprehensive, integrated solution for the management of the entire life cycle of a product, from |
5 | to admit that the idea of a “smart” city is still in its infancy. “We |
6 | step in the process of making a delicious and refreshing drink. The process of making a da |
7 | step in the process of creating a new drug. Exercise: What is the purpose of |
8 | to admit that the idea of a “perfect” marriage is a myth. “There is no such |
9 | album by the band The Black Crowes. It was released in 1988 and was produced by the band |
Success?
It “works” - but is it useful?
Unclear.
It seems interesting enough to continue. As transformers are used in more contexts and can do more, having a richer set of tools to interact with the model seems like a good idea.
I’m going to switch to Mistral or perhaps coding model. Or perhaps a musicgen model - fuzz the music!
Footnotes
In Mapping the semantic void: Strange goings-on in GPT embedding spaces, the author explores GPT-J’s token space and introduced me to the concept of non-tokens (nokens).↩︎
LLMs are determinstic by default. To add chaos you need to turn up the
temperature
or usetop_k
/top_p
sampling.↩︎Linear intERPolation: a way to smoothly blend between two values.↩︎