Wednesday, November 7, 2012

Belief Propagation for music recommendations with Mapreudce And Giraph




I am taking Probabilistic Graphical Models course from Coursera. once class covers inference on markov random field using label propagation, I wanted to see how this algorithm works on real data so applied label propagation on million song dataset to feel it.

here is example of label propagation.

simple arithmetic result to following.

50% from my profile + 40% from X +10% from Z = 60% rock, 40% jazz.


Generalization


Here is actual formulation. here markov random field is represented by factor graph.



Above formulation can be implemented by matrix multiplication using Map/Reduce.

definition


Lets say graph structure as G, and set of Concept C = {c1, c2,...} where each ci is consist of set of vertices in G. ci represent prior per concept i. then we want to calculate pair-wise posterior given G, C.

implementation


iteration 0: CV(|C| x |V|) x G(|V| x |V|) = CV'(|C| x |V|)
iteration 1: CV'(|C| x |V|) x G(|V| x |V|) = CV''(|C| x |V|)
..
..

each iteration is simply matrix multiplication between Concept-Vertices CV matrix and static graph structure G.

This operation quickly becomes huge. so I first implemented this as DistributedRowMatrix in Mahout.

DistributedRowMatrix class provide following APIs.
1. tranpose: this.tranpose()
2. times(DistributedRowMatrix other): this.transpose().times(other)

using above API, label propagation in Map/Reduce becomes following.

1. init Concept-Vertices matrix CV.
2. normalize Graph for convenience if G is not normalized.
3. create CV, G using DistributedRowMatrix class.
4. for # iteration CV = CV x G

Following code demonstrate how DistributedRowMatrix in mahout library becomes handy.

1:    DistributedRowMatrix CV =   
2:    createInitialCV(numItems, getTempPath("initial.class"),   
3:                    getConf());  
4:    for (int i = startIteration; i < iterations; i++) {  
5:     log.info("current iteration: {}", iteration);  
6:     /*  
7:       GNorm is DistributedRowMatrix   
8:       contain vertex-vertex graph structure.  
9:       DistributedRowMatrix.times calculate  
10:      this.transpose().times(other). so transpose itself first.  
11:      */  
12:     CV = CV.transpose().times(GNorm)  
13:    }  


Note that using Map/Reduce for iterative job is inefficient, so why not try Giraph?

in Graph-Parallel environment, problems become following.

1. each vertex has it`s neighbor edges in G.
2. at superstep 0, some vertex has C vector as value([ci:prior, cj:prior...]). if vertex has C vector, then send C to all of it`s neighbors otherwise don`t send it.
3. after superstep 0, all vertex get messages([vertex_id j, C vector]) from each of it`s neighbors.
if current vertex is Vi, and message is [Vj, Cvj] then edge(Vi, Vj) / Vi`s all Edge sum * C is added to Vi`s value Cvi vector. merge all concept-prior vectors sent to each vertex and update value(Cvi).
4. if iteration is not done, send value(Cvi) to all neighbors.

Following code is compute method in VertexProgram to implement above.


@Override  
  public void compute(Iterable<MultiLabelVectorWritable> messages) throws IOException {  
   /*  
    * each vertex has Vector as value.   
    * this Vector consist of [concept_id:probability,....]  
    * MultiLabelVectorWritable to represent   
    * (vertex j which sent this message, vertex j`s value Vector)   
    */  
   long step = getSuperstep();  
   if (step < iteration) {  
    // we still need to compute on this vertex.  
    Vector currentVector = getValue().get();  
    // create new messages from this vertex.  
    MultiLabelVectorWritable newMessage = new MultiLabelVectorWritable();  
    // set message source to this vertex.  
    newMessage.setLabels(new int[]{(int)getId().get()});  
    // vertex value vector [concept_id:probability] is sparse.  
    Vector newMessageVector = new RandomAccessSparseVector(minNonConceptVertexId);  
    // iterate messages sent to this vertex and merge them up to build this vertex`s vector.  
    for (MultiLabelVectorWritable message : messages) {  
     int messageId = message.getLabels()[0];  
     Vector conceptProbs = message.getVector();  
     float weight = getEdgeValue(new LongWritable(messageId)).get();  
     Iterator<Vector.Element> probs = conceptProbs.iterateNonZero();  
     while (probs.hasNext()) {  
      Vector.Element prob = probs.next();  
      int conceptId = prob.index();  
      currentVector.set(conceptId, prob.get() * weight);  
     }  
    }  
    // prunning for absorb  
    Iterator<Vector.Element> iter = currentVector.iterateNonZero();  
    while (iter.hasNext()) {  
     Vector.Element e = iter.next();  
     if (e.get() < gamma) {  
      continue;  
     }  
     newMessageVector.setQuick(e.index(), e.get());  
    }  
    newMessage.setVector(newMessageVector);  
    sendMessageToAllEdges(newMessage);  
   } else {  
    voteToHalt();  
   }  
  }  


I implemented demo using label propagation with open dataset from million song dataset challenge from Kaggle. This demo load taste profile graph data into memory and calculate on the fly rather than using Giraph for demonstration. here is github for Giraph/Mahout code and demo codes.

TODO: since test set for this data is opend(competition is over), I will measure truncated mean average precision to evaluation label propagation.