#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <g6util.h>

static double
g_p3m(double re)
{
    double func, cppfrc;

    if ((re>=0)&&(re<1)) {
	func=re*(224.+re*re*(-224.+re*(70.+re*(48.-re*21.))))/(35.*4.0);
	cppfrc = 1.0-re*re*func;
    }
    else {
	if ((re>=1)&&(re<2)) {
	    func=(12./(re*re)-224.+re*(896.+re*(-840.+re*(224.+re*(70.+re*(-48.+re*7.))))))/(35.*4.0);
	    cppfrc = 1.0-re*re*func;      
	}
	else {
	    cppfrc = 0;
	}
    }
    return cppfrc;   
}

static void
put_two_particles(double r, double rmax,
                  double xi[3], double xj[3], double r_dr,
                  double vi[3], double vj[3], double v, double v_dv)
{
    double phi, costheta, sintheta;
    int k;

    r = (drand48() - 0.5) * (r_dr - r) + r;
    v = (drand48() - 0.5) * (v_dv - v) + v;

#if 1
    phi = 2.0 * M_PI * drand48();
    costheta = 2.0 * drand48() - 1.0;
    sintheta = sqrt(1 - costheta * costheta);
    xj[0] = r * sintheta * cos(phi);
    xj[1] = r * sintheta * sin(phi);
    xj[2] = r * costheta;

    phi = 2.0 * M_PI * drand48();
    costheta = 2.0 * drand48() - 1.0;
    sintheta = sqrt(1 - costheta * costheta);
    vj[0] = v * sintheta * cos(phi);
    vj[1] = v * sintheta * sin(phi);
    vj[2] = v * costheta;
#else
    xj[0] = r;
    xj[1] = 0.0;
    xj[2] = 0.0;

    vj[0] = v;
    vj[1] = 0.0;
    vj[2] = 0.0;
#endif
    for (k = 0; k < 3; k++) {
	double axj = fabs(xj[k]);
	xi[k] = (rmax - 2.0 * axj) * (drand48() - 0.5);
	xj[k] += xi[k];

        double vmax = rmax * 1e-3; // may be too large
	vi[k] = vmax * (drand48() - 0.5);
	vj[k] += vi[k];
    }

    //	fprintf(stderr, "xi: %e %e %e  xj: %e %e %e\n",
    //		xi[0], xi[1], xi[2], xj[0], xj[1], xj[2]);
}




#define JSHIFT (195)
#define VSCALE pow(2.0, 5.0);

static void
set_range(double xmin, double xmax)
{
    double xsize = xmax - xmin;
    double xscale = pow(2.0, 64.0) / xsize;
    double vscale = VSCALE;
    double eps2scale = xscale * xscale;
    double mscale = 1.0;
    double ascale = xscale * xscale / mscale;
    double jscale = xscale * xscale * xscale / mscale / vscale;

    g6_set_range_xj(xmin, xmax);
    g6_set_range_xi(xmin, xmax);
    g6_set_scale_vj(vscale);
    g6_set_scale_vi(vscale);
    g6_set_scale_epsi2(eps2scale);
    g6_set_scale_mj(mscale);
    g6_set_scale_acc(ascale);
    g6_set_scale_jerk(jscale);
}

static void
pairwise_force_grape(double m, double xi[3], double xj[3], double eps, double vi[3], double vj[3],
                     double ag[3], double jg[3])
{
    int k;
    int fshift = 150;
    int jshift = JSHIFT;
    double eps2 = eps * eps;

    g6_set_jp(0, 1, &m, (double (*)[3])xj, (double (*)[3])vj);
    g6_set_n(1);
    g6_set_ip(1, (double (*)[3])xi, (double (*)[3])vi, &eps2, &fshift, &jshift);
    g6_run();
    g6_get_fout(1, (double (*)[3])ag, (double (*)[3])jg);

    double ascale = pow(2.0, -fshift);
    for (k = 0; k < 3; k++) {
	ag[k] *= ascale;
    }
    double jscale = pow(2.0, -jshift);
    for (k = 0; k < 3; k++) {
	jg[k] *= jscale;
    }
}

static void
pairwise_force_host_p3m(double m, double xi[3], double xj[3], double eps, double ag[3], double eta)
{
    double r, r2, r3, g;
    int k;

    r2 = eps * eps;
    for (k = 0; k < 3; k++) {
	r2 += (xi[k] - xj[k]) * (xi[k] - xj[k]);
    }
    r = sqrt(r2);
    r3 = r2 * r;
    g = g_p3m(r/eta);
    for (k = 0; k < 3; k++) {
	ag[k] = g * m * (xj[k] - xi[k]) / r3;
    }
}

static void
pairwise_force_host(double m, double xi[3], double xj[3], double eps, double vi[3], double vj[3],
                    double ag[3], double jg[3])
{
    double r, r2, r3, r5, g;
    double dx[3], dv[3];
    int k;

    for (k = 0; k < 3; k++) {
        dx[k] = xj[k] - xi[k];
        dv[k] = vj[k] - vi[k];
    }
    r2 = eps * eps;
    for (k = 0; k < 3; k++) {
	r2 += dx[k] * dx[k];
    }
    r = sqrt(r2);
    r3 = r2 * r;
    r5 = r3 * r2;
    for (k = 0; k < 3; k++) {
	ag[k] = m * dx[k] / r3;
    }
    double vdotx = 0;
    for (k = 0; k < 3; k++) {
        vdotx += dv[k] * dx[k];
    }
    for (k = 0; k < 3; k++) {
	jg[k] = m * (dv[k] / r3 - 3.0 * vdotx * dx[k]/ r5);
    }
}

/*
 * a0: force0 (may have cutoff)
 * a1: force1 (may have cutoff)
 * a2: force without cutoff (i.e. pure gravity)
 */
static double
compare_force(double a0[3], double a1[3], double a2[3])
{
    int k;
    double e, e2 = 0.0, absa = 0.0;

    for (k = 0; k < 3; k++) {
        absa += a2[k] * a2[k];
        e2 += (a0[k] - a1[k]) * (a0[k] - a1[k]);
    }
    absa = sqrt(absa);
    e = sqrt(e2);
    return e/absa; // relative error
}

/*
 * a0: force0 (may have cutoff)
 * a1: force1 (may have cutoff)
 * a2: force without cutoff (i.e. pure gravity)
 */
static double
compare_force_ave(double a0[3], double a1[3], double a2[3])
{
    int k;
    double e, abs0 = 0.0, abs1 = 0.0, abs2 = 0.0;

    for (k = 0; k < 3; k++) {
        abs0 += a0[k] * a0[k];
        abs1 += a1[k] * a1[k];
        abs2 += a2[k] * a2[k];
    }
    abs0 = sqrt(abs0);
    abs1 = sqrt(abs1);
    abs2 = sqrt(abs2);
    e = abs0 - abs1;
    return e/abs2;
}

int
main(int argc, char **argv)
{
    int n, i, ntry;
    double r, dr, rmax, eps, eta;
    double v, dv;
    double e, s2, eave;
    double ej, sj2, ejave;
    double mmin, m;
    double xi[3], xj[3];
    double vi[3], vj[3];
    double ag[3], pg, jg[3];
    //    double ah[3], ph; // may have cutoff
    double ah0[3], ph0, jh0[3]; // pure gravity

    srand48(1234);
    ntry = 1; // average over 'ntry' trials

    dr = 1.05;

    rmax = 1.0;
    mmin = 1.0;
    m = 1.0;
    eps = 3.3e-4 * rmax;
    eps = 1e-6 * rmax;
    eta = rmax*10000; // force cut-off is not supported.

    g6_open();
    set_range(-rmax, rmax);

    v =  1e-2;
    dv = 1.05;

    for (r = rmax * 1e-15 ; r < rmax * 2.0; r *= dr) {
        s2 = eave = 0.0;
        sj2 = ejave = 0.0;
        for (i = 0; i < ntry; i++) {
            n = 1;
            put_two_particles(r, rmax, xi, xj, r * dr, vi, vj, v, v * dv);

	    pairwise_force_grape(m, xi, xj, eps, vi, vj, ag, jg);
            pairwise_force_host(m, xi, xj, eps, vi, vj, ah0, jh0);

            e = compare_force(ah0, ag, ah0);
            s2 += e * e;
            eave += compare_force_ave(ah0, ag, ah0);

            ej = compare_force(jh0, jg, jh0);
            sj2 += ej * ej;
            ejave += compare_force_ave(jh0, jg, jh0);
        }

        s2 /= ntry;
        s2 = sqrt(s2);
	eave /= ntry;

        sj2 /= ntry;
        sj2 = sqrt(sj2);
	ejave /= ntry;

	printf("% 15.13E % 15.13E % 15.13E % 15.13E % 15.13E % 15.13E % 15.13E % 15.13E % 15.13E\n",
               r, s2, eave, sj2, ejave, ah0[0], ag[0], jh0[0], jg[0]);
    }

    g6_close();
}
