Write a document classifier in less than 30 minutes. | Towards Data Science

8 min read Original article ↗

In one of my past interviews, I have been asked to implement a model to classify paper abstracts. The objective was not to have a perfect tuned model, but rather to see my ability to go through the entire process in a minimum amount of time. Here is what I did.

The Data

The data consisted of paper abstracts from the PubMed database. PubMed is a repository of all biomedical literature. NCBI, the agency managing PubMed provides an API to download the papers. A lot of libraries already exist to interact with the API, in several languages. I used Python, and the easiest library I found is Bio and its module Entrez used for this specific database.

We import the module, and configure an email, which is mandatory to let them track the number of requests per second. You can even ask for an API_KEY to increase the number of documents per second to 10.

from Bio import Entrez
Entrez.email = '[email protected]'
Entrez.api_key = "abcdefghijklmnopqrstuvwxyz42"

To get articles from PubMed, we first execute a query that returns the metadata of each document such as its ID. We then use the IDs to get the details (in my case abstracts).

def search(query, max_documents=1000):

    handle = Entrez.esearch(db='pubmed',
                            sort='relevance',
                            retmax=max_documents,
                            retmode='xml',
                            term=query)
    results = Entrez.read(handle)

 return results

The function will execute the query in the argument on the PubMed database, sort the results by relevance and limit the number of results to max_documents.

Queries are actually very simple. You can use document keywords, with logic operators. The PubMed documentation explains how to build queries in details.

In the interview, I’ve been asked to get documents for 4 classes (topics). We do that by specificy each class’s associated keywords in the query.

The result of the function is a list of document details, without their contents. We then use these IDs to get all the details of the documents.

def fetch_details(id_list):
    handle = Entrez.efetch(db="pubmed", id=','.join(map(str,                 id_list)),rettype="xml", retmode="text")

    records = Entrez.read(handle)
    abstracts = [pubmed_article['MedlineCitation']['Article']   ['Abstract']['AbstractText'][0] for pubmed_article in records['PubmedArticle'] if 'Abstract' in pubmed_article['MedlineCitation']['Article'].keys()]

 return abstracts

The function will take a list of IDs and returns an array with all the abstracts. The complete function to get all abstracts for a specific class is:

def get_abstracts_for_class(ab_class):
    list_abstracts = []
    ## get keywords of the class
    query = " AND ".join(keywords[ab_class])
    res = search(query)
    list_abstracts = fetch_details(res["IdList"])
    return list_abstracts

I saved all keywords in a dictionary and use them to build the query.

We call the function for each class, to get all the abstracts of all classes. Finally, we reformat them to have a usable Pandas DataFrame.

list_all_classes = []
list_all_classes += [{"abs": a, "class": 1} for a in list_abs_class1]
list_all_classes += [{"abs": a, "class": 2} for a in list_abs_class2]
list_all_classes += [{"abs": a, "class": 3} for a in list_abs_class3]
list_all_classes += [{"abs": a, "class": 4} for a in list_abs_class4]
abs_df = pd.DataFrame(list_all_classes)

Data Cleansing

Again, the goal here is not to perfectly clean the dataset, but a small standard preprocessing is necessary. I personnaly use NLTK most of the time, but you can do the same thing with almost all NLP libraries.

from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
import string
## 1) Lower
abs_df["abs"] = abs_df["abs"].str.lower()
## 2) Remove tags
abs_df["abs"] = abs_df.apply(lambda x: re.sub("<[^>]*>", "", x["abs"]), axis=1)
## 3) Tokenize
abs_df["abs_proc"] = abs_df.apply(lambda x: word_tokenize(x["abs"]), axis=1)
## 4) Remove punctuation
nltk.download('punkt')
table = str.maketrans('', '', string.punctuation)
abs_df["abs_proc"] = abs_df.apply(lambda x: [w.translate(table) for w in x["abs_proc"]], axis=1)
## 5) Remove non-alpha
abs_df["abs_proc"] = abs_df.apply(lambda x: [w for w in x["abs_proc"] if w.isalpha()], axis=1)
## 6) Remove stop-words
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))
abs_df["abs_proc"] = abs_df.apply(lambda x: [w for w in x["abs_proc"] if not w in stop_words], axis=1)
## 7) Reformat to have a single text. 
abs_df["abs_proc_res"] = abs_df.apply(lambda x: ' '.join(x["abs_proc"]), axis=1)

We use the power of Pandas apply function, to apply the same processing to the entire dataframe:

  1. Lower all the text
  2. I found that there are some tags in the text, such as to indicate bold text. Even if these tags have probably an importance, but this is too much complicated for a 1h exercise. So I decided to just remove them using a regular expression.
  3. We first tokenize the text: i.e. split it into a list of individual words.
  4. Remove all punctuation characters such as question marks (?) or commas (,).
  5. We remove non alphabetical tokens, i.e. numbers.
  6. We remove stop-words. We first retrieve the English stop-words vocabulary using NLTK, and then use it to filter out our tokens.
  7. Finally we join the processed data to have a single text for each abstract.

Data Embedding

If you are familiar with NLP problems, then you know that the most important part when dealing with textual data is probably the vector representation, i.e. Embedding. A lot of advances have been made regarding this, and some powerful models have been proposed such as Google’s BERT or OpenAI’s GPT. However, these are very tricky models to tune, and are definitely not suited for a 1 hour exercise. Moreover, for a lot of practical problems, a very simple embedding is enough to have a correct vector representation of the data.

The simplest one is probably TF-IDF (Term Frequency-Inverse Document Frequency), which is the one that I used.

The sklearn library already has a TF-IDF module, and can be directly used on a dataframe.

from sklearn.feature_extraction.text import TfidfVectorizer
vec = TfidfVectorizer()
x = vec.fit_transform(abs_df["abs_proc_res"])

At this moment, we have a matrix X corresponding to all our vectorized abstracts. However, looking at the shape of X, we notice something:

print(x.shape)
(25054, 60329)

we end up with a large number of columns (i.e. 60329). This is normal as this number corresponds to the size of the vocabulary of the entire corpus (i.e. the entire dataset). There is 2 problems with this number.

First, it will complexify the training of the model.

Second, even if we did a lot of preprocessing, most of the words of the vocabulary are not relevent to the classification as they do not add any relevent information.

Fortunately, there is a way of reducing the number of columns while avoid loosing relevant information. The most common way of doing this is called PCA (Principal Component Analysis), which will decompose the matrix to a set of uncorrelated matrices of lower dimensions. We apply an SVD (Singular Value Decomposition), which is a type of PCA. Again, there is an sklearn module to do it easily.

from sklearn.decomposition import TruncatedSVD
svd = TruncatedSVD(n_components=100)
res = svd.fit_transform(x)
print(res.shape)
(25054, 100)

I chose to reduce our initial matrix to 100 components (i.e. features). This is a parameter to optimize: the closer we are to the initial dimension the less information we loose in the reduction process, while a small number will reduce the complexity of our model training.

We are now ready to train a classifier.

The model

There are a lot of classification models out there. One of the simplest ones to understand and to implement is probably SVM (Support Vector Machine). Ina nutshell, it will try to draw a line to separate as much as it can to the points from each class.

We also use a cross-validation for a better representativeness of the metrics.

from sklearn import svm
from sklearn.model_selection import RepeatedKFold
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_validate
from numpy import mean
from numpy import std
y = abs_df["class"].values
X = res
cv = RepeatedKFold(n_splits=10, n_repeats=3, random_state=1)
model = svm.SVC(kernel='linear', C=1, decision_function_shape='ovo')

We use a linear kernel, i.e. it will try to draw a line to separate the data. Other kernels exist like the polynomial, which try to find a polynomial function that better separate the points.

The decision function is set to ovo, i.e. one-versus-one, which will take each try to separate each pair of classes ignoring the others.

Let’s train !

metrics = cross_validate(model, res, y, scoring=['precision_macro', 'recall_macro'], cv=cv, n_jobs=-1)

print('Precision: %.3f (%.3f)' % (mean(metrics["test_precision_macro"]), std(metrics["test_precision_macro"])))
print('Recall: %.3f (%.3f)' % (mean(metrics["test_recall_macro"]), -std(metrics["test_recall_macro"])))
-----------------------------------
Precision: 0.740 (0.021)
Recall: 0.637 (0.014)

Two interesting metrics here: The precision and the recall.

The precision means that, for each class, among the predicted documents, 74% are correctly predicted, which is not bad.

On the other hand, the recall means that among all documents from a certain class, we were able to catch 63%.

Conclusion and perspectives

As you can see, it is relatively easy to implement a quick classifier, using only the basics of machine learning. Of course it is not perfect, but when you don’t have anything, even a bad model is acceptable.

Clearly a lot of improvments can be made. The preprocessing is probably the most impactful part on a model. For instance, instead of using a TF-IDF, we could try more sophisticated algorithms such as BERT. From the model side, we could also try other classifiers or even stack several ones for better performance.

That being said, if your objective is to have a working model to classify your documents, this is a good starting point.

The next step is to put this in production ! I will cover this part on another post.