Day and night: an image classifier with scikit-learn

How hard is it to write a program that can recognize if a photo of a city was taken during the day or during the night?

Turns out that’s pretty simple. The process can be broken down in three steps:

  1. defining a feature vector;
  2. training a classifier;
  3. test the classifier on unknown images.

Defining a feature vector

A feature vector is a vector that summarize the features of an object that we want to classify.

Seattle by day by gaensler@flickr. Sydney by night by NickiMM@flickr.

Seattle by day by gaensler@flickr. Sydney by night by NickiMM@flickr.

Obviously photos taken by night have more dark pixel compared to photos taken by night. So we can the number of dark/midrange/light pixel as feature vector to classify photos.

However, simply counting the number of pixels for each color would make the feature vector dependent on the size of each photo. So, it makes more sense to put in the feature vector the ratio of pixels of each color by the total number of pixels.

An additional problem is that so far the feature vector would be larger than necessary: the RGB color space has \(\left(2^8\right)^3 = 16.777.216\) elements, so a city with a dark sky would be substantially different from a city whose sky has a slightly different shade of dark blue.

We can reduce the feature vector size by mapping each of the 256 possible values of each color channel to a smaller set of values, for example 4.

RGB color cube. Courtesy of SharkD@wikipedia. Original: https://en.wikipedia.org/wiki/RGB_color_space#mediaviewer/File:RGB_Cube_Show_lowgamma_cutout_b.png

RGB color cube. Courtesy of SharkD@wikipedia. Original: https://en.wikipedia.org/wiki/RGB_color_space#mediaviewer/File:RGB_Cube_Show_lowgamma_cutout_b.png

The final problem is that most machine learning libraries assume the feature vector to be a 1-dimensional vector, while the RGB color space is 3-dimensional. For this reason we can simply map the \(\)(4)^3 = 64\(\) cells of the RGB color space to a 1-dimensional vector with 64 slots.

Let’s use Pillow to read images and start writing some Python code that given an image file path or an image URL calculates its feature vector:

from __future__ import division
from __future__ import print_function
from PIL import Image
from StringIO import StringIO
import urllib2
from urlparse import urlparse
import sys
import os


def process_directory(directory):
    '''Returns an array of feature vectors for all the image files in a
    directory (and all its subdirectories). Symbolic links are ignored.

    Args:
      directory (str): directory to process.

    Returns:
      list of list of float: a list of feature vectors.
    '''
    training = []
    for root, _, files in os.walk(directory):
        for file_name in files:
            file_path = os.path.join(root, file_name)
            img_feature = process_image_file(file_path)
            if img_feature:
                training.append(img_feature)
    return training


def process_image_file(image_path):
    '''Given an image path it returns its feature vector.

    Args:
      image_path (str): path of the image file to process.

    Returns:
      list of float: feature vector on success, None otherwise.
    '''
    image_fp = StringIO(open(image_path, 'rb').read())
    try:
        image = Image.open(image_fp)
        return process_image(image)
    except IOError:
        return None


def process_image_url(image_url):
    '''Given an image URL it returns its feature vector

    Args:
      image_url (str): url of the image to process.

    Returns:
      list of float: feature vector.

    Raises:
      Any exception raised by urllib2 requests.

      IOError: if the URL does not point to a valid file.
    '''
    parsed_url = urlparse(image_url)
    request = urllib2.Request(image_url)
    # set a User-Agent and Referer to work around servers that block a typical
    # user agents and hotlinking. Sorry, it's for science!
    request.add_header('User-Agent', 'Mozilla/5.0 (X11; Ubuntu; Linux ' \
            'x86_64; rv:31.0) Gecko/20100101 Firefox/31.0')
    request.add_header('Referrer', parsed_url.netloc)
    # Wrap network data in StringIO so that it looks like a file
    net_data = StringIO(urllib2.build_opener().open(request).read())
    image = Image.open(net_data)
    return process_image(image)


def process_image(image, blocks=4):
    '''Given a PIL Image object it returns its feature vector.

    Args:
      image (PIL.Image): image to process.
      blocks (int, optional): number of block to subdivide the RGB space into.

    Returns:
      list of float: feature vector if successful. None if the image is not
      RGB.
    '''
    if not image.mode == 'RGB':
        return None
    feature = [0] * blocks * blocks * blocks
    pixel_count = 0
    for pixel in image.getdata():
        ridx = int(pixel[0]/(256/blocks))
        gidx = int(pixel[1]/(256/blocks))
        bidx = int(pixel[2]/(256/blocks))
        idx = ridx + gidx * blocks + bidx * blocks * blocks
        feature[idx] += 1
        pixel_count += 1
    return [x/pixel_count for x in feature]

Just to have an idea of what we get, this is the feature vector plot of a city by day:

Feature vector city by day

And this is the feature vector of a city by night:

Feature vector city by night

Exactly as expected: night photos have plenty of dark pixels.

Another good quality of this approach is that feature values are already normalized, which makes most classifier work better.

Training a classifier

I chose scikit-learn as machine learning library, but you are free to choose the one that excites you the most.

First, what is a classifier? It is a “thing” that given a feature vector returns its class. In our example, it should return “1” given the feature vector of a picture of a city by day, and “0” for a city by night. The procedure that teaches to the classifier what feature vectors belong to which classes is called training.

So go collect pictures of cities by day and by night, I’ll wait. Once you got them, put them in two separate folders. This will be our training set, that we’ll use to train a classifier.

Assume that we have no clue what machine learning algorithm we should use. scikit-learn provides a useful cheat sheet to guide us: http://scikit-learn.org/stable/tutorial/machine_learning_map/. In our case, it suggests that we should use a C-Support Vector Classification (also called Support Vector Machine, SVM).

SVM are all-around awesome and are  one of those algorithms that is pretty much always worth trying because they work well in a wide range of settings. Go read about them.

In very simple terms, SVM puts all objects in the training set in a n-dimensional space (n can even be infinite!), and then looks for the plane that better divides objects of type A from objects of type B.

So far, this would work only if these objects are linearly separable. However, with one weird trick (mathematicians hate it) SVM work even on non-linearly separable classes (it’s actually called kernel trick).

The SVM implementation of scikit-learn is available sklearn.svc module. As per cheat sheet suggestion, we are going to use SVC.

SVC has plenty of parameters, the most important are C (penalty error), kernel (the type of kernel to use), and gamma (kernel coefficient).

Even if you have a good knowledge of SVM is not straightforward to choose these parameters. The simplest approach to solve this dilemma is simply to try all possible combinations of these parameters and pick the classifier that works best. scikit-learn automatizes this using the GridSearchCV class.

Putting together the pieces of the puzzle we have to:

  1. gather training data using the code showed in the previous section;
  2. define the parameter search space to find a good classifier
  3. return the classifier

Here’s the code that does it:

def train(training_path_a, training_path_b, print_metrics=True):
    '''Trains a classifier. training_path_a and training_path_b should be
    directory paths and each of them should not be a subdirectory of the other
    one. training_path_a and training_path_b are processed by
    process_directory().

    Args:
      training_path_a (str): directory containing sample images of class A.
      training_path_b (str): directory containing sample images of class B.
      print_metrics  (boolean, optional): if True, print statistics about
        classifier performance.

    Returns:
      A classifier (sklearn.svm.SVC).
    '''
    if not os.path.isdir(training_path_a):
        raise IOError('%s is not a directory' % training_path_a)
    if not os.path.isdir(training_path_b):
        raise IOError('%s is not a directory' % training_path_b)
    training_a = process_directory(training_path_a)
    training_b = process_directory(training_path_b)
    # data contains all the training data (a list of feature vectors)
    data = training_a + training_b
    # target is the list of target classes for each feature vector: a '1' for
    # class A and '0' for class B
    target = [1] * len(training_a) + [0] * len(training_b)
    # split training data in a train set and a test set. The test set will
    # containt 20% of the total
    x_train, x_test, y_train, y_test = cross_validation.train_test_split(data,
            target, test_size=0.20)
    # define the parameter search space
    parameters = {'kernel': ['linear', 'rbf'], 'C': [1, 10, 100, 1000],
            'gamma': [0.01, 0.001, 0.0001]}
    # search for the best classifier within the search space and return it
    clf = grid_search.GridSearchCV(svm.SVC(), parameters).fit(x_train, y_train)
    classifier = clf.best_estimator_
    if print_metrics:
        print()
        print('Parameters:', clf.best_params_)
        print()
        print('Best classifier score')
        print(metrics.classification_report(y_test,
            classifier.predict(x_test)))
    return classifier

There are several other techniques to properly train a classifier, such as cross-validation. Read about them on the official scikit-learn documentation.

Test the classifier on unknown images

We got the training data, we got the classifier, we only need to test it:

def main(training_path_a, training_path_b):
    '''Main function. Trains a classifier and allows to use it on images
    downloaded from the Internet.

    Args:
      training_path_a (str): directory containing sample images of class A.
      training_path_b (str): directory containing sample images of class B.
    '''
    print('Training classifier...')
    classifier = train(training_path_a, training_path_b)
    while True:
        try:
            print("Input an image url (enter to exit): "),
            image_url = raw_input()
            if not image_url:
                break
            features = process_image_url(image_url)
            print(classifier.predict(features))
        except (KeyboardInterrupt, EOFError):
            break
        except:
            exception = sys.exc_info()[0]
            print(exception)

And this is an example of how it works:

Training classifier...

Parameters: {'kernel': 'linear', 'C': 10, 'gamma': 0.01}

Best classifier score
             precision    recall  f1-score   support

          0       1.00      1.00      1.00         3
          1       1.00      1.00      1.00         5

avg / total       1.00      1.00      1.00         8


Input an image url (enter to exit): 

https://upload.wikimedia.org/wikipedia/commons/9/99/Qu%C3%A9bec-City-Skyline.jpg

[1]
Input an image url (enter to exit): 

http://upload.wikimedia.org/wikipedia/commons/d/d4/New_York_City_at_night_HDR_edit1.jpg

[0]

Yay!

The classifier actually returned 1 for a photo taken by day and 0 for a picture taken by night!

Here is the full source code:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''Images binary classifier based on scikit-learn SVM classifier.
It uses the RGB color space as feature vector.
'''

from __future__ import division
from __future__ import print_function
from PIL import Image
from sklearn import cross_validation
from sklearn import grid_search
from sklearn import svm
from sklearn import metrics
from StringIO import StringIO
from urlparse import urlparse
import urllib2
import sys
import os


def process_directory(directory):
    '''Returns an array of feature vectors for all the image files in a
    directory (and all its subdirectories). Symbolic links are ignored.

    Args:
      directory (str): directory to process.

    Returns:
      list of list of float: a list of feature vectors.
    '''
    training = []
    for root, _, files in os.walk(directory):
        for file_name in files:
            file_path = os.path.join(root, file_name)
            img_feature = process_image_file(file_path)
            if img_feature:
                training.append(img_feature)
    return training


def process_image_file(image_path):
    '''Given an image path it returns its feature vector.

    Args:
      image_path (str): path of the image file to process.

    Returns:
      list of float: feature vector on success, None otherwise.
    '''
    image_fp = StringIO(open(image_path, 'rb').read())
    try:
        image = Image.open(image_fp)
        return process_image(image)
    except IOError:
        return None


def process_image_url(image_url):
    '''Given an image URL it returns its feature vector

    Args:
      image_url (str): url of the image to process.

    Returns:
      list of float: feature vector.

    Raises:
      Any exception raised by urllib2 requests.

      IOError: if the URL does not point to a valid file.
    '''
    parsed_url = urlparse(image_url)
    request = urllib2.Request(image_url)
    # set a User-Agent and Referer to work around servers that block a typical
    # user agents and hotlinking. Sorry, it's for science!
    request.add_header('User-Agent', 'Mozilla/5.0 (X11; Ubuntu; Linux ' \
            'x86_64; rv:31.0) Gecko/20100101 Firefox/31.0')
    request.add_header('Referrer', parsed_url.netloc)
    # Wrap network data in StringIO so that it looks like a file
    net_data = StringIO(urllib2.build_opener().open(request).read())
    image = Image.open(net_data)
    return process_image(image)


def process_image(image, blocks=4):
    '''Given a PIL Image object it returns its feature vector.

    Args:
      image (PIL.Image): image to process.
      blocks (int, optional): number of block to subdivide the RGB space into.

    Returns:
      list of float: feature vector if successful. None if the image is not
      RGB.
    '''
    if not image.mode == 'RGB':
        return None
    feature = [0] * blocks * blocks * blocks
    pixel_count = 0
    for pixel in image.getdata():
        ridx = int(pixel[0]/(256/blocks))
        gidx = int(pixel[1]/(256/blocks))
        bidx = int(pixel[2]/(256/blocks))
        idx = ridx + gidx * blocks + bidx * blocks * blocks
        feature[idx] += 1
        pixel_count += 1
    return [x/pixel_count for x in feature]


def show_usage():
    '''Prints how to use this program
    '''
    print("Usage: %s [class A images directory] [class B images directory]" %
            sys.argv[0])
    sys.exit(1)


def train(training_path_a, training_path_b, print_metrics=True):
    '''Trains a classifier. training_path_a and training_path_b should be
    directory paths and each of them should not be a subdirectory of the other
    one. training_path_a and training_path_b are processed by
    process_directory().

    Args:
      training_path_a (str): directory containing sample images of class A.
      training_path_b (str): directory containing sample images of class B.
      print_metrics  (boolean, optional): if True, print statistics about
        classifier performance.

    Returns:
      A classifier (sklearn.svm.SVC).
    '''
    if not os.path.isdir(training_path_a):
        raise IOError('%s is not a directory' % training_path_a)
    if not os.path.isdir(training_path_b):
        raise IOError('%s is not a directory' % training_path_b)
    training_a = process_directory(training_path_a)
    training_b = process_directory(training_path_b)
    # data contains all the training data (a list of feature vectors)
    data = training_a + training_b
    # target is the list of target classes for each feature vector: a '1' for
    # class A and '0' for class B
    target = [1] * len(training_a) + [0] * len(training_b)
    # split training data in a train set and a test set. The test set will
    # containt 20% of the total
    x_train, x_test, y_train, y_test = cross_validation.train_test_split(data,
            target, test_size=0.20)
    # define the parameter search space
    parameters = {'kernel': ['linear', 'rbf'], 'C': [1, 10, 100, 1000],
            'gamma': [0.01, 0.001, 0.0001]}
    # search for the best classifier within the search space and return it
    clf = grid_search.GridSearchCV(svm.SVC(), parameters).fit(x_train, y_train)
    classifier = clf.best_estimator_
    if print_metrics:
        print()
        print('Parameters:', clf.best_params_)
        print()
        print('Best classifier score')
        print(metrics.classification_report(y_test,
            classifier.predict(x_test)))
    return classifier


def main(training_path_a, training_path_b):
    '''Main function. Trains a classifier and allows to use it on images
    downloaded from the Internet.

    Args:
      training_path_a (str): directory containing sample images of class A.
      training_path_b (str): directory containing sample images of class B.
    '''
    print('Training classifier...')
    classifier = train(training_path_a, training_path_b)
    while True:
        try:
            print("Input an image url (enter to exit): "),
            image_url = raw_input()
            if not image_url:
                break
            features = process_image_url(image_url)
            print(classifier.predict(features))
        except (KeyboardInterrupt, EOFError):
            break
        except:
            exception = sys.exc_info()[0]
            print(exception)


if __name__ == '__main__':
    if len(sys.argv) != 3:
        show_usage()
    main(sys.argv[1], sys.argv[2])

Wrap up

This binary classifier works quite well if you feed it with enough training data. Although as example I chose daytime vs. nighttime photos, it works for all images that have a reasonably different colorspaces, e.g. photos of tigers vs. elephants, landscapes vs. portraits, sea vs. meadow, and so on.

Moreover it is quite easy to modify it so that it works with multiple classes. Of course, this is left as an exercise to the reader.

Evaluating mean confidence interval in Java and Python

A confidence interval gives an estimated range of values which is likely to include an unknown population parameter, the estimated range being calculated from a given set of sample data

This is the definition of confidence interval given in the Statistics Glossary v1.1 by Valerie J. Easton & John H. McColl. The “unknown population parameter” is usually the population mean, so in the following I will just assume that the “unknown population parameter” is indeed the mean.

Thus we are dealing with a sample of a population and we want to measure how close we get to the population mean using only data about a sample.

If independent samples are taken from the same population and confidence interval evaluated for each sample then a certain percentage (called confidence level) of the intervals will include the population mean. The confidence level is usually 95%, but we can get to 99%, 90% or any other percentage we fancy.

I’m always a bit let down when I read a paper and authors do not report the confidence interval of their experimental results. It means that whatever measure they are reporting you have to guess whether it is significant or not.

I think it’s good to make a habit of including the confidence interval for any measurement you are reporting.

In most practical settings, we don’t actually know what is the population distribution and we just assume that it is normally distributed. For samples from other population distributions what I am going to describe is approximately correct by the Central Limit Theorem.

For a population with unknown mean \(\mu\), unknown standard deviation \(\sigma\), a confidence interval for the population mean, based on a random sample of size, is \(\overline{x}\pm t^*\frac{s}{\sqrt{n}}\) where:

  • \(\overline{x}\) is the sample mean;
  • \(n\) is the sample size;
  • \(s\) is the estimated standard deviation (also known as standard error);
  • \(t^*\) is the upper \(\frac{1-C}{2}\) critical value for the Student’s t-distribution with \(n-1\) degrees of freedom.

The most difficult element is to evaluate \(t^*\).

Assume that we are given the height in cm of 30 one year old toddlers: 63.5, 81.3, 88.9, 63.5, 76.2, 67.3, 66.0, 64.8, 74.9, 81.3, 76.2, 72.4, 76.2, 81.3, 71.1, 80.0, 73.7, 74.9, 76.2, 86.4, 73.7, 81.3, 68.6, 71.1, 83.8, 71.1, 68.6, 81.3, 73.7, 74.9.

The average height is 74.8 cm. What is the 95% confidence interval of this mean?

Mean Confidence Interval in Java

The Apache Commons Math 3 can give critical values for the Student’s t-distribution. So download it or use your dependency manager to use it. Here is the code that calculates the 95% confidence interval:

import org.apache.commons.math3.distribution.TDistribution;
import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;

public class ConfidenceIntervalApp {

    public static void main(String args[]) {
        // data we want to evaluate: average height of 30 one year old male and female toddlers
        // interestingly, at this age height is not bimodal yet
        double data[] = new double[] { 63.5, 81.3, 88.9, 63.5, 76.2, 67.3, 66.0, 64.8, 74.9, 81.3, 76.2, 72.4, 76.2, 81.3, 71.1, 80.0, 73.7, 74.9, 76.2, 86.4, 73.7, 81.3, 68.6, 71.1, 83.8, 71.1, 68.6, 81.3, 73.7, 74.9 };
        // Build summary statistics of the dataset "data"
        SummaryStatistics stats = new SummaryStatistics();
        for (double val : data) {
            stats.addValue(val);
        }

        // Calculate 95% confidence interval
        double ci = calcMeanCI(stats, 0.95);
        System.out.println(String.format("Mean: %f", stats.getMean()));
        double lower = stats.getMean() - ci;
        double upper = stats.getMean() + ci;
        System.out.println(String.format("Confidence Interval 95%%: %f, %f", lower, upper));
    }

    private static double calcMeanCI(SummaryStatistics stats, double level) {
        try {
            // Create T Distribution with N-1 degrees of freedom
            TDistribution tDist = new TDistribution(stats.getN() - 1);
            // Calculate critical value
            double critVal = tDist.inverseCumulativeProbability(1.0 - (1 - level) / 2);
            // Calculate confidence interval
            return critVal * stats.getStandardDeviation() / Math.sqrt(stats.getN());
        } catch (MathIllegalArgumentException e) {
            return Double.NaN;
        }
    }

}

The output of this program is:

Mean: 74.806667
Confidence Interval 95%: 72.328860, 77.284474

Mean Confidence Interval in Python

The code to do the same calculation in Python is very similar. We will use numpy and scipy:

#!/usr/bin/env python

from scipy.stats import t
from numpy import average, std
from math import sqrt

if __name__ == '__main__':
    # data we want to evaluate: average height of 30 one year old male and
    # female toddlers. Interestingly, at this age height is not bimodal yet
    data = [63.5, 81.3, 88.9, 63.5, 76.2, 67.3, 66.0, 64.8, 74.9, 81.3, 76.2,
            72.4, 76.2, 81.3, 71.1, 80.0, 73.7, 74.9, 76.2, 86.4, 73.7, 81.3,
            68.6, 71.1, 83.8, 71.1, 68.6, 81.3, 73.7, 74.9]
    mean = average(data)
    # evaluate sample variance by setting delta degrees of freedom (ddof) to
    # 1. The degree used in calculations is N - ddof
    stddev = std(data, ddof=1)
    # Get the endpoints of the range that contains 95% of the distribution
    t_bounds = t.interval(0.95, len(data) - 1)
    # sum mean to the confidence interval
    ci = [mean + critval * stddev / sqrt(len(data)) for critval in t_bounds]
    print "Mean: %f" % mean
    print "Confidence Interval 95%%: %f, %f" % (ci[0], ci[1])

The output is exactly the same of the Java version.

Trie: T9, firewalls and DNA

Trie is a tree data structure that stores a set or an associative array (i.e., a map). Since a picture is worth a thousand words, at 24fps for 89 seconds, this short video that explains tries is worth 2.1 million words:

Tries are a type of associative arrays, like hash tables. However, compared to hash tables, tries have several advantages:

  • looking up a word takes O(word length) in the worst case for tries, whereas an imperfect hash table can take up to O(number of words);
  • tries have no collisions;
  • it is possible to walk through all keys in order;
  • it is not necessary to design an hash function.

Tries do have disadvantages:

  • the naïve implementation of tries uses pointers between nodes, which reduces their cache efficiency;
  • if tries are naïvely stored on a storage device they perform much worse than hash tables;
  • a trie may require more memory than an hash table.

So, given these pros and cons, here is a simple trie implementation:

class Node:

    def __init__(self, key, value):
        self.children = {}
        self.key = key
        self.value = value


def append_word(node, word, completeword):
    if not word:
        return
    key = word[0]
    try:
        child = node.children[key]
    except KeyError:
        child = Node(key, None)
        node.children[key] = child
    if len(word) == 1:
        child.value = completeword
    else:
        append_word(child, word[1:], completeword)


def main():
    root = Node(None, None)
    with open('wlist.txt') as wlist:
        map(lambda l: append_word(root, l.strip(), l.strip()), wlist)


if __name__ == '__main__':
    main()
trie.py

This code assumes that the file wlist.txt is accessible and that it contains a dictionary of words, one per line.

For example, this is how the trie that stores all the English words that start with “archi” looks like (click for a large version):

Trie of English words starting with "archi".

As a side node, it possible to “collapse” long chains of prefixes. For the same set of words we would get something like this:

Trie of English words starting with "archi". Commn prefixes have been merged.

Let’s now use a trie to solve a common interview question problem.

T9

T9 is a predictive text technology for mobile phones having a 3×4 keypad. Its core idea is simple: users only press each key once for each letter of the word they want to type. For example, to write “arching” a user would tap “2724464”. That’s exactly what a trie does! Compared with the previous implementation, we need to change two things:

  • nodes are numbers rather than letters;
  • a node is associated to a list of values because the same sequence of numbers can generate different words.

Here’s the code:

#!/usr/bin/env python

from string import maketrans

PHONE_LETTERS = 'abcdefghijklmnopqrstuvwxyz'
PHONE_NUMBERS = '22233344455566677778889999'
# mapping to translate a-z characters to phone digits
PHONE_TRANS = maketrans(PHONE_LETTERS, PHONE_NUMBERS)

class Node:

    def __init__(self, key):
        self.children = {}
        self.key = key
        self.values = []


def append_word(node, sequence, completeword):
    if not sequence:
        return
    key = sequence[0]
    try:
        child = node.children[key]
    except KeyError:
        child = Node(key)
        node.children[key] = child
    if len(sequence) == 1:
        child.values.append(completeword)
    else:
        append_word(child, sequence[1:], completeword)


def main():
    root = Node(None)
    with open('wlist.txt') as wlist:
        # magic! str.translate uses PHONE_TRANS to translate mappings
        map(lambda l: append_word(root, l.strip().translate(PHONE_TRANS), l.strip()), wlist)


if __name__ == '__main__':
    main()

The most prominent changes are that we are using str.translate to map a-z characters to phone digits and that each node has a list of associated values (node.values).

Here is how this “trie” looks like:

Trie of words starting with "archi", with phone keypad numbers in place of letters.

The algorithm to implement T9 is now straightforward: given a sequence of numbers, iterate over them and use each of them to select next trie node. When we run out of numbers it means that we found the longest valid prefix in the trie. We then start a depth first search starting from the last node we explored, and collect words from each node that we explore.

Here is the code that does it:

#!/usr/bin/env python

import fileinput
import string

PHONE_LETTERS = 'abcdefghijklmnopqrstuvwxyz'
PHONE_NUMBERS = '22233344455566677778889999'
PHONE_TRANS = string.maketrans(PHONE_LETTERS, PHONE_NUMBERS)

class Node:

    def __init__(self, key):
        self.children = {}
        self.key = key
        self.values = []


def append_word(node, sequence, completeword):
    if not sequence:
        return
    # discard words that do not match [a-z]+
    if not all(map(lambda x: x in string.ascii_lowercase, completeword)):
        print "Discarding %s: only a-z characters allowed" % completeword
        return
    key = sequence[0]
    try:
        child = node.children[key]
    except KeyError:
        child = Node(key)
        node.children[key] = child
    if len(sequence) == 1:
        child.values.append(completeword)
    else:
        append_word(child, sequence[1:], completeword)


def lookup(node, sequence=None):
    if sequence:
        # there are still numbers in the sequence: follow them in the trie
        try:
            child = node.children[sequence[0]]
            return lookup(child, sequence[1:])
        except KeyError:
            return []
    else:
        # the sequence is empty: explore the trie using a DFS
        result = node.values[:]
        for child in node.children.values():
            result.extend(lookup(child))
        return result


def main():
    root = Node(None)
    with open('wlist.txt') as wlist:
        map(lambda l: append_word(root, l.strip().translate(PHONE_TRANS), l.strip()), wlist)
    sequence = ""
    while True:
        print "Phone keys:"
        try:
            sequence = raw_input()
        except EOFError:
            break
        if not sequence:
            break
        words = sorted(lookup(root, sequence))
        print "Words: %s" % words


if __name__ == '__main__':
    main()

This code assumes that there is a wlist.txt file available. Several English wordlists can be downloaded here: http://www.keithv.com/software/wlist/.

A sample output of this simple application is:

Phone keys:
53763
Words: ['jerod', 'jeroen', 'kernel', 'kernels', 'kernen', 'kerner', 'kernersville', 'kernes', 'kerney', 'kesner', 'lerner']

Other applications

Tries are ubiquitous in all problems where prefix matching is useful. A few examples are:

  • Google autocomplete: tries are augmented with words popularity;
  • spell checkers: each reasonable misspelling of a word (due to insertion, deletion, or substitution of one or more character) is linked to the correct spelling, and each misspelling is linked to the correct spelling (see http://norvig.com/spell-correct.html for an overview of how simple spell correction works);
  • firewalls often store IP ranges associated to a policy (e.g., drop packet, forward packet, accept packet, etc.): IP ranges can be efficiently stored and searched using a trie (for example, see this paper);
  • tries are also used in bioinformatics for overlap detection in fragment assembly (ref.);
  • the fastest algorithm for large data sets sorting, burstsort, is based on tries (ref.)

Whenever you are dealing with a problem where prefixes are important, tries might be the right tool.

Visualizing the dependencies of installed packages on Debian

Packages installed on a Linux system can depend on other packages, for example an application depends on shared libraries that it uses. This means that the “depends on” relationships defines a graph over packages.

Visualizing the “depends on” relationship allows understand at a glance what are the core packages and why each package was installed; so let’s explore de dependencies of packets installed on a Debian-based Linux installation.

Dependencies of a package can be listed using apt-cache depends <packagename>. For example, the list of dependencies of vim are:

vim
  Depends: vim-common
  Depends: vim-runtime
  Depends: libacl1
  Depends: libc6
  Depends: libgpm2
  Depends: libpython2.7
  Depends: libselinux1
  Depends: libtinfo5
  Suggests: <ctags>
    exuberant-ctags:i386
    exuberant-ctags
  Suggests: vim-doc
  Suggests: vim-scripts
  Conflicts: vim:i386

So what we want to do is:

  • list all installed packages;
  • list dependencies for each installed package;
  • put dependencies in some format that allows easy visualization.

Listing installed packages

Listing installed packages is easily done with dpkg:

dpkg --get-selections | grep -v 'deinstall' | cut -f1 | cut -d: -f1

Let’s explain this command:

  1. dpkg --get-selections: list all packages;
  2. grep -v 'deinstall': remove all lines that contain the word “deinstall”, thus only packages whose selection state is “install” will be kept;
  3. cut -f1 | cut -d: -f1

Now we want to get the list dependencies for each of these packages.

Listing dependencies

apt-cache allows to easily list dependencies of a package, given its name. When parsing its output, we should keep in mind that alternative dependencies start with a pipe (“|“) and virtual packages are shown within angle brackets (<>).

Assuming that the variable $pkgname contains the name of a valid package, we can use the following command to list the packages it depends on:

apt-cache depends "$pkgname" | grep 'Depends: ' | grep -v '<' | cut -d':' -f2- | tr -d ' '

What it does is:

  1. apt-cache depends "$pkgname": list all dependencies for the $pkgname package;
  2. grep 'Depends:': keep only lines containing the "Depends" word (i.e., "Depends" and "PreDepends");
  3. grep -v "<": remove references to virtual packages;
  4. cut -d':' -f2-: keep only package names;
  5. tr -d ' ': trim possible whitespaces.

Let's put all in a bash script:

#!/bin/sh

tempf=$(tempfile)
edgesf=$(tempfile)
nodesf=$(tempfile)
dpkg --get-selections | grep -v 'deinstall' | cut -f1 | cut -d: -f1 > "$tempf"
while read pkgname
do
    apt-cache depends "$pkgname" | grep 'Depends: ' | grep -v '<' | cut -d':' -f2- | tr -d ' ' | ( while read dependsf; do
        echo "$pkgname depends on $dependsf"
    done)
done < "$tempf"
rm "$tempf"
rm "$edgesf"
rm "$nodesf"

Visualize dependencies

I choose Gephi to visualize the dependency graph.

A possible alternative would be dot from the GraphViz suite. However, on my desktop machine there are about 2.500 installed packages and more than 13.000 dependencies: this graph is big enough to make very hard for dot to find a reasonable layout for all the graph nodes and output an image that can be actually visualized.

Gephi supports several graph definition formats, including GEXF. GEXF is an XML-based format and it is very simple to adapt our script so that it outputs dependencies in the GEXF format:

#!/bin/sh

tempf=$(tempfile)
edgesf=$(tempfile)
nodesf=$(tempfile)

# write list of installed packages in $tempf file
dpkg --get-selections | grep -v 'deinstall' | cut -f1 | cut -d: -f1 > "$tempf"

# output GEXF header 
echo '<?xml version="1.0" encoding="UTF-8"?>'
echo '<gexf xmlns="http://www.gexf.net/1.2draft" version="1.2">'
echo '    <graph mode="static" defaultedgetype="directed">'

# read all packages from $tempf
while read pkgname
do
    # create a node declaration and write it to $nodesf
    echo "            <node id=\"$pkgname\" label=\"$pkgname\"/>" >> "$nodesf"

    # read all dependencies, create node definitions for each package and add an edge
    apt-cache depends "$pkgname" | grep 'Depends: ' | grep -v '<' | cut -d':' -f2- | tr -d ' ' | ( while read dependsfrom ; do
        echo "            <node id=\"$dependsfrom\" label=\"$dependsfrom\"/>" >> "$nodesf"
        echo "            <edge id=\"$dependsfrom-$pkgname\" source=\"$dependsfrom\" target=\"$pkgname\" />" >> "$edgesf"
    done)

done < "$tempf"
# create the <nodes> XML node. Use uniq to remove duplicated nodes
echo "        <nodes>"
sort "$nodesf" | uniq
echo "        </nodes>"

# create the <nodes> XML node. Use uniq to remove duplicated nodes
echo "        <edges>"
sort "$edgesf" | uniq
echo "        </edges>"
echo "    </graph>"
echo "</gexf>"

# remove temporary files
rm "$tempf"
rm "$edgesf"
rm "$nodesf"

Example

I run this script on a server running a barebone installation of Debian Wheezy. I loaded the resulting GEXF file in Gephi and I tweaked a bit the controls:

Screenshot of Gephi showing package dependencies of a Debian wheezy server.

Screenshot of Gephi showing package dependencies of a Debian wheezy server.

In particular, I ran the "ForceAtlas 2" layout algorithm to neatly position all nodes (bottom left corner in the screenshot), I filtered the nodes to remove those that have a 0 degree, such as font packages (bottom right corner), and I ranked nodes based on their outdegree: nodes with higher outdegree are more red.

A few more tweaks in the Preview section of Gephi allow to create something like this:

Graph of Debian wheezy base installation package dependencies.

Graph of Debian wheezy base installation package dependencies.

The red node at the center of the graph is libc6 that, unsurprisingly, many packages depend on. The other two red nodes on the top left are debconf and dpkg: due to the very small installation size, they are relatively strongly connected to the other packages. The two nodes just below libc6 are zlib1g and multiarch-support. Finally, the node on the bottom right is python.

Here is the same graph for a Kubuntu installation:

Package dependency graph of a laptop running Kubuntu.

Package dependency graph of a laptop running Kubuntu.

Gephi also gives metrics about the graphs it loads. For example, the average degree of this graph is 5.3, and its diameter is 12 (i.e., the longest path between any two connected nodes has 12 edges), the average path length is 3.7 and so on.

The Gephi team has collected several datasets to explore their software, available here.