用Graph实现了ALS,原理和实现都很简单。
代码如下
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.StorageLevels;
import org.apache.spark.graphx.Edge;
import org.apache.spark.graphx.EdgeContext;
import org.apache.spark.graphx.Graph;
import org.apache.spark.graphx.TripletFields;
import org.apache.spark.graphx.VertexRDD;
import org.apache.spark.rdd.RDD;
import org.apache.spark.util.random.XORShiftRandom;
import org.netlib.util.intW;
import com.github.fommil.netlib.BLAS;
import com.github.fommil.netlib.LAPACK;
import scala.Option;
import scala.Tuple2;
import scala.reflect.ClassManifestFactory;
import scala.reflect.ClassTag;
import scala.runtime.AbstractFunction1;
import scala.runtime.AbstractFunction2;
import scala.runtime.AbstractFunction3;
import scala.runtime.BoxedUnit;
import scala.util.hashing.package$;
/**
*
*/
public class ALSGraphTest
{
private static final Long SEED = 10L;
private static final BLAS blas = BLAS.getInstance( );
private static final LAPACK lapack = LAPACK.getInstance( );
private static final int MAXITER = 10;
private static final int RANK = 2;
private static final double LAMBDA = 0.01;
private static final ClassTag<NodeInfo> nodeTag = ClassManifestFactory.classType( NodeInfo.class );
private static final ClassTag<SendValue> sendValueTag = ClassManifestFactory.classType( SendValue.class );
private static final ClassTag<Float> floatTag = ClassManifestFactory.classType( Float.class );
public static void main( String[] args )
{
SparkConf conf = new SparkConf( ).setMaster( "local" ).setAppName( "ALS Graph Test" );
JavaSparkContext ctx = new JavaSparkContext( conf );
List<Long> ids = Arrays.asList( 1L, 2L, 3L );
Random seedGen = new XORShiftRandom( SEED );
List<Tuple2<Long, float[]>> userNodes = initFactors( ids, RANK, seedGen.nextLong( ) );
List<Tuple2<Long, float[]>> itemNodes = Arrays.asList( new Tuple2<Long, float[]>( 11L, null ), new Tuple2<Long, float[]>( 12L, null ), new Tuple2<Long, float[]>( 13L, null ) );
List<Tuple2<Object, NodeInfo>> allNodes = createNode( userNodes, 0 );
allNodes.addAll( createNode( itemNodes, 1 ) );
List<Edge<Float>> edges = Arrays.asList( new Edge<Float>( 1L, 11L, 3.0f ), new Edge<Float>( 1L, 12L, 4.0f ), new Edge<Float>( 2L, 12L, 3.0f ), new Edge<Float>( 2L, 13L, 4.5f ), new Edge<Float>( 3L, 11L, 3.0f ), new Edge<Float>( 3L, 12L, 2.0f ) );
RDD<Tuple2<Object, NodeInfo>> nodesRDD = ctx.parallelize( allNodes ).rdd( );
RDD<Edge<Float>> edgesRDD = ctx.parallelize( edges ).rdd( );
Graph<NodeInfo, Float> g = Graph.apply( nodesRDD, edgesRDD, new NodeInfo( -1, null ), StorageLevels.MEMORY_ONLY_SER, StorageLevels.MEMORY_ONLY_SER, nodeTag, floatTag );
g = g.outerJoinVertices( g.ops( ).inDegrees( ), new MyFunction3<Object, NodeInfo, Option<Object>, NodeInfo>( )
{
@Override
public NodeInfo apply( Object t0, NodeInfo t1, Option<Object> t2 )
{
if ( t2.isDefined( ) )
{
t1.inDegrees = (Integer) t2.get( );
}
return t1;
}
}, ClassManifestFactory.classType( Object.class ), nodeTag, null );
g = g.outerJoinVertices( g.ops( ).outDegrees( ), new MyFunction3<Object, NodeInfo, Option<Object>, NodeInfo>( )
{
@Override
public NodeInfo apply( Object t0, NodeInfo t1, Option<Object> t2 )
{
if ( t2.isDefined( ) )
{
t1.outDegrees = (Integer) t2.get( );
}
return t1;
}
}, ClassManifestFactory.classType( Object.class ), nodeTag, null );
g.cache( );
materialize( g );
Graph<NodeInfo, Float> result = computerFactors( g );
List<Tuple2<Integer, Float>> resultList = recommendProducts(1,2, result );
ctx.stop( );
System.out.println( "Done" );
}
private static List<Tuple2<Integer, Float>> recommendProducts( int user,
int num, Graph<NodeInfo, Float> g )
{
JavaPairRDD<Object, NodeInfo> vertices = g.vertices( ).toJavaRDD( ).mapToPair( s -> s );
final float[] uFactors = vertices.lookup( user ).get( 0 ).factor;
List<Tuple2<Integer, Float>> retValue = new ArrayList<Tuple2<Integer, Float>>();
retValue.addAll( g.vertices( ).toJavaRDD( ).filter( s-> s._2.type == 1 ).map( s-> {
Object dst = s._1;
float[] dFactors = s._2.factor;
float[] copy = new float[dFactors.length];
System.arraycopy( dFactors, 0, copy, 0, dFactors.length );
float value = blas.sdot( dFactors.length, uFactors, 1, copy, 1 );
return new Tuple2<Integer, Float>( ((Long) dst).intValue( ), value );
}).collect( ));
JavaRDD<Tuple2<Integer, Float>> result = g.vertices( ).toJavaRDD( ).filter( s-> s._2.type == 1 ).map( s-> {
Object dst = s._1;
float[] dFactors = s._2.factor;
float[] copy = new float[dFactors.length];
System.arraycopy( dFactors, 0, copy, 0, dFactors.length );
float value = blas.sdot( dFactors.length, uFactors, 1, copy, 1 );
return new Tuple2<Integer, Float>( ((Long) dst).intValue( ), value );
}).sortBy( s->s._2, false, g.vertices( ).getNumPartitions( ) );
int mumber = Math.min( num , ((Long)result.count( )).intValue( ) );
return result.take( mumber );
}
private static Graph<NodeInfo, Float>
computerFactors( Graph<NodeInfo, Float> g )
{
for ( int i = 0; i < MAXITER; i++ )
{
Graph<NodeInfo, Float> g1 = computer( g, 0 );
g1.cache( );
materialize( g1 );
g.unpersist( true );
g = computer( g1, 1 );
g.cache( );
materialize( g );
g1.unpersist( true );
}
return g;
}
private static Graph<NodeInfo, Float> computer( Graph<NodeInfo, Float> g,
int type )
{
VertexRDD<SendValue> vRDD = g.aggregateMessages( new MyFunction1<EdgeContext<NodeInfo, Float, SendValue>, BoxedUnit>( )
{
@Override
public BoxedUnit apply( EdgeContext<NodeInfo, Float, SendValue> t )
{
if ( type == 0 )
{
t.sendToDst( new SendValue( t.srcAttr( ).factor, t.attr( ) ) );
}
else if ( type == 1 )
{
t.sendToSrc( new SendValue( t.dstAttr( ).factor, t.attr( ) ) );
}
return BoxedUnit.UNIT;
}
}, new MyFunction2<SendValue, SendValue, SendValue>( )
{
@Override
public SendValue apply( SendValue t1, SendValue t2 )
{
// SendValue value = new SendValue();
SendValue value = null;
if ( t1.first && t2.first )
{
value = new SendValue( );
value.ne.add( t1.tolFactors, t1.b, 1.0 );
value.ne.add( t2.tolFactors, t2.b, 1.0 );
}
else if ( t1.first && !t2.first )
{
value = t2;
value.ne.add( t1.tolFactors, t1.b, 1.0 );
}
else if ( !t1.first && t2.first )
{
value = t1;
value.ne.add( t2.tolFactors, t2.b, 1.0 );
}
else
{
value = t1;
t1.ne.merge( t2.ne );
}
value.first = false;
return value;
}
}, TripletFields.All, sendValueTag );
List a = vRDD.toJavaRDD( ).collect( );
g = g.outerJoinVertices( vRDD, new MyFunction3<Object, NodeInfo, Option<SendValue>, NodeInfo>( )
{
@Override
public NodeInfo apply( Object t0, NodeInfo t1,
Option<SendValue> t2 )
{
if ( !t2.isDefined( ) )
{
return t1;
}
SendValue value = t2.get( );
if ( value.first )
{
value.ne.add( value.tolFactors, value.b, 1.0 );
}
double lambda = type == 0 ? LAMBDA * t1.inDegrees
: LAMBDA * t1.outDegrees;
t1.factor = choleskySolver( value.ne, lambda );
return t1;
}
}, sendValueTag, nodeTag, null );
return g;
}
private static List<Tuple2<Object, NodeInfo>>
createNode( List<Tuple2<Long, float[]>> list, int type )
{
List<Tuple2<Object, NodeInfo>> retValue = new ArrayList<Tuple2<Object, NodeInfo>>( );
for ( Tuple2<Long, float[]> t : list )
{
retValue.add( new Tuple2<Object, NodeInfo>( t._1, new NodeInfo( type, t._2 ) ) );
}
return retValue;
}
private static List<Tuple2<Long, float[]>> initFactors( List<Long> list,
int rank, long seed )
{
List<Tuple2<Long, float[]>> retValue = new ArrayList<Tuple2<Long, float[]>>( );
Random random = new XORShiftRandom( package$.MODULE$.byteswap64( seed ) );
for ( int i = 0; i < list.size( ); i++ )
{
float[] factor = new float[rank];
for ( int j = 0; j < rank; j++ )
{
factor[j] = ( (Double) random.nextGaussian( ) ).floatValue( );
}
float nrm = blas.snrm2( rank, factor, 1 );
blas.sscal( rank, 1.0f / nrm, factor, 1 );
retValue.add( new Tuple2<Long, float[]>( list.get( i ), factor ) );
}
return retValue;
}
private static void materialize( Graph g )
{
g.vertices( ).count( );
g.edges( ).count( );
}
private static class SendValue implements Serializable
{
float[] tolFactors;
float b = -1;
NormalEquation ne = new NormalEquation( RANK );
boolean first = true;
public SendValue( )
{
}
public SendValue( float[] tolFactors, float b )
{
super( );
this.tolFactors = tolFactors;
this.b = b;
}
}
private static class NodeInfo implements Serializable
{
int type;
float[] factor;
long outDegrees;
long inDegrees;
public NodeInfo( int type, float[] factor )
{
super( );
this.type = type;
this.factor = factor;
}
}
private static float[] choleskySolver( NormalEquation ne, double lambda )
{
int k = ne.k;
int i = 0;
int j = 2;
while ( i < ne.trik )
{
ne.ata[i] = ne.ata[i] + lambda;
i = i + j;
j = j + 1;
}
solve( ne.ata, ne.atb );
float[] x = new float[k];
i = 0;
while ( i < k )
{
x[i] = ( (Double) ne.atb[i] ).floatValue( );
i = i + 1;
}
ne.reset( );
return x;
}
private static double[] solve( double[] A, double[] bx )
{
int k = bx.length;
intW info = new intW( 0 );
lapack.dppsv( "U", k, 1, A, bx, k, info );
if ( info.val != 0 )
{
throw new RuntimeException( "LAPACK run error" );
}
return bx;
}
private static class NormalEquation implements Serializable
{
private static final String upper = "U";
private int k;
private int trik;
private double[] ata;
private double[] atb;
private double[] da;
public NormalEquation( int k )
{
super( );
this.k = k;
trik = k * ( k + 1 ) / 2;
ata = new double[trik];
atb = new double[k];
da = new double[k];
}
private void copyToDouble( float[] a )
{
int i = 0;
while ( i < k )
{
da[i] = a[i];
i = i + 1;
}
}
NormalEquation add( float[] a, double b, double c )
{
copyToDouble( a );
blas.dspr( upper, k, c, da, 1, ata );
if ( b != 0.0 )
{
blas.daxpy( k, c * b, da, 1, atb, 1 );
}
return this;
}
NormalEquation merge( NormalEquation other )
{
blas.daxpy( ata.length, 1.0, other.ata, 1, ata, 1 );
blas.daxpy( atb.length, 1.0, other.atb, 1, atb, 1 );
return this;
}
void reset( )
{
Arrays.fill( ata, 0.0 );
Arrays.fill( atb, 0.0 );
}
}
private static abstract class MyFunction1<T1, R> extends AbstractFunction1<T1, R> implements Serializable
{
}
private static abstract class MyFunction2<T1, T2, R> extends AbstractFunction2<T1, T2, R> implements Serializable
{
}
private static abstract class MyFunction3<T1, T2, T3, R> extends AbstractFunction3<T1, T2, T3, R> implements Serializable
{
}
}