Document Classification with InceptionV3
Training and Tuning with Pytorch and Ray Tune
With the increasing digitization of processes in various industries that deal with vast amounts of archives, there is a growing interest in Document Image Processing (DIP). This technology results in significant cost savings for the company and enhances the user experience, as the customer is no longer required to enter the data manually. An integral part of this technology is the automated assignment of a document to a specific
document class to be able to perform subsequent document-specific text extractions. This document classification task can be performed using either context-based (NLP, e.g. BERT) or structure-based (CV, e.g. CNNs) approaches. In the following, we want to follow the latter approach and train a CNN to classify the documents.
The InceptionV3 is one of the best-known and best-performing CNN architectures. The weights — pre-trained on ImageNet — can be conveniently imported with Pytorch and fine-tuned on a custom dataset.
Tobacco3482 Dataset
Since we want to investigate the classification of documents, we choose the openly accessible dataset Tobacco3482 consisting of 3482 single-page scans, divided into 10 different classes of business documents that originate from the tobacco industry:
In the plot we can see that classes comprise different document types including printed documents and hand-written manuscripts. The printed documents range from instances with many graphical elements and little text (e.g., the class Advertisement) to predominantly text-based documents (e.g., class Scientific). However, a document category that contains mainly handwritten font is the class form.
The number of documents per class differs across the document categories, which means class imbalance. Overrepresented classes are Email, Form and ADVE, each consisting of about 600 documents. In contrast, the three most underrepresented classes are Resume, Report, and Scientific, each of which has between 100 and 200 documents:
Pytorch Dataset & Dataloader
Note: The implementation of the following procedure can be found here.
In the first step, we prepare the paths and the labels of the images in a data frame in order to be able to create a Pytorch custom dataset based on them. Thereby, image_dir is the path to which we have saved the dataset.
Before we define our Pytorch Dataset class, we want to specify transformations of the images. We have to distinguish between the transformations of the training data in contrast to the validation and test data : while we want to apply additional image augmentation methods to the training data to avoid overfitting, the test- and validation data should solely be normalized and adapted to the input size of the InceptionV3.
Next we can create the custom dataset class. This inherits from the pytorch dataset class and always requires the three methods __init__(), __len__() and __getitem__().
We create two instances of this class, one with the transformations for the training data and one with the transformations of the test - and the validation data.
Based on this, we create our training, validation — and testing splits with help of the Pytorch SubsetRandomSampler.
Training and Hyperparameter Tuning
Note: The implementation of the following procedure can be found here.
During training, the parameters used to train the model (Hyperparameters) have a considerable influence on the quality of the model. That is why we want to integrate hyperparameter tuning into the usual training workflow of Pytorch. This works very well with Ray Tune, an industry-standard tool for distributed hyperparameter tuning. Furthermore, it includes the latest hyperparameter search algorithms and integrates with TensorBoard as well as other analysis libraries.
To use Ray Tune for our training and tuning, we need to define a training function. This function requires at least a config parameter, which can later be used to specify the hyperparameter grid. Conversely, this means that all elements containing hyperparameters that are to be varied later must be integrated into this training function.
We limit our hyperparameter search to the learning rate, batch size, momentum, weight decay, and the number of epochs (feel free to integrate further hyperparameters in your tuning).
The batch size is crucial for the Pytorch Dataloader. Since the tuning only affects the training and validation split, we define the Pytorch Dataloader as the first part of the training function for both and bundle them in a joint dictionary.
As a next step inside the training function, we initialize the Inception with our custom initialize_model() function, set it to the available resources (either CPU or GPU) and make use of Pytorch distributed training, if more than one GPU is available.
The next step is calling the initialize_model() function which imports the InceptionV3 from torchvision. A special characteristic when finetuning the InceptionV3 in contrast to other model architectures is that this network has two output layers during training. Thereby, the second output is called auxiliary output and is part of the AuxLogits part of the network. However, when testing the model, only the primary output, a linear layer at the end of the network, is considered. For training, though, we need to align both layers to the number of document classes.
Another prerequisite of the training function is the criterion towards which the training should be optimized and the optimizer itself. The optimizer (in our case we decided on the Stochastic Gradient Descent (SGD)) contains the hyperparameters (learning rate, momentum and weight decay) that we want to optimize.
After that, the training worklow follows. Thereby, we loop over the number of epochs we want to train our model for. Within each epoch, all training data is batched through the dataloader and fed into the network. The prediction is then compared with the true label, from which the loss (in our case of the InceptionV3 two losses) is calculated. The latter is then taken by the optimizer (in our case the Stochastic Gradient descent) to optimize the weights of the InceptionV3. At the end of each epoch, the InceptionV3 is applied batchwise to the validation data in order to be able to report both the validation loss and the accuracy on the validation set at the end of each epoch.
Hyperparameter Tuning with Ray Tune
As already mentioned, we will use Ray for hyperparameter optimisation. We will do this with the help of Ray’s Tune module. Thereby we make use of th tune.run() method. It requires the following components as arguments:
- The config
- The scheduler
- The reporter
In the config we define the hyperparameter grid we want to sample combinations from in each run:
The scheduler and the reporter are used to schedule the runs and logs metrics in the console output. In addition, we define how many experiments (max_t) we want to sample on the hyperparameter grid.
Next to these components we specify the number of samples to be drawn from the hyperparameter grid, aka the number of experiments.
Later, after we have trained using all the sampled hyperparameter combinations, we choose the best performing model on the validation set and print its hyperparameter configs, the loss and the accuracy on the validation data.
In a last step of our training procedure, we need to ensure that this model performs equally well on a hold-out-set. For this, we withheld a test set in splitting at the beginning that neither the model itself saw during training nor we used for the model selection. We define another function to apply the selected model on the test data and thereby calculate the accuracy:
As input for our test_accuracy() function we have to load our best model. To do this, we call our previously discussed initialize_model() function, tell the returned model which resources we want to use and load the weights of our best model.
We can now apply the loaded model to the test data by calling our test_accuracy function.
All these building blocks together make up our final main () function of our training pipeline:
That’s it! Now we can watch the models train over the epochs. Ray Tune automatically saves all models over all epochs in checkpoints and identifies the best model at the end of the training process.
Results
The hyperparametertuning with 15 samples in the specified hyperparameter grid resulted in a learning rate (lr) of 0.0043, a batch size of 8, a momentum (mom) of 0.6, a weight decay of 0.008 and 12 epochs. With this hyperparameter combination the highest validation accuracy of 84.6% could be achieved.
In order to additionally evaluate the selected model on a hold out set, we additionally apply the model on the test set that was set aside at the beginning of our analysis. On this data set, we even achieve an accuracy of 85.5% with the model.
Further Improvement
Further improvements of the model quality can be achieved by adding further hyperparameters to the optimization (e.g. a dropout rate). An alternative would be to perform a full grid search instead of randomly selecting only 15 hyperparameter combinations. Also, the model could first be pre-trained on a larger document dataset such as the RVL-CDIP and then fine-tuned on the Tobacco3482. Another option would be to choose a more complex network architecture, such as the follow-up InceptionV4.
Have fun training your own model and do not hesistate to try other model architectures than IceptionV3 and new datasets to finetune your model!
You can find the repository to this project here: https://github.com/jopagel/Document-Classification-with-InceptionV3
References:
- https://www.kaggle.com/datasets/patrickaudriaz/tobacco3482jpg
- https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
- https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
- https://pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html
- https://docs.ray.io/en/latest/tune/examples/tune_analyze_results.html
- https://docs.ray.io/en/latest/tune/api_docs/schedulers.html
- https://docs.ray.io/en/latest/tune/api_docs/reporters.html
- https://arxiv.org/pdf/2004.07922v1.pdf