#define __WINDOWS__ 1
#include "libidx.h"
#include "libeblearn.h"
#include "libeblearntools.h"
//#include "netconf.h"
#include <iostream>
#include <stdio.h>
using namespace std;
using namespace ebl;
uint dump_count = 0;
void test_softmax();
int main(int argc,char** argv)
{
test_softmax();
return 0;
}
void test_softmax()
{
state<double> in(2,2,2,2,2,2);
state<double> out(1,1,1,1,1,1);
double beta = 1;
softmax_module<double> module(beta);
// init
dseed(1);
cout<<"fixed seed:"<<endl;
cout<<"in:"<<endl;
in.printElems();
cout<<endl;
module.fprop(in, out);
cout<<"out:"<<endl;
out.printElems();
cout<<endl;
dynamic_init_drand();
idx_bloop2(i, in, double, o, out, double)
{
idx_bloop2(ii, i, double, oo, o, double)
{
idx_bloop2(iii, ii, double, ooo, oo, double)
{
idx_bloop2(iiii, iii, double, oooo, ooo, double)
{
idx_bloop2(iiiii, iiii, double, ooooo, oooo, double)
{
idx_bloop2(iiiiii, iiiii, double, oooooo, ooooo, double)
{
iiiiii.set(drand((double)1));
oooooo.set(drand((double)1));
}
}
}
}
}
}
// fprop, bprop, bbprop
cout<<endl;
cout<<"dynamic seed:"<<endl;
cout<<"in:"<<endl;
in.printElems();
cout<<endl;
module.fprop(in, out);
cout<<"out:"<<endl;
out.printElems();
cout<<endl;
in.zero_dx();
module.bprop(in, out);
in.zero_ddx();
module.bbprop(in, out);
/*print
printf(" Input\n");
in.pretty();
printf(" Output\n");
out.pretty();
printf(" Input dx\n");
in.dx.pretty();
printf(" Output dx\n");
out.dx.pretty();
printf(" Input ddx\n");
in.ddx.pretty();
printf(" Output ddx\n");
out.ddx.pretty();
printf("\n");*/
idx<double> ib3 = in.select(0,0).select(0,0),
calc_out = out.select(0,0).select(0,0);
idx<double> ib(new srg<double>(), ib3.spec),
des_out(new srg<double>(), ib3.spec);
idx_dotc(ib3, module.beta, ib);
idx_exp(ib);
double ib2 = 1/idx_sum(ib);
idx_dotc(ib, ib2, des_out);
//printf("Fprop error 1 : %3.3e \n", idx_sqrdist(calc_out, des_out));
ib3 = in.select(0,1).select(0,0);
calc_out = out.select(0,1).select(0,0);
idx_dotc(ib3, module.beta, ib);
idx_exp(ib);
ib2 = 1/idx_sum(ib);
idx_dotc(ib, ib2, des_out);
//printf("Fprop error 2 : %3.3e \n", idx_sqrdist(calc_out, des_out));
ib3 = in.select(0,0).select(0,1);
calc_out = out.select(0,0).select(0,1);
idx_dotc(ib3, module.beta, ib);
idx_exp(ib);
ib2 = 1/idx_sum(ib);
idx_dotc(ib, ib2, des_out);
//printf("Fprop error 3 : %3.3e \n", idx_sqrdist(calc_out, des_out));
ib3 = in.select(0,1).select(0,1);
calc_out = out.select(0,1).select(0,1);
idx_dotc(ib3, module.beta, ib);
idx_exp(ib);
ib2 = 1/idx_sum(ib);
idx_dotc(ib, ib2, des_out);
//printf("Fprop error 4 : %3.3e \n", idx_sqrdist(calc_out, des_out));
/*
Bprop_tester *bproptest = new Bprop_tester();
bproptest->test(module);
Bbprop_tester *bbproptest = new Bbprop_tester();
bbproptest->test(module);
Jacobian_tester *test= new Jacobian_tester();
test->test(module);
*/
}