Sunday, November 30, 2008

IR Math in Java : Citation based Ranking

If you are a regular reader, you know that I have been working my way through Dr Manu Konchady's TMAP book in an effort to teach myself some Information Retrieval theory. This week, I talk about my experience implementing Google's PageRank algorithm in Java, as described in Chapter 6 of this book and the PageRank Wikipedia page. In the process, I also ended up developing a Sparse Matrix implementation in order to compute PageRank for real data collections, which I contributed back to the commons-math project.

The PageRank algorithm was originally proposed by Google's founders, and while it does form part of the core of what SEO types refer to as The Google Algorithm, the Algorithm is significantly more comprehensive and complex. My intent is not to reverse engineer this stuff, nor to hack it. I think the algorithm is interesting, and thought it would be worth figuring out how to code this up in Java.

The PageRank algorithm is based on the citation model (hence the title of this post), ie, if a scholarly paper is considered to be of interest, other scholarly papers cite it as a reference. Similarly, a page with good information is linked to by other pages on the web. The PageRank of a page is the sum of normalized PageRanks of pages that point to it. If a page links out to more than one page, its contribution to the target page's PageRank is its PageRank divided by the number of pages it links out to. Obviously, this is kind of a chicken and egg problem, so it needs to be solved in a recursive way.

In addition, there is a damping factor d to simulate a random surfer, who clicks on links but eventually gets bored and does a new search and starts over. To compensate for the damping factor, a constant factor c is added to the PageRank formula. The formula is thus:

  rj = c + (d * Σ ri / ni)
  where:
    rj = PageRank for page j
    d = damping factor, usually 0.85
    c = (1 - d) / N
    ri = PageRank for page i which points to page j
    ni = Number of outbound links from page i
    N = number of documents in the collection

This would translate to a set of linear equations, and could thus be re-written as a recursive matrix equation. As much as I would like to say that I arrived at this epiphany all by myself, I really just worked backwards from the formula on the Wikipedia page.

  R = C + d * A * R0
  where:
    R  = a column vector of size N, containing the ranks of pages in the collection.
    C  = a constant column vector containing [ci]
    d  = scalar damping factor
    A  = a NxN square matrix containing the initial probabilities 1/N for each (i,j)
         where page(i) links to page(j), and 0 for all other (i,j).
    R0 = the initial guess for the page ranks, all set to 1/N.

We populate the matrices on the right hand side, then compute R. At each stage we check for convergence (if it is close enough to the previous value of R). If not, we set R0 from R and recompute. Here is the code to do this:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// Source: src/main/java/com/mycompany/myapp/ranking/PageRanker.java
package com.mycompany.myapp.ranking;

import java.util.List;
import java.util.Map;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.SparseRealMatrixImpl;

public class PageRanker {

  private Map<String,Boolean> linkMap;
  private double d;
  private double threshold;
  private List<String> docIds;
  private int numDocs;
  
  public void setLinkMap(Map<String,Boolean> linkMap) {
    this.linkMap = linkMap;
  }
  
  public void setDocIds(List<String> docIds) {
    this.docIds = docIds;
    this.numDocs = docIds.size();
  }
  
  public void setDampingFactor(double dampingFactor) {
    this.d = dampingFactor;
  }
  
  public void setConvergenceThreshold(double threshold) {
    this.threshold = threshold;
  }
  
  public RealMatrix rank() throws Exception {
    // create and initialize the probability matrix, start with all
    // equal probability p(i,j) of 0 or 1/n depending on if there is 
    // a link or not from page i to j.
    RealMatrix a = new SparseRealMatrixImpl(numDocs, numDocs);
    for (int i = 0; i < numDocs; i++) {
      for (int j = 0; j < numDocs; j++) {
        String key = StringUtils.join(new String[] {
          docIds.get(i), docIds.get(j)
        }, ",");
        if (linkMap.containsKey(key)) {
          a.setEntry(i, j, 1.0D / numDocs);
        }
      }
    }
    // create and initialize the constant matrix
    RealMatrix c = new SparseRealMatrixImpl(numDocs, 1);
    for (int i = 0; i < numDocs; i++) {
      c.setEntry(i, 0, ((1.0D - d) / numDocs));
    }
    // create and initialize the rank matrix
    RealMatrix r0 = new SparseRealMatrixImpl(numDocs, 1);
    for (int i = 0; i < numDocs; i++) {
      r0.setEntry(i, 0, (1.0D / numDocs));
    }
    // solve for the pagerank matrix r
    RealMatrix r;
    int i = 0;
    for(;;) {
      r = c.add(a.scalarMultiply(d).multiply(r0));
      // check for convergence
      if (r.subtract(r0).getNorm() < threshold) {
        break;
      }
      r0 = r.copy();
      i++;
    }
    return r;
  }
}

Here is the JUnit code to call the class. We set up the damping factor and the convergence threshold. We use the picture of the graph on the Wikipedia PageRank article as our initial dataset. The dataset is represented as a comma-delimited pairs of linked page ids.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
  @Test
  public void testRankWithToyData() throws Exception {
    Map<String,Boolean> linkMap = getLinkMapFromDatafile(
      "src/test/resources/pagerank_links.txt");
    PageRanker ranker = new PageRanker();
    ranker.setLinkMap(linkMap);
    ranker.setDocIds(Arrays.asList(new String[] {
      "1", "2", "3", "4", "5", "6", "7"
    }));
    ranker.setDampingFactor(0.85D);
    ranker.setConvergenceThreshold(0.001D);
    RealMatrix pageranks = ranker.rank();
    log.debug("pageRank=" + pageranks.toString());
  }

  private Map<String,Boolean> getLinkMapFromDatafile(String filename) 
      throws Exception {
    Map<String,Boolean> linkMap = new HashMap<String,Boolean>();
    BufferedReader reader = new BufferedReader(new FileReader(filename));
    String line;
    while ((line = reader.readLine()) != null) {
      if (StringUtils.isEmpty(line) || line.startsWith("#")) {
        continue;
      }
      String[] pairs = StringUtils.split(line, "\t");
      linkMap.put(pairs[0], Boolean.TRUE);
    }
    return linkMap;
  }

You may have noticed that I am using calls to SparseRealMatrixImpl, which does not exist in the commons-math codebase at the time of this writing. The reason I implemented the SparseRealMatrixImpl was because when I try to run the algorithm against a real interlinked data collection of about 6000+ documents, I would consistently get an Out Of Memory Exception with the code that used a RealMatrixImpl (which uses a two dimensional double array as its backing store).

The SparseRealMatrixImpl subclasses RealMatrixImpl, but uses a Map<Point,Double> as its backing store. The Point class is a simple struct type data holder private class that encapsulates the row and column number for the data element. Only non-zero matrix elements are actually stored in the Map. This works out because the largest matrix (A) contains mostly zeros, ie comparatively few pages are actually linked. The patch is available here.

Update 2009-04-26: In recent posts, I have been building on code written and described in previous posts, so there were (and rightly so) quite a few requests for the code. So I've created a project on Sourceforge to host the code. You will find the complete source code built so far in the project's SVN repository.

Saturday, November 22, 2008

Jab - Inflict pain on your Java Application

When load testing web applications, I usually take a few URLs and run them through tools such as Apache Bench (ab) or more recently, Siege. Recently, however, I needed to compare performance under load for code querying data from a MySQL database table versus a Lucene index. I could have built a simple web-based interface around this code and used the tools mentioned above, but it seemed like too much work, so I looked around to see if there was anything in library form that I could use to load test Java components.

The first result on my Google search came up with information about Mike Clark's JUnitPerf project, which consists of a set of Test decorators designed to work with JUnit 3.x. Since I use JUnit 4.x, I would have to write JUnit 3.x style code and run it under JUnit 4.x, which is something I'd rather not do unless really, really have to. That was not the biggest problem, however. Since both my components depended on external resources, they would have to be pre-instantiated for the test times to be realistic. Since JUnitPerf wraps an existing Test, which then runs within the JUnitRunner, the instantiation would have to be done either within the @Test or @Before equivalent methods, or I would have to write another @BeforeClass style JUnit 3.8 decorator. In the first two cases, tests run with JUnitPerf's LoadTest would include the resource setup times. So I decided to write my own little framework which was JUnit-agnostic and yet runnable from within Junit 4.x, and which allowed me to setup resources outside the code being tested.

Overview

My framework borrows the idea of using the Decorator pattern from JUnitPerf. It consists of 2 interfaces and 5 different Test Decorator implementations, and couple of utility classes. The only dependencies are Java 1.5+, commons-lang, commons-math, commons-logging and log4j. I call it jab (JAva Bench), drawing inspiration for the name from Apache Bench. It can also be thought of as something that inflicts pain on your Java application by putting it under load (hence the title of this post).

Component Descriptions

ITestable

The ITestable interface provides the template which a peice of code that wishes to be tested with jab needs to implement. The resources argument passes in all the pre-instantiated resources that are needed by the ITestable to execute. Further down, I show you a real-life example, which incidentally was also the code that drove the building of this framework - there are two example implementations of ITestable in there.

1
2
3
4
5
6
7
8
// Source: src/main/java/com/mycompany/jab/ITestable.java
package com.mycompany.jab;

import java.util.Map;

public interface ITestable {
  public void execute(Map<String,Object> resources) throws Exception;
}

ITest

ITest is the interface that all our Test instances implement. This is really something internal to the framework, providing a template for people writing new Test implementations.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
// Source: src/main/java/mycompany/jab/ITest.java
package com.mycompany.jab;

import java.util.List;

public interface ITest extends Cloneable {
  public void runTest() throws Exception;
  public Double getAggregatedObservation();
  public List<Double> getObservations();
  public Object clone();
}

SingleTest

This is the most basic (and central) implementation of ITest. All it does is wrap the ITestable.execute() call within two calls to System.currentTimeMillis() to grab the wallclock times, and calculate and update the elapsed times into the appropriate counters.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
// Source: src/main/java/com/mycompany/jab/SingleTest.java
package com.mycompany.jab;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
 * Models a single test. All it does is attach timers around the 
 * ITestable.execute() call.
 */
public class SingleTest implements ITest {

  private final Log log = LogFactory.getLog(getClass());
  
  private Class<? extends ITestable> testableClass;
  private Map<String,Object> resources;
  private ITestable testable;
  
  private List<Double> observations = new ArrayList<Double>();
  
  public SingleTest(Class<? extends ITestable> testableClass, 
      Map<String,Object> resources) throws Exception {
    this.testableClass = testableClass;
    this.resources = resources;
    this.testable = testableClass.newInstance();
  }

  public Double getAggregatedObservation() {
    return getObservations().get(0);
  }
  
  public List<Double> getObservations() {
    return observations;
  }

  public void runTest() throws Exception {
    try {
      observations.clear();
      long start = System.currentTimeMillis();
      testable.execute(resources);
      long stop = System.currentTimeMillis();
      observations.add(new Double(stop - start));
    } catch (Exception e) {
      observations.add(-1.0D); // negative number indicate that it failed
      e.printStackTrace();
    }
  }
  
  @Override
  public Object clone() {
    try {
      return new SingleTest(this.testableClass, this.resources);
    } catch (Exception e) {
      log.error("Cloning object of class: " + this.getClass() + 
        " failed", e);
      return null;
    }
  }
}

This is the only ITest implementation that has access to the ITestable. More complex ITest implementations wrap a SingleTest. Instantiating a SingleTest is simple. The example below shows it being instantiated with an ITestable implementation called MockTestable.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
    // instantiate and populate a Map<String,Object> in
    // your @Before annotated method
    resources.put("text", "Some random text");
    ...
    // instantiate a SingleTest in your @Test annotated method
    // and run it
    SingleTest test = new SingleTest(MockTestable.class, resources);
    test.runTest();
    // return the aggregated observation
    double elapsed = test.getAggregatedObservation();

AggregationPolicy

The next two implementations are really decorators for the SingleTest, which can be used to run the underlying test in serial or in parallel. Now that we will have multiple elapsed time observations, we need to be able to control what we will do with these multiple observations. The default is to expose the average of these observations using the getAggregatedObservation() method. However, this is tunable, using the AggregationPolicy argument in the constructor. The AggregationPolicy is a simple enum as shown below:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
// Source: src/main/java/com/mycompany/jab/AggregationPolicy.java
package com.mycompany.jab;

/**
 * Enumerates the possible aggregation policies for observations returned
 * from RepeatedTest and ConcurrentTest (and other combo tests in the 
 * future).
 */
public enum AggregationPolicy {

  SUM, AVERAGE, MAX, MIN, VARIANCE, STDDEV, COUNT, FAILED, SUCCEEDED;
  
}

Most of the values are self explanatory, corresponding to various common statistical measures. The FAILED and SUCCEEDED signals that the number of the failures and successful runs should be counted and aggregated.

Aggregator

The Aggregator provides utility methods to actually do the aggregation that is requested using the AggregationPolicy. We rely on the StatUtils class in commons-math to do the heavy lifting.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
// Source: src/main/java/com/mycompany/jab/Aggregator.java
package com.mycompany.jab;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.math.stat.StatUtils;

/**
 * Aggregates a List of Double observations into a single Double value
 * based on the specified aggregation policy.
 */
public class Aggregator {

  private double[] failures;
  private double[] successes;

  public Aggregator(List<Double> observations) {
    List<Double> sobs = new ArrayList<Double>();
    List<Double> fobs = new ArrayList<Double>();
    for (Iterator<Double> sit = observations.iterator(); 
        sit.hasNext();) {
      Double obs = sit.next();
      if (obs < 0.0D) {
        fobs.add(obs);
      } else {
        sobs.add(obs);
      }
    }
    this.successes = ArrayUtils.toPrimitive(sobs.toArray(new Double[0]));
    this.failures = ArrayUtils.toPrimitive(fobs.toArray(new Double[0]));
  }

  public Double aggregate(AggregationPolicy policy) {
    switch(policy) {
    case SUM:
      return StatUtils.sum(successes);
    case MAX:
      return StatUtils.max(successes);
    case MIN:
      return StatUtils.min(successes);
    case VARIANCE:
      return StatUtils.variance(successes);
    case STDDEV:
      return Math.sqrt(StatUtils.variance(successes));
    case COUNT:
      return ((double) (successes.length + failures.length));
    case FAILED:
      return ((double) failures.length);
    case SUCCEEDED:
      return ((double) successes.length);
    case AVERAGE:
    default:
      return StatUtils.mean(successes);
    }
  }
}

RepeatedTest

A RepeatedTest decorates an ITest, usually a SingleTest. All it does is run the decorated ITest a specified number of times, collecting and aggregating the elapsed time observations. The type of aggregation is specified with an AggregationPolicy. The default AggregationPolicy is AVERAGE, meaning that the aggregated observation is the average of the individual aggregated observations from the ITests.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
// Source: src/main/java/com/mycompany/jab/RepeatedTest.java
package com.mycompany.jab;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
 * Models a test that consists of running a test a fixed number of times
 * in series.
 */
public class RepeatedTest implements ITest {

  private final Log log = LogFactory.getLog(getClass());
  
  private ITest test;
  private int numIterations;
  private AggregationPolicy policy;
  private long delayMillis;
  
  private List<Double> observations = new ArrayList<Double>();
  
  public RepeatedTest(ITest test, int numIterations) {
    this(test, numIterations, AggregationPolicy.AVERAGE, 0L);
  }
  
  public RepeatedTest(ITest test, int numIterations, 
      AggregationPolicy policy) {
    this(test, numIterations, policy, 0L);
  }
  
  public RepeatedTest(ITest test, int numIterations, 
      AggregationPolicy policy, long delayMillis) {
    this.test = test;
    this.numIterations = numIterations;
    this.policy = policy;
    this.delayMillis = delayMillis;
  }

  public Double getAggregatedObservation() {
    Aggregator aggregator = new Aggregator(getObservations());
    return aggregator.aggregate(policy);
  }

  public List<Double> getObservations() {
    return observations;
  }

  public void runTest() throws Exception {
    ITest clone = (ITest) test.clone();
    for (int i = 0; i < numIterations; i++) {
      clone.runTest();
      observations.add(clone.getAggregatedObservation());
      if (delayMillis > 0L) {
        try { Thread.sleep(delayMillis); }
        catch (InterruptedException e) {;}
      }
    }
  }
  
  @Override
  public Object clone() {
    return new RepeatedTest(this.test, this.numIterations, this.policy, 
      this.delayMillis);
  }
}

As you can see, there three constructors that you can use. The simplest one specifies the ITest and the number of repetitions, the second one overrides the default AggregationPolicy to be used, and the third one specifies that the test should wait a specified number of milliseconds between ITest invocations. Here are some usage examples.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
    // run the SingleTest 10 times, no delay, default aggregation
    RepeatedTest test1 = new RepeatedTest(
      new SingleTest(MockTestable.class, resources), 10);

    // run the SingleTest 10 times, no delay, override aggregation
    // policy to return the sum of the 10 observations
    RepeatedTest test2 = new RepeatedTest(
      new SingleTest(MockTestable.class, resources), 10,
      AggregationPolicy.SUM);

    // run the SingleTest 10 times, with default aggregation,
    // and a 10ms delay between each invocation
    RepeatedTest test3 = new RepeatedTest(
      new SingleTest(MockTestable.class, resources), 10,
      AggregationPolicy.AVERAGE, 10L);

ConcurrentTest

A ConcurrentTest decorates an ITest and runs a specific number of these ITests concurrently. Like RepeatedTest, its default AggregationPolicy is AVERAGE, which can be overriden. Also like RepeatedTest, it allows you to specify a delay between spawning successive parallel ITest instances.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// Source: src/main/java/com/mycompany/jab/ConcurrentTest.java
package com.mycompany.jab;

import java.util.List;
import java.util.ArrayList;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
 * Models multiple running concurrent jobs.
 */
public class ConcurrentTest implements ITest {

  private final Log log = LogFactory.getLog(getClass());
  
  private ITest test;
  private int numConcurrent;
  private AggregationPolicy policy;
  private long delayMillis;
  
  private List<Callable<ITest>> callables = null;
  
  private List<Double> observations = new ArrayList<Double>();

  public ConcurrentTest(ITest test, int numConcurrent) throws Exception {
    this(test, numConcurrent, AggregationPolicy.AVERAGE, 0L);
  }
  
  public ConcurrentTest(ITest test, int numConcurrent, 
      AggregationPolicy policy) throws Exception {
    this(test, numConcurrent, policy, 0L);
  }
  
  public ConcurrentTest(ITest test, int numConcurrent, 
      AggregationPolicy policy, long delayMillis) throws Exception {
    this.test = test;
    this.numConcurrent = numConcurrent;
    this.delayMillis = delayMillis;
    this.policy = policy;
    this.callables = 
      new ArrayList<Callable<ITest>>(numConcurrent);
    for (int i = 0; i < numConcurrent; i++) {
      final ITest clone = (ITest) this.test.clone();
      callables.add(new Callable<ITest>() {
        public ITest call() throws Exception {
          clone.runTest();
          return clone;
      }});
    }
  }

  public Double getAggregatedObservation() {
    Aggregator aggregator = new Aggregator(getObservations());
    return aggregator.aggregate(policy);
  }

  public List<Double> getObservations() {
    return observations;
  }

  public void runTest() throws Exception {
    ExecutorService executor = Executors.newFixedThreadPool(numConcurrent);
    List<Future<ITest>> tests = 
      new ArrayList<Future<ITest>>();
    for (int i = 0; i < numConcurrent; i++) {
      Future<ITest> test = executor.submit(callables.get(i));
      tests.add(test);
      if (delayMillis > 0L) {
        try { Thread.sleep(delayMillis); }
        catch (InterruptedException e) {;}
      }
    }
    for (Future<ITest> future : tests) {
      future.get();
    }
    executor.shutdown();
    for (int i = 0; i < numConcurrent; i++) {
      ITest test = tests.get(i).get();
      observations.add(test.getAggregatedObservation());
    }
  }
  
  @Override
  public Object clone() {
    try {
      return new ConcurrentTest(this.test, this.numConcurrent, this.policy);
    } catch (Exception e) {
      log.error("Cloning object of class: " + this.getClass() + 
        " failed", e);
      return null;
    }
  }
}

I picked up some pointers on the new Java 1.5 threading style from this blog post on recursor. As you can see, the constructors are similar to those for RepeatedTest. Here are some usage examples:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
    // run SingleTest in parallel with 10 threads, no delay
    // between thread spawning
    ConcurrentTest test1 = new ConcurrentTest(
      new SingleTest(MockTestable.class, resources), 10);

    // run SingleTest in parallel with 10 threads, override the
    // AggregationPolicy to SUM, no delay between thread spawning
    ConcurrentTest test2 = new ConcurrentTest(
      new SingleTest(MockTestable.class, resources), 10,
      AggregationPolicy.SUM);

    // run SingleTest in parallel with 10 threads, default 
    // AggregationPolicy, with delay of 10ms between thread spawning
    ConcurrentTest test3 = new ConcurrentTest(
      new SingleTest(MockTestable.class, resources), 10,
      AggregationPolicy.AVERAGE, 10L);

TimedTest

A TimedTest is passed a ITest and a maximum allowed time. The underlying ITest is allowed to run to completion, and if the aggregated observation exceeds the maximum allowed time, it is recorded as a failure.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
// Source: src/main/java/com/mycompany/jab/TimedTest.java
package com.mycompany.jab;

import java.util.ArrayList;
import java.util.List;

/**
 * Models a test which has an upper time limit. If the test runs beyond
 * that period, it is counted as a failure.
 */
public class TimedTest implements ITest {

  private ITest test;
  private long maxElapsedMillis;
  private AggregationPolicy policy;
  private List<Double> observations = new ArrayList<Double>();
  
  public TimedTest(ITest test, long maxElapsedMillis) {
    this(test, maxElapsedMillis, AggregationPolicy.AVERAGE);
  }

  public TimedTest(ITest test, long maxElapsedMillis, 
      AggregationPolicy policy) {
    this.test = test;
    this.maxElapsedMillis = maxElapsedMillis;
    this.policy = policy;
  }
  
  public Double getAggregatedObservation() {
    Aggregator aggregator = new Aggregator(observations);
    return aggregator.aggregate(policy);
  }

  public List<Double> getObservations() {
    return observations;
  }

  public void runTest() throws Exception {
    test.runTest();
    List<Double> observations = test.getObservations();
    if (getAggregatedObservation() > maxElapsedMillis) {
      observations.add(-1.0D);
    } else {
      observations.add(test.getAggregatedObservation());
    }
  }
  
  @Override
  public Object clone() {
    return new TimedTest(this.test, this.maxElapsedMillis, this.policy);
  }
}

Calling patterns are similar to the RepeatedTest and ConcurrentTest decorators. Here are some examples:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
    // Construct a timed test, setting the maximum allowed time to
    // 10ms, and count the number of failures.
    TimedTest test1 = new TimedTest(
      new SingleTest(MockTestable.class, resources), 10L,
      AggregationPolicy.FAILED);

    // Construct a timed test, setting the maximum allowed time to
    // 2000ms (2s).
    TimedTest test2 = new TimedTest(
      new SingleTest(MockTestable.class, resources), 2000L);

ThroughputTest

This test measures the througput, i.e. the number of times the test ran within the maximum allowed time period. This is useful when you want to stress test a component for a given time period, say 10mins, and see how many times it ran.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
// Source: src/main/java/com/mycompany/jab/ThroughputTest.java
package com.mycompany.jab;

import java.util.ArrayList;
import java.util.List;

/**
 * Given a test and a maximum time to run, returns the number of times
 * the test was run in the time provided.
 */
public class ThroughputTest implements ITest {

  private ITest test;
  private long maxElapsedMillis;
  private List<Double> observations = new ArrayList<Double>();
  private AggregationPolicy policy;

  public ThroughputTest(ITest test, long maxElapsedMillis) {
    this(test, maxElapsedMillis, AggregationPolicy.AVERAGE);
  }

  public ThroughputTest(ITest test, long maxElapsedMillis, 
      AggregationPolicy policy) {
    this.test = test;
    this.maxElapsedMillis = maxElapsedMillis;
    this.policy = policy;
  }

  public Double getAggregatedObservation() {
    Aggregator aggregator = new Aggregator(this.observations);
    return aggregator.aggregate(policy);
  }

  public List<Double> getObservations() {
    return observations;
  }

  public void runTest() throws Exception {
    long totalElapsed = 0L;
    for (;;) {
      long start = System.currentTimeMillis();
      this.test.runTest();
      long end = System.currentTimeMillis();
      long elapsed = end - start;
      observations.add((double) elapsed);
      totalElapsed += elapsed;
      if (totalElapsed > maxElapsedMillis) {
        break;
      }
    }
  }

  @Override
  public Object clone() {
    return new ThroughputTest(this.test, this.maxElapsedMillis, this.policy);
  }
}

And here is an example of how to call this. As you can see, you can nest decorators fairly deep, although it is left to you to determine what kind of nesting make sense.

1
2
3
4
5
6
    // Declare a test that runs for 15s, which consists of 5 parallel
    // invocations of a set of 5 serial invocations of the SingleTest
    ThroughputTest test = new ThroughputTest(
      new ConcurrentTest(new RepeatedTest(
      new SingleTest(MockTestable.class, resources), 5),
      5), 15000L);

A real-life example

I tested the code above with a MockTestable that slept for 10s to simulate some kind of load. But the whole reason I built this was so I could do this sort of thing on real-life components. Here is a JUnit test that runs searches against 2 components and compares their performance under load. The searchers are modeled as ITest implementations.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
// Source: src/test/java/com/mycompany/jab/example/MySQLSearchTestable.java
package com.mycompany.jab.example;;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Queue;

import javax.sql.DataSource;

import com.mycompany.jab.ITestable;

public class MySQLSearchTestable implements ITestable {

  public void execute(Map<String,Object> resources) throws Exception {
    // get references to various resources
    Queue<String> mysqlQueue = 
      (Queue<String>) resources.get("mysqlQueue");
    DataSource dataSource = (DataSource) resources.get("dataSource");
    String imuidQuery = (String) resources.get("sqlQuery");
    Integer preparedStmtFetchSize = 
      (Integer) resources.get("preparedStmtFetchSize");
    String randomImuid = mysqlQueue.poll();
    // do the work
    List<Result> results = new ArrayList<Result>();
    Connection conn = dataSource.getConnection();
    PreparedStatement ps = conn.prepareStatement(imuidQuery);
    ps.setFetchSize(preparedStmtFetchSize);
    ps.setString(1, randomImuid);
    ResultSet rs = null;
    try {
      rs = ps.executeQuery();
      while (rs.next()) {
        // populate a Result object
        Result result = new Result();
        // result.field = rs.getString(n) type calls 
        // deliberately removed
        ...
        results.add(result);
      }
    } finally {
      if (rs != null) {
        try { rs.close(); } catch (Exception e) {;}
      }
      if (ps != null) {
        try { ps.close(); } catch (Exception e) {;}
      }
      if (conn != null) {
        try { conn.close(); } catch (Exception e) {;}
      }
    }
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
// Source: src/test/java/com/mycompany/jab/example/LuceneSearchTestable.java
package com.mycompany.jab.example;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;

import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.Hits;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TermQuery;

import com.mycompany.jab.ITestable;

public class LuceneSearchTestable implements ITestable {

  public void execute(Map<String,Object> resources) throws Exception {
    // get references to various resources
    Queue<String> luceneQueue = 
      (Queue<String>) resources.get("luceneQueue");
    IndexSearcher searcher = (IndexSearcher) resources.get("searcher");
    // start the test
    List<Result> results = new ArrayList<Result>();
    String id = luceneQueue.poll();
    Hits hits = searcher.search(new TermQuery(new Term("myId", id)));
    int numHits = hits.length();
    for (int i = 0; i < numHits; i++) {
      Result result = new Result();
      // result.field = doc.get("fieldName") type calls 
      // deliberately removed
      ...
      results.add(result);
    }
  }
}

As you can see, these two testables are just some simple code to run an SQL query against a database table and a TermQuery against a Lucene index. All the expensive resources (and some inexpensive ones) are passed to the ITestable via the resources map. The resources are created in the calling JUnit test, which also uses the jab mini-framework to build a pair of progressively larger ConcurrentTest by varying the number of users. Each ConcurrentTest is composed of 10 RepeatedTest, which invoke one of the two ITestables shown above. The observations from each run are aggregated and written out into a flat file in tab-delimited format. Here is the code for the JUnit test.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
// Source: src/test/java/com/mycompany/jab/example/JabExampleTest.java
package com.mycompany.jab.example;

import java.io.FileOutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Queue;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;

import javax.sql.DataSource;

import org.apache.commons.dbcp.ConnectionFactory;
import org.apache.commons.dbcp.DriverManagerConnectionFactory;
import org.apache.commons.dbcp.PoolableConnectionFactory;
import org.apache.commons.dbcp.PoolingDataSource;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.pool.ObjectPool;
import org.apache.commons.pool.impl.GenericObjectPool;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.IndexSearcher;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;

import com.mycompany.jab.AggregationPolicy;
import com.mycompany.jab.Aggregator;
import com.mycompany.jab.ConcurrentTest;
import com.mycompany.jab.RepeatedTest;
import com.mycompany.jab.SingleTest;

/**
 * Harness to compare Lucene and MySQL cp index performance.
 */
public class JabExampleTest {

  // ======== Configuration Parameters =========
  
  private static final String INDEX_PATH = "/path/to/index";
  private static final String DATA_FILE = "/tmp/output.dat";
  private static final String DB_URL = "jdbc:mysql://localhost:3306/test";
  private static final String DB_USER = "root";
  private static final String DB_PASS = "secret";
  private static final int DB_POOL_INITIAL_SIZE = 1;
  private static final int DB_POOL_MAX_ACTIVE = 50;
  private static final int DB_POOL_MAX_WAIT = 5000;

  private static final int NUM_SEARCHES_PER_USER = 10;
  private static final int[] NUM_CONCURRENT_USERS = 
    new int[] {5,10,15,20,25,30,35,40,45,50,55,60,65,70,75,80,85,90,95,100};
  
  // ========= global vars for internal use ============
  
  private final Log log = LogFactory.getLog(getClass());
  
  private static final DecimalFormat DF = new DecimalFormat("####");
  
  private static IndexSearcher searcher;
  private static DataSource dataSource;
  private static List<String> uniqueIds;
  private static Random randomizer;
  private static PrintWriter outputWriter;

  private static final String MYSQL_QUERY = "select * from foo where ...";
  private static final int PREPARED_STATEMENT_FETCH_SIZE = 200;

  @BeforeClass
  public static void setUpBeforeTest() throws Exception {
    uniqueIds = getUniqueIds(INDEX_PATH);
    randomizer = new Random();
    searcher = new IndexSearcher(INDEX_PATH);
    dataSource = getPoolingDataSource();
    outputWriter = new PrintWriter(new OutputStreamWriter(
      new FileOutputStream(DATA_FILE)));
  }

  @AfterClass
  public static void tearDownAfterClass() throws Exception {
    searcher.close();
    outputWriter.flush();
    outputWriter.close();
  }

  @Test
  public void testCompareSearches() throws Exception {
    // set up reporting
    outputWriter.println(StringUtils.join(new String[] {
      "NUM-USERS",
      "LUCENE-AVG",
      "LUCENE-MAX",
      "LUCENE-MIN",
      "LUCENE-FAIL",
      "MYSQL-AVG",
      "MYSQL-MAX",
      "MYSQL-MIN",
      "MYSQL-FAIL"
    }, "\t"));
    // set up resources
    Map<String,Object> resources = new HashMap<String,Object>();
    resources.put("searcher", searcher);
    resources.put("dataSource", dataSource);
    resources.put("sqlQuery", MYSQL_QUERY);
    resources.put("preparedStmtFetchSize", PREPARED_STATEMENT_FETCH_SIZE);
    for (int numConcurrent : NUM_CONCURRENT_USERS) {
      // compute the random ids
      List<String> randomIds = 
        getRandomIds(numConcurrent * NUM_SEARCHES_PER_USER);
      Queue<String> luceneQueue = 
        new ConcurrentLinkedQueue<String>();
      luceneQueue.addAll(randomIds);
      Queue<String> mysqlQueue = 
        new ConcurrentLinkedQueue<String>();
      mysqlQueue.addAll(randomIds);
      resources.put("luceneQueue", luceneQueue);
      resources.put("mysqlQueue", mysqlQueue);
      // set up the tests
      log.debug("Running test with " + numConcurrent + " users...");
      ConcurrentTest luceneTest = new ConcurrentTest(
        new RepeatedTest(new SingleTest(
        LuceneSearchTestable.class, resources), 
        NUM_SEARCHES_PER_USER), numConcurrent);
      ConcurrentTest mysqlTest = new ConcurrentTest(
        new RepeatedTest(new SingleTest(
        MySQLSearchTestable.class, resources),
        NUM_SEARCHES_PER_USER), numConcurrent);
      // run them
      luceneTest.runTest();
      mysqlTest.runTest();
      // collect information and output to report
      Aggregator luceneAggregator = 
        new Aggregator(luceneTest.getObservations());
      Aggregator mysqlAggregator = 
        new Aggregator(mysqlTest.getObservations());
      outputWriter.println(StringUtils.join(new String[] {
        String.valueOf(numConcurrent),
        DF.format(luceneAggregator.aggregate(AggregationPolicy.AVERAGE)),
        DF.format(luceneAggregator.aggregate(AggregationPolicy.MAX)),
        DF.format(luceneAggregator.aggregate(AggregationPolicy.MIN)),
        DF.format(luceneAggregator.aggregate(AggregationPolicy.FAILED)),
        DF.format(mysqlAggregator.aggregate(AggregationPolicy.AVERAGE)),
        DF.format(mysqlAggregator.aggregate(AggregationPolicy.MAX)),
        DF.format(mysqlAggregator.aggregate(AggregationPolicy.MIN)),
        DF.format(mysqlAggregator.aggregate(AggregationPolicy.FAILED))
      }, "\t"));
    }
  }
  
  // ========= Methods to build and populate resources as applicable ========
  
  private static DataSource getPoolingDataSource() throws Exception {
    ObjectPool connectionPool = new GenericObjectPool(null);
    Properties connProps = new Properties();
    connProps.put("user", DB_USER);
    connProps.put("password", DB_PASS);
    connProps.put("initialSize", String.valueOf(DB_POOL_INITIAL_SIZE));
    connProps.put("maxActive", String.valueOf(DB_POOL_MAX_ACTIVE));
    connProps.put("maxWait", String.valueOf(DB_POOL_MAX_WAIT));
    Class.forName("com.mysql.jdbc.Driver");
    ConnectionFactory connectionFactory = 
      new DriverManagerConnectionFactory(DB_URL, connProps);
    PoolableConnectionFactory pcf = new PoolableConnectionFactory(
      connectionFactory, connectionPool, null, null, false, false);
    return new PoolingDataSource(connectionPool);
  }

  private static List<String> getUniqueIds(String cpIndexPath) 
      throws Exception {
    Set<String> uniqueImuidSet = new HashSet<String>();
    IndexReader reader = IndexReader.open(cpIndexPath);
    int numDocs = reader.maxDoc();
    for (int i = 0; i < numDocs; i++) {
      Document doc = reader.document(i);
      uniqueImuidSet.add(doc.get("myId"));
    }
    List<String> idlist = new ArrayList<String>();
    idlist.addAll(uniqueImuidSet);
    reader.close();
    return idlist;
  }

  private List<String> getRandomIds(int numRandom) {
    List<String> randomImuids = new ArrayList<String>();
    for (int i = 0; i < numRandom; i++) {
      int random = randomizer.nextInt(uniqueIds.size());
      randomImuids.add(uniqueIds.get(random));
    }
    return randomImuids;
  }
}

Test Results

Although the results of this exercise is not relevant for this post (since I am just describing the framework and how to use it), I thought it would be interesting, so I am including it here.

NUM-USERSLUCENE-AVGLUCENE-MAXLUCENE-MINLUCENE-FAILMYSQL-AVGMYSQL-MAXMYSQL-MINMYSQL-FAIL
539423604954430
10111940142390
15121930142160
20153410202860
251935602338120
30254440263790
35254450274180
404581160446250
454369904564100
503373004266110
55286720335960
60377230416920
65438430508140
70336370387030
755395170538480
80105232006411350
8552102606210040
90711464071110130
9556991206510620
100641281107312160

To visualize the data, I used the following gnuplot script to transform the average time observations into a graph.

1
2
3
4
5
6
7
set multiplot
set key off
set xlabel '#-users'
set ylabel 'response(ms)'
set yrange [0:150]
plot 'perfcomp.dat' using 1:2 with lines lt 1
plot 'perfcomp.dat' using 1:6 with lines lt 2

The graph is shown below. Not too many surprises here, there are quite a few people who've reached the same conclusion, that it is as performant, and often more convenient, to serve results of exact queries from a MySQL database than from a Lucene index.

Conclusion

Prior to this, I would either resort to wrapping a component in a web interface and used ab or siege, or written JUnit tests that did the multithreading inline with the code being tested. I think this approach is cleaner and perhaps more scalable, since it separates out the component being tested from the actual test parameters, allowing you to model more complex scenarios.

I am curious as to what other people do in similar situations. If you have had similar needs, I would appreciate knowing how you approached it. I am also curious if other people think this is complete enough to release as a project - I don't really want the headache of maintaining and improving the project, I just figure that it may be useful to have it somewhere where people can download it, use and maybe improve it and check the fixes/features back in.

Also, I don't normally write multi-threaded code, just because its not needed that often for the stuff I work on, so there may be obvious bugs that a reader who does multi-threaded stuff for a living (and some that do not) may spot immediately. If so, please let me know and I will make the necessary corrections.