Yuan Gao
Posted on January 14, 2021
Running AI these days is increasingly simple due to the hard work of open source contributors producing top-notch libraries out there, and research groups opening up their work so others can build on it. One key library doing that is HuggingFace's Transformers library. HuggingFace are a startup building, amongst other NLP-related products, a library and model ecosystem that allows almost anyone to quickly and easily set up AI-powered chat bots that can consume or produce natural language.
In this post, I'll demonstrate how I used this library to produce a Twitter bot that is only tweeting made-up (and slightly quirky) good news
AIs and GPT2
This blog post isn't meant to explain any theory, but for those who aren't familiar, the easiest way to explain this kind of AI, is they're sophisticated pattern recognition systems. If you feed it enough data, it can build up an ability to recognize the patterns in the english language, to the extent that if you ask it to repeat the pattern, not only will it generate mostly correct English grammar, it might also from time to time generate a coherent sentence!
It takes a huge amount of data and processing power for it to figure out the patterns in language and also some of the topics that we typically talk about. GPT2 is one such AI that does this, developed by OpenAI, and the model (the output of all this "learning") that we often use for it is trained on millions and millions of pages from the internet (a lot of it from Reddit), which has given it an uncanny ability to reproduce English - give it a starting letter or word, and it will, based on the patterns it has picked up from it's training data, produce the next few words.
That model, with all of these learned "patterns" embedded within it, can be downloaded, and used directly. HuggingFace have also produced their own, much smaller model, that has been optimized to reduce the resources needed to use it, so that it can be run on your average desktop computer, called DistilGPT2. This is the one we'll be using, as I don't have a good enough PC to run the larger GPT2 models (which are also available for download on their site).
Sourcing the data
The model by itself is not too useful for our use-case. It can produce English, but it has no specific topic of focus. It's source material was diverse, so it could end up producing sentences about a wide variety of subjects.
What we want to do is make it produce text in the style of news headlines, with a particular focus on good news or quirky news. So we want to take this base model, and "fine tune" it so it continues to learn on new data and figure out what new patterns we are interested in it producing.
So we must first acquire a large amount of examples of the kind of thing we want it to generate - good/quirky news headlines.
I'm not going to talk in too much detail about how or where because uh...the amount of data we need to acquire is quite large, and this involves scraping it in volumes that most sites probably don't want you to since a lot of people doing it would put a large burden on their servers. So it's something you'd want to do responsibly, and legally.
Possible source of data: RSS feeds
Where better to collect news headlines than RSS feeds! RSS is a standardised format for apps to consume headlines from news sources, so that apps like news-readers could source their headlines from lots of different news sites and have them all formatted correctly in the reader. RSS is big in the podcast world, because it's perfect for that use-case: each podcast has an RSS feed, which allows your podcast app to subscribe to the feed and fetch the latest episodes when they're available.
Almost all news sites will produce an RSS feed, and you can programatically fetch and parse these feeds easily. For example, in Python, there's an easy-to-use library called feedparser
(pip install feedparser
) code to do this is simply:
import feedparser
feed = feedparser.parse("http://example.com/rss")
for entry in feed.entries:
print(entry.title)
The upside of this is it's super simple and easy to write. The downside to this is, most news sites will only give a day or two's worth of headlines, and don't give you a way to fetch historical news. Fortunately, news sites are often crawled by archive tools intended to preserve historical snapshots of sites, and as a result, RSS feeds are often included in archives, meaning you could, if you were to figure out how and where, fetch ten or more years of historical RSS feeds HINT HINT MASSIVE HINT.
Possible source of data: Reddit
Just as GPT2 was trained on reddit stories, (actually, only using links from them), you can too. Reddit offers an API that you can use to build your own programatic access to it, you can find more information on Reddit's developer page.
If using Python, there's a package called 'PRAW' (pip install praw
) which provides an easy to use interface for interacting with the Reddit API. You still need your own API keys, which can be acquired from Reddit's developer pages. The code is simply:
import praw
reddit = praw.Reddit(
client_id=YOUR_TOKEN_ID,
client_secret=YOUR_TOKEN_SECRET
user_agent=YOUR_APP_NAME_OR_SOMETHING
)
subreddit = reddit.subreddit("example_subreddit")
for submission in subreddit.top(limit=None):
print(submission.title)
Again, it's very easy to use it this way, but the downside is Reddit will give you at maximum 1000 stories per subreddit. This is quite good already, but we need much more than this usually, though with there being many subreddits available, you can probably collect enough data like this.
Another possible source of reddit data is PushShift, a service operated by a researcher who needed access to a large amount of reddit data, and appears to currently be a one-person operation that is community supported at this point. PushShift's APIs allow access to a large amount of reddit data, with history far beyond the 1000 that the official API will provide. PushShift is a very powerful resource, there's more information on their subreddit and website.
Other possible sources of data
There are several other sources of data, ranging from plain web crawling and scraping, to user-generated content, and paying (or otherwise persuading people to browse the internet and find you the data). However you do it, you need maybe tens of thousands of examples to get it work well.
Sentiment analysis
If you were collecting a lot of news stories from news sites, a lot of the data won't be positive, in fact, my findings show a big majority (> 66%) of news articles were negative. Since this is supposed to be a positive AI bot, we need to filter the data somewhat.
Fortunately, sentiment analysis has been a mainstay of NLP/AI research these past years, and so the methods to do it are well developed. HuggingFace's transformers library has a sentiment-analysis
pipeline already (based on BERT, which is pretty new), which makes this task a breeze. After installing it (pip install transformers
, and pip install torch
), the code is simply the following, using a little bit of list comprehension to get our filtered list out:
from transformers import pipeline
sentimentanalyzer = pipeline("sentiment-analysis")
sentiments_list = sentimentanalyzer(titles_list)
THRESH = 0.99
positive = [title for title, sentiment in zip(titles_list, sentiments_list) if sentiment["label"] == "POSITIVE" and sentiment["score"] > THRESH]
Note: this takes a lot of RAM, I actually ran out of my 32GB of RAM when processing 100k titles. It needs to be batched smaller in many cases.
This was my machine just before it exceeded 32GB of RAM, and crashed the script:
Preparing the data
Once the filtered list is collected, it needs to be prepared in a specific way for the fine-tuning. We need a flat text file containing all the data we want to use, and we need a special token "<|endoftext|>"
to separate each headline.
In addition, we want to randomize the titles in the file multiple times to avoid the AI learning that there's any association between each of them, so we copy the data in 4 times in a random order to avoid this:
import random
ENDTOKEN = "<|endoftext|>"
EPOCHS = 4
with open("data_train.txt", "w") as fp:
fp.write(ENDTOKEN)
for _ in range(EPOCHS):
random.shuffle(positive)
fp.write(ENDTOKEN.join(positive)+ENDTOKEN)
This produces a big flat text-file data_train.txt
that we can now feed into the fine-tuning
Fine Tuning
To fine-tune a model, we start with the Huggingface distilgpt2
model, though others could be used (have a browse!).
The code looks like this:
import os
import random
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
TextDataset, DataCollatorForLanguageModeling,
Trainer, TrainingArguments
)
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
train_dataset = TextDataset(tokenizer=tokenizer, file_path="data_train.txt", block_size=tokenizer.model_max_length)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir='./output',
overwrite_output_dir=True,
num_train_epochs=4,
per_device_train_batch_size=1,
prediction_loss_only=True,
logging_steps=100,
save_steps=0,
seed=random.randint(0, 2**32-1),
)
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
)
trainer.train()
model.save_pretrained("./model")
The result is a training process. Hopefully you have a GPU with 6GB or more of VRAM, or you're going to be there for weeks! On GPU, this process takes an hour or two.
Using the model
Once done, we can save the model, and use it!
import string
start_str = ENDTOKEN + random.choice(string.ascii_uppercase+string.digits)
encoded_prompt = tokenizer(start_str, add_special_tokens=False, return_tensors="pt").input_ids
encoded_prompt = encoded_prompt.to(model.device)
output_sequences = model.generate(
input_ids=encoded_prompt,
max_length=50,
min_length=10,
temperature=1.6,
top_k=100,
top_p=0.90,
do_sample=True,
num_return_sequences=5,
pad_token_id=tokenizer.eos_token_id # gets rid of warning
)
for generated_sequence in output_sequences:
generated_sequence = generated_sequence.tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
print(text.strip())
Here, we give the model a random starting character to use as the prompt. If you don't do this, the text can start as if it were half-way through a sentence (actually, the above code can do even better by selecting a token at random from the available tokens rather than using a letter).
The explanation for temperature
, top_k
and top_p
arguments in the model generator takes some adjusting, the above values makes the output very random, but that's sort of what we want here, as we want the output to be a little zany. Hugingface blog has a nice article on these parameters
The output looks something like this:
"HELP: The measles vaccine wins Malaysia competition",
"Sarajevo aims to change Bolivia's infamous Montecito archipelago, abolishing devastating 20 mile wilderness area which includes most of",
"0-year-old 'wonder' statue torn down in Richmond, helps replace statue in front of House Building burned to fire",
"Elderly man finds treasure of \u00a33.7 million - goes under the tree's cheek - and RETURNS it to Wildlife Charities",
The settings as I have it currently yields some crazy stuff, but that's what I'm looking for. A more reasonable set of values might be temperature 0.5, top_k 50, and top_p 0.95, which will give fairly conservative (unimaginative) results.
Manual curation
I'm looking to tweet out some stuff that people will laugh at, and so I don't want to just generate these headlines en-masse and queue them up to be tweeted, that would quickly flood people's feeds with kind of random stuff that doesn't make any sense. So I want to curate these: hand-select the few good ones out of the randomness that is in there.
Fortunately, this is where CurateBot comes in, my previous project. The purpose of CurateBot is so I can load literally thousands of these tweets into it, and go through by hand and quickly delete the bad ones, and send the good ones to a queue to be eventually tweeted out.
Here's a peek at some of the queued tweets for @goodnews_ai
The resultant tweets posted by @goodnews_ai
Hopefully this post gives you an idea about how easy it is to make AI tweet bots these days! My intention for this post is not to give an in-depth explanation or guide, but to direct you towards the different resources and search terms you can use to find more resources. If you're going to be trying this soon, good luck!
Posted on January 14, 2021
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.