Friday, October 24, 2008

Phrase Spelling Corrector using Word Collocation Probabilities

Spelling correction is one of those things that people don't notice when it works well. Indeed, for web-based search applications, its manifestation is usually a little "Did you mean: xxx?" component that appears when the application is not able to recognize the term being queried for. In spite of its relative non-ubiquity, however, users do notice when the suggestion is incorrect.

There are various approaches to spelling corrections. One popular approach is to use a Lucene index with character n-grams for terms in the index - I have written previously about my implementation of this approach.

Another popular approach is to compute edit costs and return from a dictionary the words that are within a predefined edit cost from the mispelt word. This is the approach used by GNU Aspell and its Java cousin Jazzy, which we use here.

Both these approaches work very well for single words, so they are very usable for applications such as word processors, where you need to be able to flag and suggest alternatives for mispelt words. In a typical search page, however, a user can type in a multi-word phrase, with one or more words mispelt. The job of the spelling corrector, in this case, is to tie the best suggestions together so that the corrected phrase makes sense within the context of the original phrase. A much harder problem, as you will no doubt agree.

Various approaches to solve this have been suggested and tried - I noticed one such suggestion almost by accident here on the Aspell TODO list, which set me thinking about this whole thing again.

Thinking about this suggestion a bit, I realized that a much simpler way would be to compute conditional probabilities between consecutive words in the phrase, and then consider the "best" suggestion to be the one which connects the words via the most probable path, i.e. the path with the highest sum of conditional probabilities. This effectively boils down a graph theory problem of computing the shortest path in a weighted directed graph. This post describes an implementation of this idea.

Consider Knud Sorensen's example from the Aspell TODO list. Two mispelt phrases and their corrected forms are shown below. As you can see, the correct form of the mispelt word 'fone' differs based on other words in the term.

1
2
    a fone number => a phone number
    a fone dress  => a fine dress

The list below shows the suggestions returned by Jazzy for the mispelt word 'fone', ordered by cost, i.e. the first suggestion is the one with the least edit cost to convert from the mispelt word. Notice that neither 'phone' nor 'fine' is the first result. The Java code for the CLI that I built for quickly looking up suggestions is available later in this post.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
sujit@sirocco:~$ mvn -o exec:java \
  -Dexec.mainClass=com.mycompany.myapp.Shell \
  -Dexec.args=true
jazzy> fone
foe, one, fine, bone, zone, fore, lone, fond, font, hone, cone, gone, 
none, done, tone, fen, foes, on, fan, fin, fun, money, phone, son, fee, 
for, fog, fox, fined, found, fount, fence, fines, finer, honey, non, 
don, ton, ion, yon, vane, vine, June, gene, mane, mine, mono, bane, 
bony, pane, pine, pony, sane, sine, fade, fate, food, foot, vote, face, 
foci, fuse, fare, fire, four, free, fame, foam, fume, file, flee, foil, 
fool, foul, fowl, fake, lane, line, fife, five, fogs, find, fund, fans, 
fins, fang, cane, nine, dine, dune, tune, wane, wine
jazzy> \q

The approach I propose is to construct a graph of our input phrase ('a fone book'), adding vertices corresponding to the original word and each of its spelling suggestions, as shown below. The edge weights represent the conditional probability of the edge target B being followed by the edge source A (or P(B|A)). The numbers are all cooked up for this example, but I describe a way to compute them further down. I still need to populate my database tables with real data, I will describe this in a subsequent post.

What you will immediately notice is that we cannot prune the graph as we encounter each word in the phrase, i.e. we cannot select the most likely suggestion as we parse each word, since the "best path" is the most probable path through the graph from the start vertex to the finish vertex.

Since we are going to use Dijkstra's shortest path algorithm to find the shortest path (aka Graph Geodesic) through the graph, we need to convert the edge probabilities to a weight function given by wA,B, like so:

  wA,B = 1 - P(B|A)
  where:
    wA,B = cost to get from vertex A to B
    P(B|A) = probability of the occurrence of B given A

The probability P(B|A) can be computed as the number of times A and B co-occur in our dataset divided by the number of times word A occurs in the dataset, as shown below:

  If the occurrence of A and B are dependent:
    P(B ∩ A) = P(B|A) * P(A)
  so:
    P(B|A) = P(B ∩ A) / P(A)
           = N(B ∩ A) / N(A)

To get experimental values for N(A) and N(B ∩ A), we will need to extract data from actual search terms used by users from our Apache access logs and populate the following tables:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
mysql> desc occur_a;
+---------+---------------+------+-----+---------+-------+
| Field   | Type          | Null | Key | Default | Extra |
+---------+---------------+------+-----+---------+-------+
| word    | varchar(32)   | NO   | PRI | NULL    |       | 
| n_words | mediumint(11) | NO   |     | NULL    |       | 
+---------+---------------+------+-----+---------+-------+

mysql> desc occur_ab;
+---------+---------------+------+-----+---------+-------+
| Field   | Type          | Null | Key | Default | Extra |
+---------+---------------+------+-----+---------+-------+
| word_a  | varchar(32)   | NO   | PRI | NULL    |       | 
| word_b  | varchar(32)   | NO   | PRI | NULL    |       | 
| n_words | mediumint(11) | NO   |     | NULL    |       | 
+---------+---------------+------+-----+---------+-------+

Without any data in the database tables, the code degrades very gracefully. It just returns what we typed in, as you can see below. This happens because we always insert the original word in the first position of the suggestion list returned by Jazzy, so the "best" among equals is the one that comes first. As before, the Java code for this CLI is provided later in the article.

1
2
3
4
5
6
sujit@sirocco:~$ mvn -o exec:java \
  -Dexec.mainClass=com.mycompany.myapp.Shell
spell-check> a fone book
a fone book
spell-check> a fone dress
a fone dress

Once some (still cooked up) occurrence data is entered manually into these tables...

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
mysql> select * from occur_a;
+-------+---------+
| word  | n_words |
+-------+---------+
| a     |     100 | 
| book  |      43 | 
| dress |      10 | 
| fine  |      12 | 
| phone |      18 | 
+-------+---------+
5 rows in set (0.00 sec)

mysql> select * from occur_ab;
+--------+--------+---------+
| word_a | word_b | n_words |
+--------+--------+---------+
| a      | fine   |       8 | 
| a      | phone  |      13 | 
| book   | phone  |      12 | 
| dress  | fine   |       7 | 
+--------+--------+---------+
4 rows in set (0.00 sec)

...our spelling corrector behaves more intelligently. The beauty of this approach is that its intelligence can be localized to your industry. So for example, if you were in the clothing business, your search terms are more likely to include fine dresses than phone books, and therefore the probability of P(dress|fine) would be higher than P(dress|phone).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
sujit@sirocco:~$ mvn -o exec:java \
  -Dexec.mainClass=com.mycompany.myapp.Shell
spell-check> a fone book
a phone book
spell-check> a fone dress
a fine dress
spell-check> fone book
phone book
spell-check> fone dress
fine dress

Here is the code for the actual Spelling corrector. It uses Jazzy for its word suggestions, and JGraphT to construct a graph and run Dijkstra's shortest path algorithm (included in the JGraphT library) to find the most likely path based on word co-occurrence probabilities.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
// Source: src/main/java/com/mycompany/myapp/SpellingCorrector.java
package com.mycompany.myapp;

import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import javax.sql.DataSource;

import org.apache.commons.lang.StringUtils;
import org.jgrapht.alg.DijkstraShortestPath;
import org.jgrapht.graph.ClassBasedEdgeFactory;
import org.jgrapht.graph.DefaultWeightedEdge;
import org.jgrapht.graph.SimpleDirectedWeightedGraph;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.datasource.DriverManagerDataSource;

import com.swabunga.spell.engine.SpellDictionary;
import com.swabunga.spell.engine.SpellDictionaryHashMap;
import com.swabunga.spell.engine.Word;

/**
 * Uses probability of word-collocations to determine best phrases to be
 * returned from a SpellingCorrector for multi-word mispelt queries.
 */
public class SpellingCorrector {

  private static final int SCORE_THRESHOLD = 200;
  private static final String DICTIONARY_FILENAME = 
    "src/main/resources/english.0";
  
  private long occurASumWords = 1L;
  private JdbcTemplate jdbcTemplate;
  
  @SuppressWarnings("unchecked")
  public String getSuggestion(String input) throws Exception {
    // initialize Jazzy spelling dictionary
    SpellDictionary dictionary = new SpellDictionaryHashMap(
      new File(DICTIONARY_FILENAME));
    // initialize database connection
    DataSource dataSource = new DriverManagerDataSource(
      "com.mysql.jdbc.Driver", "jdbc:mysql://localhost:3306/spelldb", 
      "foo", "secret");
    jdbcTemplate = new JdbcTemplate(dataSource);
    occurASumWords = jdbcTemplate.queryForLong(
      "select sum(n_words) from occur_a");
    if (occurASumWords == 0L) {
      // just a hack to prevent divide by zero for empty db
      occurASumWords = 1L;
    }
    // set up graph and create root vertex
    final SimpleDirectedWeightedGraph<SuggestedWord,DefaultWeightedEdge> g = 
      new SimpleDirectedWeightedGraph<SuggestedWord,DefaultWeightedEdge>(
      new ClassBasedEdgeFactory<SuggestedWord,DefaultWeightedEdge>(
      DefaultWeightedEdge.class));
    SuggestedWord startVertex = new SuggestedWord("START", 0);
    g.addVertex(startVertex);
    // set up variables to hold results of previous iteration
    List<SuggestedWord> prevVertices = 
      new ArrayList<SuggestedWord>();
    List<SuggestedWord> currentVertices = 
      new ArrayList<SuggestedWord>();
    int tokenId = 1;
    prevVertices.add(startVertex);
    // parse the string
    String[] tokens = input.toLowerCase().split("[ -]");
    for (String token : tokens) {
      // build up spelling suggestions for individual word
      List<String> possibleTokens = new ArrayList<String>();
      if (token.trim().length() <= 2) {
        // people usually don't make mistakes for words 2 words or less,
        // just pass it back unchanged
        possibleTokens.add(token);
      } else if (dictionary.isCorrect(token)) {
        // no need to find suggestions, token is recognized as valid spelling
        possibleTokens.add(token);
      } else {
        possibleTokens.add(token);
        List<Word> words = 
          dictionary.getSuggestions(token, SCORE_THRESHOLD);
        for (Word word : words) {
          possibleTokens.add(word.getWord());
        }
      }
      // populate the graph with these values
      for (String possibleToken : possibleTokens) {
        SuggestedWord currentVertex = 
          new SuggestedWord(possibleToken, tokenId); 
        g.addVertex(currentVertex);
        currentVertices.add(currentVertex);
        for (SuggestedWord prevVertex : prevVertices) {
          DefaultWeightedEdge edge = new DefaultWeightedEdge();
          double weight = computeEdgeWeight(
            prevVertex.token, currentVertex.token);
          g.setEdgeWeight(edge, weight);
          g.addEdge(prevVertex, currentVertex, edge);
        }
      }
      prevVertices.clear();
      prevVertices.addAll(currentVertices);
      currentVertices.clear();
      tokenId++;
    } // for token : tokens
    // finally set the end vertex
    SuggestedWord endVertex = new SuggestedWord("END", tokenId);
    g.addVertex(endVertex);
    for (SuggestedWord prevVertex : prevVertices) {
      DefaultWeightedEdge edge = new DefaultWeightedEdge();
      g.setEdgeWeight(edge, 1.0D);
      g.addEdge(prevVertex, endVertex, edge);
    }
    // find shortest path between START and END
    DijkstraShortestPath<SuggestedWord,DefaultWeightedEdge> dijkstra =
      new DijkstraShortestPath<SuggestedWord, DefaultWeightedEdge>(
      g, startVertex, endVertex);
    List<DefaultWeightedEdge> edges = dijkstra.getPathEdgeList();
    List<String> bestMatch = new ArrayList<String>();
    for (DefaultWeightedEdge edge : edges) {
      if (startVertex.equals(g.getEdgeSource(edge))) {
        // skip the START vertex
        continue;
      }
      bestMatch.add(g.getEdgeSource(edge).token);
    }
    return StringUtils.join(bestMatch.iterator(), " ");
  }

  private Double computeEdgeWeight(String prevToken, String currentToken) {
    if (prevToken.equals("START")) {
      // this is the first word, return 1-P(B)
      try {
        double nb = (Double) jdbcTemplate.queryForObject(
          "select n_words/? from occur_a where word = ?", 
          new Object[] {occurASumWords, currentToken}, Double.class);
        return 1.0D - nb;
      } catch (IncorrectResultSizeDataAccessException e) {
        // in case there is no match, then we should return weight of 1
        return 1.0D;
      }
    }
    double na = 0.0D;
    try {
      na = (Double) jdbcTemplate.queryForObject(
        "select n_words from occur_a where word = ?", 
        new String[] {prevToken}, Double.class);
    } catch (IncorrectResultSizeDataAccessException e) {
      // no match, should be 0
      na = 0.0D;
    }
    if (na == 0.0D) {
      // if N(A) == 0, A does not exist, and hence N(A ^ B) == 0 too,
      // so we guard against a DivideByZero and an additional useless
      // computation.
      return 1.0D;
    }
    // for the A^B lookup, alphabetize so A is lexically ahead of B
    // since that is the way we store it in the database
    String[] tokens = new String[] {prevToken, currentToken};
    Arrays.sort(tokens); // alphabetize before lookup
    double nba = 0.0D;
    try {
      nba = (Double) jdbcTemplate.queryForObject(
        "select n_words from occur_ab where word_a = ? and word_b = ?",
        tokens, Double.class);
    } catch (IncorrectResultSizeDataAccessException e) {
      // no result found so N(B^A) = 0
      nba = 0.0D;
    }
    return 1.0D - (nba / na);
  }

  /**
   * Holder for the graph vertex information.
   */
  private class SuggestedWord {
    public String token;
    public int id;
    
    public SuggestedWord(String token, int id) {
      this.token = token;
      this.id = id;
    }
    
    @Override
    public int hashCode() {
      return toString().hashCode();
    }
    
    @Override
    public boolean equals(Object obj) {
      if (!(obj instanceof SuggestedWord)) {
        return false;
      }
      SuggestedWord that = (SuggestedWord) obj;
      return (this.id == that.id && 
        this.token.equals(that.token));
    }
    
    @Override
    public String toString() {
      return id + ":" + token;
    }
  };
}

The CLI proved to be very useful for checking out assumptions quickly when I was developing the algorithm. Its quite simple, it just wraps the functionality within a JLine ConsoleReader. I included it here for completeness and to illustrate how easy it is to build. Depending on the presence of a command line argument, it can function either as an interface over the Jazzy dictionary or to the Phrase Spelling Corrector described in this post.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
// Source: src/main/java/com/mycompany/myapp/Shell.java
package com.mycompany.myapp;

import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import jline.ConsoleReader;

import org.apache.commons.lang.StringUtils;

import com.swabunga.spell.engine.SpellDictionary;
import com.swabunga.spell.engine.SpellDictionaryHashMap;
import com.swabunga.spell.engine.Word;

public class Shell {

  private final int SPELL_CHECK_THRESHOLD = 250;

  public Shell() throws Exception {
    ConsoleReader reader = new ConsoleReader(
      System.in, new PrintWriter(System.out));
    SpellingCorrector spellingCorrector = new SpellingCorrector();
    for (;;) {
      String line = reader.readLine("spell-check> ");
      if ("\\q".equals(line)) {
        break;
      }
      System.out.println(spellingCorrector.getSuggestion(line));
    }
  }

  // === this is really for exploratory testing purposes ===
  
  /**
   * Wrapper over Jazzy native spell checking functionality.
   * @param b always true (to differentiate from the new ctor).
   * @throws Exception if one is thrown.
   */
  public Shell(boolean b) throws Exception {
    ConsoleReader reader = new ConsoleReader(
      System.in, new PrintWriter(System.out));
    SpellDictionary dictionary = new SpellDictionaryHashMap(
      new File("src/main/resources/english.0"));
    for (;;) {
      String line = reader.readLine("jazzy> ");
      if ("\\q".equals(line)) {
        break;
      }
      String suggestions = suggest(dictionary, line);
      System.out.println(suggestions);
    }
  }
  
  /**
   * Looks up single words from Jazzy's English dictionary.
   * @param dictionary the dictionary object to look up.
   * @param incorrect the suspected mispelt word.
   * @return if the incorrect word is correct according to 
   * Jazzy's dictionary, then it is returned, else a set of possible
   * corrections is returned. If no possible corrections were found, 
   * this method returns (no suggestions).
   */
  @SuppressWarnings("unchecked")
  private String suggest(SpellDictionary dictionary, String incorrect) {
    if (dictionary.isCorrect(incorrect)) {
      // return the entered word
      return incorrect;
    }
    List<Word> words = dictionary.getSuggestions(
      incorrect, SPELL_CHECK_THRESHOLD);
    List<String> suggestions = new ArrayList<String>();
    final Map<String,Integer> costs = 
      new HashMap<String,Integer>();
    for (Word word : words) {
      costs.put(word.getWord(), word.getCost());
      suggestions.add(word.getWord());
    }
    if (suggestions.size() == 0) {
      return "(no suggestions)";
    }
    Collections.sort(suggestions, new Comparator<String>() {
      public int compare(String s1, String s2) {
        Integer cost1 = costs.get(s1);
        Integer cost2 = costs.get(s2);
        return cost1.compareTo(cost2);
      }
    });
    return StringUtils.join(suggestions.iterator(), ", ");
  }
  
  public static void main(String[] args) throws Exception {
    if (args.length == 1) {
      // word mode
      new Shell(true);
    } else {
      new Shell();
    }
  }
}

Note that I still don't know whether this works well for a large set of mispelt phrases. I need to put this through a lot more real data to say that with any degree of certainty. It is also fairly slow in my development testing. I have a few ideas as to how that can be improved, although I will attempt them after I have some real data to play with. As always, any suggestions/corrections much appreciated.

Update 2009-04-26: In recent posts, I have been building on code written and described in previous posts, so there were (and rightly so) quite a few requests for the code. So I've created a project on Sourceforge to host the code. You will find the complete source code built so far in the project's SVN repository.

Saturday, October 18, 2008

IR Math in Java : Experiments in Clustering

As I mentioned last week, I have been trying to teach myself clustering algorithms. Having used the Carrot API to do some clustering work about 6 months ago, I have been curious about the clustering algorithms themselves. Carrot offers you a choice of several built-in clustering algorithms, so you just use one depending on your needs. Obviously, this presupposes that you know enough about the algorithms themselves to make the decision (which wasn't the case for me, unfortunately). So what better way to learn than to implement the algorithms in code? So this post covers some popular clustering algorithms implemented in Java.

The code in this post is based on algorithms from various sources. Sources are mentioned in the individual sections as well as listed in the references section below. I describe my implementations and test results for the following algorithms:

  1. K-Means Clustering
  2. Quality Threshold (QT) Clustering
  3. Simulated Annealing Clustering
  4. Nearest Neighbor Clustering
  5. Genetic Algorithm Clustering

K-Means Algorithm

For K-Means clustering, one seeds a random number of clusters with a few random seed documents from the collection. An estimate k of the number of clusters to use for a document collection of N documents is given by the heuristic below:

  k = floor(sqrt(N))
  where:
    N = number of documents in collection.
    k = the estimate of the number of clusters.

The algorithm starts by seeding the clusters with one random document each from the collection. It computes the centroid μ for each cluster using the following formula:

  μ = sqrt(sum(xi2)) / N
  where:
    μ = centroid of a cluster.
    N = number of documents in the collection.
    xi = the i-th document vector.

For each document, we compute the similarity between the centroids of the clusters, and assign the document to the cluster whose centroid is most similar. The similarity measure used is Cosine Similarity - I use code from my article on similarity metrics.

At the end of this step, we have a list of clusters fully populated with the documents from the collection. We then recompute the centroid based on the documents in these clusters, and repeat the above step until the new cluster is no better than the previous cluster.

Although common sense would suggest that we should use the same measure, ie Eucledian distance, for both similarity and centroid calculations, this does not work in practice - there will be no improvement in the cluster after the initial population. So it is important to use a different measure to calculate similarity.

Here is the code for my K-Means clusterer. The algorithm used is the one in the TMAP book[1].

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// Source: src/main/java/com/mycompany/myapp/clustering/KMeansClusterer.java
package com.mycompany.myapp.clustering;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.collections15.CollectionUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import Jama.Matrix;

public class KMeansClusterer {

  private final Log log = LogFactory.getLog(getClass());
  
  private String[] initialClusterAssignments = null;
  
  public void setInitialClusterAssignments(String[] documentNames) {
    this.initialClusterAssignments = documentNames;
  }
  
  public List<Cluster> cluster(DocumentCollection collection) {
    int numDocs = collection.size();
    int numClusters = 0;
    if (initialClusterAssignments == null) {
      // compute initial cluster assignments
      IdGenerator idGenerator = new IdGenerator(numDocs);
      numClusters = (int) Math.floor(Math.sqrt(numDocs));
      initialClusterAssignments = new String[numClusters];
      for (int i = 0; i < numClusters; i++) {
        int docId = idGenerator.getNextId();
        initialClusterAssignments[i] = collection.getDocumentNameAt(docId);
      }
    } else {
      numClusters = initialClusterAssignments.length;
    }

    // build initial clusters
    List<Cluster> clusters = new ArrayList<Cluster>();
    for (int i = 0; i < numClusters; i++) {
      Cluster cluster = new Cluster("C" + i);
      cluster.addDocument(initialClusterAssignments[i], 
        collection.getDocument(initialClusterAssignments[i]));
      clusters.add(cluster);
    }
    log.debug("..Initial clusters:" + clusters.toString());

    List<Cluster> prevClusters = new ArrayList<Cluster>();

    // Repeat until termination conditions are satisfied
    for (;;) {
      // For every cluster i, (re-)compute the centroid based on the
      // current member documents. (We have moved 2.2 above 2.1 because
      // this needs to be done before every iteration).
      Matrix[] centroids = new Matrix[numClusters];
      for (int i = 0; i < numClusters; i++) {
        Matrix centroid = clusters.get(i).getCentroid();
        centroids[i] = centroid;
      }
      // For every document d, find the cluster i whose centroid is 
      // most similar, assign d to cluster i. (If a document is 
      // equally similar from all centroids, then just dump it into 
      // cluster 0).
      for (int i = 0; i < numDocs; i++) {
        int bestCluster = 0;
        double maxSimilarity = Double.MIN_VALUE;
        Matrix document = collection.getDocumentAt(i);
        String docName = collection.getDocumentNameAt(i);
        for (int j = 0; j < numClusters; j++) {
          double similarity = clusters.get(j).getSimilarity(document);
          if (similarity > maxSimilarity) {
            bestCluster = j;
            maxSimilarity = similarity;
          }
        }
        for (Cluster cluster : clusters) {
          if (cluster.getDocument(docName) != null) {
            cluster.removeDocument(docName);
          }
        }
        clusters.get(bestCluster).addDocument(docName, document);
      }
      log.debug("..Intermediate clusters: " + clusters.toString());

      // Check for termination -- minimal or no change to the assignment
      // of documents to clusters.
      if (CollectionUtils.isEqualCollection(clusters, prevClusters)) {
        break;
      }
      prevClusters.clear();
      prevClusters.addAll(clusters);
    }
    // Return list of clusters
    log.debug("..Final clusters: " + clusters.toString());
    return clusters;
  }
}

The K-Means algorithm seems to be reasonably fast. However, the problem with it is that the solution is very sensitive to initial cluster seeding. Here are some results I got from using a random number generator to seed the clusters.

1
2
3
4
5
6
7
==== Results from K-Means clustering ==== (seeds: [D5,D2])
C0:[D1, D3, D4, D5, D6]
C1:[D2, D7]

==== Results from K-Means clustering ==== (seeds: [D3,D2])
C0:[D1, D3, D4, D5, D6]
C1:[D2, D7]

However, if I seed the clusters manually, using the results from my cluster visualization article last week, then I get results which look reasonable:

1
2
3
==== Results from K-Means clustering ==== (seeds: [D1,D3])
C0:[D1, D2, D5, D6, D7]
C1:[D3, D4]

Quality Threshold (QT) Algorithm

The Quality Threshold (QT) algorithm uses a maximum diameter settable by the user to cluster documents. The first cluster is built with the first document in the collection. As long as other documents are close enough to be within the diameter specified, they are added to the cluster. Once all documents are read, the documents that have been added to the cluster are set aside and the algorithm repeated recursively on the rest of the document collection. The program stops when there are no more documents. The number of levels that the program recurses down to corresponds to the number of clusters formed as a result.

The distance between a document and a cluster is computed using Complete Linkage Distance, ie the distance from the document and the furthest document in the cluster.

Here is the code for my QT Clustering program. The algorithm used was from this Wikipedia article[2].

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
// Source: src/main/java/com/mycompany/myapp/clustering/QtClusterer.java
package com.mycompany.myapp.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math.stat.descriptive.moment.Mean;

import Jama.Matrix;

import com.mycompany.myapp.similarity.CosineSimilarity;

public class QtClusterer {

  private final Log log = LogFactory.getLog(getClass());
  
  private double maxDiameter;
  private boolean randomizeDocuments;
  
  public void setMaxRadius(double maxRadius) {
    this.maxDiameter = maxRadius * 2.0D;
  }
  
  public void setRandomizeDocuments(boolean randomizeDocuments) {
    this.randomizeDocuments = randomizeDocuments;
  }
  
  public List<Cluster> cluster(DocumentCollection collection) {
    if (randomizeDocuments) {
      collection.shuffle();
    }
    List<Cluster> clusters = new ArrayList<Cluster>();
    Set<String> clusteredDocNames = new HashSet<String>();
    cluster_r(collection, clusters, clusteredDocNames, 0);
    return clusters;
  }

  private void cluster_r(DocumentCollection collection, 
      List<Cluster> clusters, 
      Set<String> clusteredDocNames, int level) {
    int numDocs = collection.size();
    int numClustered = clusteredDocNames.size();
    if (numDocs == numClustered) {
      return;
    }
    Cluster cluster = new Cluster("C" + level);
    for (int i = 0; i < numDocs; i++) {
      Matrix document = collection.getDocumentAt(i);
      String docName = collection.getDocumentNameAt(i);
      if (clusteredDocNames.contains(docName)) {
        continue;
      }
      log.debug("max dist=" + cluster.getCompleteLinkageDistance(document));
      if (cluster.getCompleteLinkageDistance(document) < maxDiameter) {
        cluster.addDocument(docName, document);
        clusteredDocNames.add(docName);
      }
    }
    if (cluster.size() == 0) {
      log.warn("No clusters added at level " + level + ", check diameter");
    }
    clusters.add(cluster);
    cluster_r(collection, clusters, clusteredDocNames, level + 1);
  }
}

The algorithm is easy to understand, and always returns the exact same clusters, regardless of the input. Using a diameter threshold of 0.4, I was able to get two clusters which is shown below:

1
2
3
==== Results from Qt Clustering ==== (diameter: 0.4D)
C0:[D6, D7]
C1:[D2, D1, D4, D5, D3]

Simulated Annealing Algorithm

The Simulated Annealing clustering algorithm is based on the Annealing process in metallurgy, where it is used to harden metals by cooling the molten metal in steps.

The algorithm starts by setting an initial "temperature", and builds an initial set of clusters using some population process. Mod based partitioning to used populate the initial clusters, although a degree of randomness can be added by shuffling the collection.

At each temperature setting, we exchange two random documents between two random clusters for a specified number of times. We then check to see if the solution improved or degraded based on the average radius of the cluster. Depending on the current temperature setting, we compute a probability that we should accept the degraded solution (aka downhill move). The probability is given by:

  P = exp((Si-1 - Si) / T)
  where:
    Si-1 = value of solution at loop (i-1)
    Si = value of solution at loop (i)
    T = current temperature setting

If the probability is higher than a specified threshold, we accept a downhill move. We then decrease the temperature by a specified step value. We then go back to exchanging random documents between random clusters in a loop. We keep doing this until the temperature is below a certain cutoff point.

From the probability equation above, the algorithm will allow more downhill moves towards its end, when the temperature gets lower. This allows more exploration of the solution space than K-Means or QT clustering methods.

The code for my Simulated Annealing Clusterer is shown below. The algorithm comes from the TMAP book[1].

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
// Source: src/main/java/com/mycompany/myapp/clustering/SimulatedaAnnealingClusterer.java
package com.mycompany.myapp.clustering;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class SimulatedAnnealingClusterer {

  private final Log log = LogFactory.getLog(getClass());

  private boolean randomizeDocs;
  private double initialTemperature;
  private double finalTemperature;
  private double downhillProbabilityCutoff;
  private int numberOfLoops;
  
  public void setRandomizeDocs(boolean randomizeDocs) {
    this.randomizeDocs = randomizeDocs;
  }
  
  public void setInitialTemperature(double initialTemperature) {
    this.initialTemperature = initialTemperature;
  }
  
  public void setFinalTemperature(double finalTemperature) {
    this.finalTemperature = finalTemperature;
  }
  
  public void setDownhillProbabilityCutoff(
      double downhillProbabilityCutoff) {
    this.downhillProbabilityCutoff = downhillProbabilityCutoff;
  }
  
  public void setNumberOfLoops(int numberOfLoops) {
    this.numberOfLoops = numberOfLoops;
  }
  
  public List<Cluster> cluster(DocumentCollection collection) {
    // Get initial set of clusters... 
    int numDocs = collection.size();
    int numClusters = (int) Math.floor(Math.sqrt(numDocs));
    List<Cluster> clusters = new ArrayList<Cluster>();
    for (int i = 0; i < numClusters; i++) {
      clusters.add(new Cluster("C" + i));
    }
    // ...and set initial temperature parameter T.
    double temperature = initialTemperature;
    // Randomly assign documents to the k clusters.
    if (randomizeDocs) {
      collection.shuffle();
    }
    for (int i = 0; i < numDocs; i++) {
      int targetCluster = i % numClusters;
      clusters.get(targetCluster).addDocument(
        collection.getDocumentNameAt(i),
        collection.getDocument(collection.getDocumentNameAt(i)));
    }
    log.debug("..Initial clusters: " + clusters.toString());
    // Repeat until temperature is reduced to the minimum.
    while (temperature > finalTemperature) {
      double previousAverageRadius = 0.0D;
      List<Cluster> prevClusters = new ArrayList<Cluster>();
      // Run loop NUM_LOOP times.
      for (int loop = 0; loop < numberOfLoops; loop++) {
        // Find a new set of clusters by altering the membership of some
        // documents. Start by picking two clusters at random
        List<Integer> randomClusterIds = getRandomClusterIds(clusters);
        // pick two documents out of the clusters at random
        List<String> randomDocumentNames = 
          getRandomDocumentNames(collection, randomClusterIds, clusters);
        // exchange the two random documents among the random clusters.
        clusters.get(randomClusterIds.get(0)).removeDocument(
          randomDocumentNames.get(0));
        clusters.get(randomClusterIds.get(0)).addDocument(
          randomDocumentNames.get(1), 
          collection.getDocument(randomDocumentNames.get(1)));
        clusters.get(randomClusterIds.get(1)).removeDocument(
          randomDocumentNames.get(1));
        clusters.get(randomClusterIds.get(1)).addDocument(
          randomDocumentNames.get(0), 
          collection.getDocument(randomDocumentNames.get(0)));
        // Compare the difference between the values of the new and old
        // set of clusters. If there is an improvement, accept the new 
        // set of clusters, otherwise accept the new set of clusters with
        // probability p.
        log.debug("..Intermediate clusters: " + clusters.toString());
        double averageRadius = getAverageRadius(clusters);
        if (averageRadius > previousAverageRadius) {
          // possible downhill move, calculate the probability of it being 
          // accepted
          double probability = 
            Math.exp((previousAverageRadius - averageRadius)/temperature);
          if (probability < downhillProbabilityCutoff) {
            // go back to the cluster before the changes
            clusters.clear();
            clusters.addAll(prevClusters);
            continue;
          }
        }
        prevClusters.clear();
        prevClusters.addAll(clusters);
        previousAverageRadius = averageRadius;
      }
      // Reduce the temperature based on the cooling schedule.
      temperature = temperature / 10;
    }
    // Return the final set of clusters.
    return clusters;
  }

  private List<Integer> getRandomClusterIds(
      List<Cluster> clusters) {
    IdGenerator clusterIdGenerator = new IdGenerator(clusters.size());
    List<Integer> randomClusterIds = new ArrayList<Integer>();
    for (int i = 0; i < 2; i++) {
      randomClusterIds.add(clusterIdGenerator.getNextId());
    }
    return randomClusterIds;
  }

  private List<String> getRandomDocumentNames(
      DocumentCollection collection, 
      List<Integer> randomClusterIds, 
      List<Cluster> clusters) {
    List<String> randomDocumentNames = new ArrayList<String>();
    for (Integer randomClusterId : randomClusterIds) {
      Cluster randomCluster = clusters.get(randomClusterId);
      IdGenerator documentIdGenerator = 
        new IdGenerator(randomCluster.size());
      randomDocumentNames.add(
        randomCluster.getDocumentName(documentIdGenerator.getNextId()));
    }
    return randomDocumentNames;
  }

  private double getAverageRadius(List<Cluster> clusters) {
    double score = 0.0D;
    for (Cluster cluster : clusters) {
      score += cluster.getRadius();
    }
    return (score / clusters.size());
  }
}

Results for Simulated Annealing runs vary across runs, which is expected, since this is essentially a Monte Carlo simulation. Some results from multiple runs, with the initial and final temperatures set to 100 and 1, and the downhill probability threshold set to 0.7, are shown below. One way to come by a good set of final results may be to consider aggregating results from multiple runs into a single one.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
==== Results from Simulated Annealing Clustering ====
C0:[D1, D3, D2, D4]
C1:[D5, D6, D7]

==== Results from Simulated Annealing Clustering ====
C0:[D3, D4, D2, D5]
C1:[D7, D1, D6]

==== Results from Simulated Annealing Clustering ====
C0:[D1, D7, D5, D6]
C1:[D4, D2, D3]

Nearest Neighbor Algorithm

This algorithm is classified as a Genetic algorithm in the TMAP book, but a subsequent section in the book describes a genetic clustering algorithm that involves mutations and crossovers. I guess the latter type is commonly associated with genetic algorithms in general. However, the Nearest Neighbor algorithm is is popular for clustering genes as well, so I guess calling it a genetic algorithm is not incorrect.

The algorithm first sorts the documents according to the sum of similarities with its 2r neighbors. It then loops through the documents in descending order of the sum of similarities. If a document is already assigned to a cluster, it is skipped, otherwise a new cluster is created with the document as its seed. Then all neighboring documents to the right and left of the current document that are not already assigned and have a simularity greater than a specified threshold are added to the cluster.

My code for the Nearest Neighbor algorithm is shown below. The algorithm comes from the TMAP book[1].

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
// Source: src/main/java/com/mycompany/myapp/clustering/NearestNeighborClusterer.java
package com.mycompany.myapp.clustering;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class NearestNeighborClusterer {

  private final Log log = LogFactory.getLog(getClass());
  
  private int numNeighbors;
  private double similarityThreshold;
  
  public void setNumNeighbors(int numNeighbors) {
    this.numNeighbors = numNeighbors;
  }

  public void setSimilarityThreshold(double similarityThreshold) {
    this.similarityThreshold = similarityThreshold;
  }

  public List<Cluster> cluster(DocumentCollection collection) {
    // get neighbors for every document
    Map<String,Double> similarityMap = collection.getSimilarityMap();
    Map<String,List<String>> neighborMap = 
      new HashMap<String,List<String>>();
    for (String documentName : collection.getDocumentNames()) {
      neighborMap.put(documentName, 
        collection.getNeighbors(documentName, similarityMap, numNeighbors));
    }
    // compute sum of similarities of every document with its numNeighbors
    Map<String,Double> fitnesses = 
      getFitnesses(collection, similarityMap, neighborMap);
    List<String> sortedDocNames = new ArrayList<String>();
    // sort by sum of similarities descending
    sortedDocNames.addAll(collection.getDocumentNames());
    Collections.sort(sortedDocNames, Collections.reverseOrder(
      new ByValueComparator<String,Double>(fitnesses)));
    List<Cluster> clusters = new ArrayList<Cluster>();
    int clusterId = 0;
    // Loop through the list of documents in descending order of the sum 
    // of the similarities.
    Map<String,String> documentClusterMap = 
      new HashMap<String,String>();
    for (String docName : sortedDocNames) {
      // skip if document already assigned to cluster
      if (documentClusterMap.containsKey(docName)) {
        continue;
      }
      // create cluster with current document
      Cluster cluster = new Cluster("C" + clusterId);
      cluster.addDocument(docName, collection.getDocument(docName));
      documentClusterMap.put(docName, cluster.getId());
      // find all neighboring documents to the left and right of the current
      // document that are not assigned to a cluster, and have a similarity
      // greater than our threshold. Add these documents to the new cluster
      List<String> neighbors = neighborMap.get(docName);
      for (String neighbor : neighbors) {
        if (documentClusterMap.containsKey(neighbor)) {
          continue;
        }
        double similarity = similarityMap.get(
          StringUtils.join(new String[] {docName, neighbor}, ":"));
        if (similarity < similarityThreshold) {
          continue;
        }
        cluster.addDocument(neighbor, collection.getDocument(neighbor));
        documentClusterMap.put(neighbor, cluster.getId());
      }
      clusters.add(cluster);
      clusterId++;
    }
    return clusters;
  }

  private Map<String,Double> getFitnesses(
      DocumentCollection collection, 
      Map<String,Double> similarityMap,
      Map<String,List<String>> neighbors) {
    Map<String,Double> fitnesses = new HashMap<String,Double>();
    for (String docName : collection.getDocumentNames()) {
      double fitness = 0.0D;
      for (String neighborDoc : neighbors.get(docName)) {
        String key = StringUtils.join(
          new String[] {docName, neighborDoc}, ":");
        fitness += similarityMap.get(key);
      }
      fitnesses.put(docName, fitness);
    }
    return fitnesses;
  }
}

Although there are extra pre-sorting work that needs to be done, the algorithm is relatively simple to understand. With a similarity threshold set to 0.25, I get the following results:

1
2
3
4
5
6
=== Clusters from Nearest Neighbor Algorithm === (sim threshold = 0.25)
C0:[D4, D3]
C1:[D7, D6]
C2:[D2]
C3:[D1]
C4:[D5]

Genetic Algorithm

The code for this algorithm was written using an algorithm described in the paper by Maulik and Bandopadhayaya[5]. In the language of genetics, a document is a gene and a cluster is a chromosome. Clusters get fitter by a reproducing and passing on their best traits, in a process similar to Darwinian evolution.

The algorithm starts by estimating the number of clusters, partitioning the document collection by mod value into the clusters. It then computes the fitness across all clusters in the current generation. After that, it will execute a few (configurable) crossover operations followed by a mutation operation. Crossover involves selecting two cut points, and exchanging the documents for the portion between the cut points, and thus creating a new cluster. Mutation selects two random documents from two random clusters and exchanges them. At the end of each generation, the fitness of the clusters are recomputed. The algorithm terminates when the fitness does not increase any more across generations.

My code for the genetic clustering algorithm is shown below:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
// Source: src/main/java/com/mycompany/myapp/GeneticClusterer.java
package com.mycompany.myapp.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class GeneticClusterer {

  private final Log log = LogFactory.getLog(getClass());
  
  private boolean randomizeData;
  private int numCrossoversPerMutation;
  private int maxGenerations;
  
  public void setRandomizeData(boolean randomizeData) {
    this.randomizeData = randomizeData;
  }
  
  public void setNumberOfCrossoversPerMutation(int ncpm) {
    this.numCrossoversPerMutation = ncpm;
  }

  public void setMaxGenerations(int maxGenerations) {
    this.maxGenerations = maxGenerations;
  }
  
  public List<Cluster> cluster(DocumentCollection collection) {
    // get initial clusters
    int k = (int) Math.floor(Math.sqrt(collection.size()));
    List<Cluster> clusters = new ArrayList<Cluster>();
    for (int i = 0; i < k; i++) {
      Cluster cluster = new Cluster("C" + i);
      clusters.add(cluster);
    }
    if (randomizeData) {
      collection.shuffle();
    }
    // load it up using mod partitioning, this is P(0)
    int docId = 0;
    for (String documentName : collection.getDocumentNames()) {
      int clusterId = docId % k;
      clusters.get(clusterId).addDocument(
        documentName, collection.getDocument(documentName));
      docId++;
    }
    log.debug("Initial clusters = " + clusters.toString());
    // holds previous cluster in the compute loop
    List<Cluster> prevClusters = new ArrayList<Cluster>();
    double prevFitness = 0.0D;
    int generations = 0;
    for (;;) {
      // compute fitness for P(t)
      double fitness = computeFitness(clusters);
      // if termination condition achieved, break and return clusters
      if (prevFitness > fitness) {
        clusters.clear();
        clusters.addAll(prevClusters);
        break;
      }
      // even if termination condition not met, terminate after the
      // maximum number of generations
      if (generations > maxGenerations) {
        break;
      }
      // do specified number of crossover operations for this generation
      for (int i = 0; i < numCrossoversPerMutation; i++) {
        crossover(clusters, collection, i);
        generations++;
      }
      // followed by a single mutation per generation
      mutate(clusters, collection);
      generations++;
      log.debug("..Intermediate clusters (" + generations + "): " +
        clusters.toString());
      // hold on to previous solution
      prevClusters.clear();
      prevClusters.addAll(clusters);
      prevFitness = computeFitness(prevClusters);
    }
    return clusters;
  }
  
  /**
   * Come up with something arbitary. Just compute the sum of the radii of
   * the clusters.
   * @param clusters
   * @return
   */
  private double computeFitness(List<Cluster> clusters) {
    double radius = 0.0D;
    for (Cluster cluster : clusters) {
      cluster.getCentroid();
      radius += cluster.getRadius();
    }
    return radius;
  }
  
  /**
   * Selects two random clusters from the list, then selects two cut-points
   * based on the minimum cluster size of the two clusters. Exchanges the
   * documents between the cut points.
   * @param clusters the clusters to operate on.
   * @param sequence the sequence number of the cross over operation.
   */
  public void crossover(List<Cluster> clusters, 
      DocumentCollection collection, int sequence) {
    IdGenerator clusterIdGenerator = new IdGenerator(clusters.size());
    int[] clusterIds = new int[2];
    clusterIds[0] = clusterIdGenerator.getNextId();
    clusterIds[1] = clusterIdGenerator.getNextId();
    int minSize = Math.min(
      clusters.get(clusterIds[0]).size(), 
      clusters.get(clusterIds[1]).size());
    IdGenerator docIdGenerator = new IdGenerator(minSize);
    int[] cutPoints = new int[2];
    cutPoints[0] = docIdGenerator.getNextId();
    cutPoints[1] = docIdGenerator.getNextId();
    Arrays.sort(cutPoints);
    Cluster cluster1 = clusters.get(clusterIds[0]);
    Cluster cluster2 = clusters.get(clusterIds[1]);
    for (int i = 0; i < cutPoints[0]; i++) {
      String docName1 = cluster1.getDocumentName(i);
      String docName2 = cluster2.getDocumentName(i);
      cluster1.removeDocument(docName1);
      cluster2.addDocument(docName1, collection.getDocument(docName1));
      cluster2.removeDocument(docName2);
      cluster1.addDocument(docName2, collection.getDocument(docName2));
    }
    // leave the documents between the cut points alone
    for (int i = cutPoints[1]; i < minSize; i++) {
      String docName1 = cluster1.getDocumentName(i);
      String docName2 = cluster2.getDocumentName(i);
      cluster1.removeDocument(docName1);
      cluster2.addDocument(docName1, collection.getDocument(docName1));
      cluster2.removeDocument(docName2);
      cluster1.addDocument(docName2, collection.getDocument(docName2));
    }
    // rebuild the Cluster list, replacing the changed clusters.
    List<Cluster> crossoverClusters = new ArrayList<Cluster>();
    int clusterId = 0;
    for (Cluster cluster : clusters) {
      if (clusterId == clusterIds[0]) {
        crossoverClusters.add(cluster1);
      } else if (clusterId == clusterIds[1]) {
        crossoverClusters.add(cluster2);
      } else {
        crossoverClusters.add(cluster);
      }
      clusterId++;
    }
    clusters.clear();
    clusters.addAll(crossoverClusters);
  }
  
  /**
   * Exchanges a random document between two random clusters in the list.
   * @param clusters the clusters to operate on.
   */
  private void mutate(List<Cluster> clusters, 
      DocumentCollection collection) {
    // choose two random clusters
    IdGenerator clusterIdGenerator = new IdGenerator(clusters.size());
    int[] clusterIds = new int[2];
    clusterIds[0] = clusterIdGenerator.getNextId();
    clusterIds[1] = clusterIdGenerator.getNextId();
    Cluster cluster1 = clusters.get(clusterIds[0]);
    Cluster cluster2 = clusters.get(clusterIds[1]);
    // choose two random documents in the clusters
    int minSize = Math.min(
      clusters.get(clusterIds[0]).size(), 
      clusters.get(clusterIds[1]).size());
    IdGenerator docIdGenerator = new IdGenerator(minSize);
    String docName1 = cluster1.getDocumentName(docIdGenerator.getNextId());
    String docName2 = cluster2.getDocumentName(docIdGenerator.getNextId());
    // exchange the documents
    cluster1.removeDocument(docName1);
    cluster1.addDocument(docName2, collection.getDocument(docName2));
    cluster2.removeDocument(docName2);
    cluster2.addDocument(docName1, collection.getDocument(docName1));
    // rebuild the cluster list, replacing changed clusters
    List<Cluster> mutatedClusters = new ArrayList<Cluster>();
    int clusterId = 0;
    for (Cluster cluster : clusters) {
      if (clusterId == clusterIds[0]) {
        mutatedClusters.add(cluster1);
      } else if (clusterId == clusterIds[1]) {
        mutatedClusters.add(cluster2);
      } else {
        mutatedClusters.add(cluster);
      }
      clusterId++;
    }
    clusters.clear();
    clusters.addAll(mutatedClusters);
  }
}

To measure the fitness of a generation (ie the list of clusters), I decided on the sum of the radii of the clusters. I guess I could have used a fancier function such as the sum of similarities in the Nearest Neighbor algorithm. In any case, I terminated the algorithm after 500 generations, and the result it came up with is shown below:

1
2
3
=== Clusters from Genetic Algorithm ===
C0:[D2, D1, D3, D4]
C1:[D5, D7, D6]

Supporting classes

In order to make the code for the clusterers clean and readable, a lot of code is factored out into supporting classes. They are shown below:

Cluster.java

This class models a cluster as a list of named document objects. It provides various methods to compute properties of a cluster given its members, such as centroid, Eucledian Distance or Cosine Similarity of a document from the cluster centroid, etc. The code is shown below:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69

 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
// Source: src/main/java/com/mycompany/myapp/clustering/Cluster.java
package com.mycompany.myapp.clustering;

import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math.stat.descriptive.rank.Max;

import Jama.Matrix;

public class Cluster {
  
  private final Log log = LogFactory.getLog(getClass());
  
  private String id;
  private Map<String,Matrix> docs = 
    new LinkedHashMap<String,Matrix>();
  private List<String> docNames = new LinkedList<String>();
  
  private Matrix centroid = null;
  
  public Cluster(String id) {
    super();
    this.id = id;
  }
  
  public String getId() {
    return id;
  }
  
  public Set<String> getDocumentNames() {
    return docs.keySet();
  }

  public String getDocumentName(int pos) {
    return docNames.get(pos);
  }
  
  public Matrix getDocument(String documentName) {
    return docs.get(documentName);
  }

  public Matrix getDocument(int pos) {
    return docs.get(docNames.get(pos));
  }
  
  public void addDocument(String docName, Matrix docMatrix) {
    docs.put(docName, docMatrix);
    docNames.add(docName);
    log.debug("...." + id + " += " + docName);
  }

  public void removeDocument(String docName) {
    docs.remove(docName);
    docNames.remove(docName);
    log.debug("...." + id + " -= " + docName);
  }

  public int size() {
    return docs.size();
  }
  
  public boolean contains(String docName) {
    return docs.containsKey(docName);
  }
  
  /**
   * Returns a document (term vector) consisting of the average of the 
   * coordinates of the documents in the cluster. Returns a null Matrix
   * if there are no documents in the cluster. 
   * @return the centroid of the cluster, or null if no documents have 
   * been added to the cluster.
   */
  public Matrix getCentroid() {
    if (docs.size() == 0) {
      return null;
    }
    Matrix d = docs.get(docNames.get(0));
    centroid = new Matrix(d.getRowDimension(), d.getColumnDimension()); 
    for (String docName : docs.keySet()) {
      Matrix docMatrix = docs.get(docName);
      centroid = centroid.plus(docMatrix);
    }
    centroid = centroid.times(1.0D / docs.size());
    return centroid;
  }

  /**
   * Returns the radius of the cluster. The radius is the average of the
   * square root of the sum of squares of its constituent document term
   * vector coordinates with that of the centroid.
   * @return the radius of the cluster.
   */
  public double getRadius() {
    double radius = 0.0D;
    if (centroid != null) {
      for (String docName : docNames) {
        Matrix doc = getDocument(docName);
        radius += doc.minus(centroid).normF();
      }
    }
    return radius / docNames.size();
  }
  
  /**
   * Returns the Eucledian distance between the centroid of this cluster
   * and the new document.
   * @param doc the document to be measured for distance.
   * @return the eucledian distance between the cluster centroid and the 
   * document.
   */
  public double getEucledianDistance(Matrix doc) {
    if (centroid != null) {
      return (doc.minus(centroid)).normF();
    }
    return 0.0D;
  }
  
  /**
   * Returns the maximum distance from the specified document to any of
   * the documents in the cluster.
   * @param doc the document to be measured for distance.
   * @return the complete linkage distance from the cluster.
   */
  public double getCompleteLinkageDistance(Matrix doc) {
    Max max = new Max();
    if (docs.size() ==0) {
      return 0.0D;
    }
    double[] distances = new double[docs.size()];
    for (int i = 0; i < distances.length; i++) {
      Matrix clusterDoc = docs.get(docNames.get(i));
      distances[i] = clusterDoc.minus(doc).normF();
    }
    return max.evaluate(distances);
  }
  
  /**
   * Returns the cosine similarity between the centroid of this cluster
   * and the new document.
   * @param doc the document to be measured for similarity.
   * @return the similarity of the centroid of the cluster to the document.
   */
  public double getSimilarity(Matrix doc) {
    if (centroid != null) {
      double dotProduct = centroid.arrayTimes(doc).norm1();
      double normProduct = centroid.normF() * doc.normF();
      return dotProduct / normProduct;
    }
    return 0.0D;
  }

  @Override
  public boolean equals(Object obj) {
    if (!(obj instanceof Cluster)) {
      return false;
    }
    Cluster that = (Cluster) obj;
    String[] thisDocNames = this.getDocumentNames().toArray(new String[0]);
    String[] thatDocNames = that.getDocumentNames().toArray(new String[0]);
    if (thisDocNames.length != thatDocNames.length) {
      return false;
    }
    Arrays.sort(thisDocNames);
    Arrays.sort(thatDocNames);
    return ArrayUtils.isEquals(thisDocNames, thatDocNames);
  }
  
  @Override
  public int hashCode() {
    String[] docNames = getDocumentNames().toArray(new String[0]);
    Arrays.sort(docNames);
    return StringUtils.join(docNames, ",").hashCode();
  }
  
  @Override
  public String toString() {
    return id + ":" + docs.keySet().toString();
  }
}

DocumentCollection.java

The DocumentCollection represents the collection of documents previously represented by the term-document matrix. It provides convenience accessor methods, and other methods to compute similarity of documents to its collection and get neighboring documents by similarity. These were used by the nearest neighbor algorithm.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
// Source: src/main/java/com/mycompany/myapp/DocumentCollection.java
package com.mycompany.myapp.clustering;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang.StringUtils;

import Jama.Matrix;

import com.mycompany.myapp.similarity.CosineSimilarity;

public class DocumentCollection {

  private Matrix tdMatrix;
  private Map<String,Matrix> documentMap;
  private List<String> documentNames;
  
  public DocumentCollection(Matrix tdMatrix, String[] docNames) {
    int position = 0;
    this.tdMatrix = tdMatrix;
    this.documentMap = new HashMap<String,Matrix>();
    this.documentNames = new ArrayList<String>();
    for (String documentName : docNames) {
      documentMap.put(documentName, tdMatrix.getMatrix(
        0, tdMatrix.getRowDimension() - 1, position, position));
      documentNames.add(documentName);
      position++;
    }
  }

  public int size() {
    return documentMap.keySet().size();
  }
  
  public List<String> getDocumentNames() {
    return documentNames;
  }
  
  public String getDocumentNameAt(int position) {
    return documentNames.get(position);
  }
  
  public Matrix getDocumentAt(int position) {
    return documentMap.get(documentNames.get(position));
  }
  
  public Matrix getDocument(String documentName) {
    return documentMap.get(documentName);
  }
  
  public void shuffle() {
    Collections.shuffle(documentNames);
  }
  
  public Map<String,Double> getSimilarityMap() {
    Map<String,Double> similarityMap = 
      new HashMap<String,Double>();
    CosineSimilarity similarity = new CosineSimilarity();
    Matrix similarityMatrix = similarity.transform(tdMatrix);
    for (int i = 0; i < similarityMatrix.getRowDimension(); i++) {
      for (int j = 0; j < similarityMatrix.getColumnDimension(); j++) {
        String sourceDoc = getDocumentNameAt(i);
        String targetDoc = getDocumentNameAt(j);
        similarityMap.put(StringUtils.join(
          new String[] {sourceDoc, targetDoc}, ":"),
          similarityMatrix.get(i, j));
      }
    }
    return similarityMap;
  }
  
  public List<String> getNeighbors(String docName,
      Map<String,Double> similarityMap, int numNeighbors) {
    if (numNeighbors > size()) {
      throw new IllegalArgumentException(
        "numNeighbors too large, max: " + size());
    }
    final Map<String,Double> differenceMap = 
      new HashMap<String,Double>();
    List<String> neighbors = new ArrayList<String>();
    neighbors.addAll(documentNames);
    for (String documentName : documentNames) {
      String key = StringUtils.join(
        new String[] {docName, documentName}, ":");
      double difference = Math.abs(similarityMap.get(key) - 1.0D);
      differenceMap.put(documentName, difference);
    }
    Collections.sort(neighbors, 
      new ByValueComparator<String,Double>(differenceMap));
    return neighbors.subList(0, numNeighbors + 1);
  }
}

IdGenerator.java

IdGenerator is a "safe" random number generator that will always return unique different numbers till its numbers are exhausted. It is seeded with a maximum number, so it will return unique numbers from 0 to the (maximum - 1) as long as can, then it starts repeating the numbers.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
// Source: src/main/java/com/mycompany/myapp/IdGenerator.java
package com.mycompany.myapp.clustering;

import java.util.HashSet;
import java.util.Random;
import java.util.Set;

public class IdGenerator {

  private int upperBound;
  
  private Random randomizer;
  private Set<Integer> ids = new HashSet<Integer>();
  
  public IdGenerator(int upperBound) {
    this.upperBound = upperBound;
    randomizer = new Random();
  }
  
  public int getNextId() {
    if (ids.size() == upperBound) {
      ids.clear();
    }
    for (;;) {
      int id = randomizer.nextInt(upperBound);
      if (ids.contains(id)) {
        continue;
      }
      ids.add(id);
      return id;
    }
  }
}

ByValueComparator.java

The ByValueComparator is a generic comparator that allows you to sort a List based on a supporting map. The idea for this came from Jeffrey Bigham's blog post Sorting Java Map by Value, although I have used Java Generics to allow it to sort any kind of Map.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
// Source: src/main/java/com/mycompany/myapp/clustering/ByValueComparator.java
package com.mycompany.myapp.clustering;

import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;

public class 
    ByValueComparator<K,V extends Comparable<? super V>> 
    implements Comparator<K> {

  private Map<K,V> map = new HashMap<K,V>();
  
  public ByValueComparator(Map<K,V> map) {
    this.map = map;
  }

  public int compare(K k1, K k2) {
    return map.get(k1).compareTo(map.get(k2));
  }
}

Test case

The test case is a simple JUnit test case that runs through all the clustering code using my little collection of seven document titles to build the term document matrix off of. Here is the code:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
// Source: src/test/java/com/mycompany/myapp/clustering/ClusteringTest.java
package com.mycompany.myapp.clustering;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.Reader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang.StringUtils;
import org.junit.Before;
import org.junit.Test;
import org.springframework.jdbc.datasource.DriverManagerDataSource;

import Jama.Matrix;

import com.mycompany.myapp.indexers.IdfIndexer;
import com.mycompany.myapp.indexers.VectorGenerator;

public class ClusteringTest {

  private Matrix tdMatrix;
  private String[] documentNames;
  
  private DocumentCollection documentCollection;
  
  @Before
  public void setUp() throws Exception {
    VectorGenerator vectorGenerator = new VectorGenerator();
    vectorGenerator.setDataSource(new DriverManagerDataSource(
      "com.mysql.jdbc.Driver", 
      "jdbc:mysql://localhost:3306/tmdb", 
      "irstuff", "irstuff"));
    Map<String,Reader> documents = 
      new LinkedHashMap<String,Reader>();
    BufferedReader reader = new BufferedReader(
      new FileReader("src/test/resources/data/indexing_sample_data.txt"));
    String line = null;
    while ((line = reader.readLine()) != null) {
      String[] docTitleParts = StringUtils.split(line, ";");
      documents.put(docTitleParts[0], new StringReader(docTitleParts[1]));
    }
    vectorGenerator.generateVector(documents);
    IdfIndexer indexer = new IdfIndexer();
    tdMatrix = indexer.transform(vectorGenerator.getMatrix());
    documentNames = vectorGenerator.getDocumentNames();
    documentCollection = new DocumentCollection(tdMatrix, documentNames);
  }

  @Test
  public void testKMeansClustering() throws Exception {
    KMeansClusterer clusterer = new KMeansClusterer();
    clusterer.setInitialClusterAssignments(new String[] {"D1", "D3"});
    List<Cluster> clusters = clusterer.cluster(documentCollection);
    System.out.println("=== Clusters from K-Means algorithm ===");
    for (Cluster cluster : clusters) {
      System.out.println(cluster.toString());
    }
  }

  @Test
  public void testQtClustering() throws Exception {
    QtClusterer clusterer = new QtClusterer();
    clusterer.setMaxRadius(0.40D);
    clusterer.setRandomizeDocuments(true);
    List<Cluster> clusters = clusterer.cluster(documentCollection);
    System.out.println("=== Clusters from QT Algorithm ===");
    for (Cluster cluster : clusters) {
      System.out.println(cluster.toString());
    }
  }

  @Test
  public void testSimulatedAnnealingClustering() throws Exception {
    SimulatedAnnealingClusterer clusterer = 
      new SimulatedAnnealingClusterer();
    clusterer.setRandomizeDocs(false);
    clusterer.setNumberOfLoops(5);
    clusterer.setInitialTemperature(100.0D);
    clusterer.setFinalTemperature(1.0D);
    clusterer.setDownhillProbabilityCutoff(0.7D);
    List<Cluster> clusters = clusterer.cluster(documentCollection);
    System.out.println(
      "=== Clusters from Simulated Annealing Algorithm ===");
    for (Cluster cluster : clusters) {
      System.out.println(cluster.toString());
    }
  }
  
  @Test
  public void testNearestNeighborClustering() throws Exception {
    NearestNeighborClusterer clusterer = new NearestNeighborClusterer();
    clusterer.setNumNeighbors(2);
    clusterer.setSimilarityThreshold(0.25);
    List<Cluster> clusters = clusterer.cluster(documentCollection);
    System.out.println("=== Clusters from Nearest Neighbor Algorithm ===");
    for (Cluster cluster : clusters) {
      System.out.println(cluster.toString());
    }
  }
  
  @Test
  public void testGeneticAlgorithmClustering() throws Exception {
    GeneticClusterer clusterer = new GeneticClusterer();
    clusterer.setNumberOfCrossoversPerMutation(5);
    clusterer.setMaxGenerations(500);
    clusterer.setRandomizeData(false);
    List<Cluster> clusters = clusterer.cluster(documentCollection);
    System.out.println("=== Clusters from Genetic Algorithm ===");
    for (Cluster cluster : clusters) {
      System.out.println(cluster.toString());
    }
  }
}

References

References to books and Internet articles that the above code is based on, in no particular order:

  1. Text Mining Application Programming, by Dr. Manu Konchady.
  2. Wikipedia article on QT (Quality Threshold) clustering.
  3. Wikipedia article on Nearest-Neighbor algorithm.
  4. Research article on Hierarchic document clustering using a genetic algorithm by Robertson, Santimetrvirul and Willet, of the University of Sheffield, UK.
  5. Research article on Genetic algorithm-based clustering technique (requires PDF download) by Maulik and Bandopadhyay, of Government Engineering College, Kalyani, India and Indian Statistical Institute, Calcutta, India.

Update 2009-04-26: In recent posts, I have been building on code written and described in previous posts, so there were (and rightly so) quite a few requests for the code. So I've created a project on Sourceforge to host the code. You will find the complete source code built so far in the project's SVN repository.