#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <malloc.h>
#define REAL double
#define DIM 3
#include "jcode.h"

#pragma goose func
double rsqrt(double r2)
{
    return 1.0/sqrt(r2);
}

void calculate_force(REAL x[][3], REAL m[], REAL a[][3], REAL pot[],
                     int ilist[], int nilist, struct jlist_t *jlist, int njlist,
                     REAL eps)
{
    int i,j,k;
    double dx[3],r2,rinv,mrinv,mr3inv;

    // #pragma omp parallel for private(j,k,dx,r2,rinv,mrinv,mr3inv)

#pragma goose parallel for precision ("double") loopcounter(i, j)
    // 
    // This is the simplest. Just in case you may write:
    // 
    // #pragma goose parallel for precision ("double") loopcounter(i, j) \
    // ip (x[ilist[i]][0..2]) jp (jlist->x[j][0..2],jlist->mass[j])      \
    // result (a[ilist[i]][0..2],pot[ilist[i]])

    for(i=0;i<nilist;i++){
        for(k=0;k<3;k++) a[ilist[i]][k] = 0.0;
        pot[ilist[i]] = 0.0;
        for(j=0;j<njlist;j++){      
            for(k=0;k<3;k++) dx[k] = jlist->x[j][k] - x[ilist[i]][k];
            r2 = dx[0]*dx[0] + dx[1]*dx[1] + dx[2]*dx[2] + eps*eps;
            rinv = rsqrt(r2);
            mrinv = jlist->mass[j] *rinv;
            mr3inv = mrinv*rinv*rinv;
            a[ilist[i]][0] += mr3inv * dx[0];
            a[ilist[i]][1] += mr3inv * dx[1];
            a[ilist[i]][2] += mr3inv * dx[2];      
            pot[ilist[i]] += -mrinv;      
        }
    }

}

void initialize_force_function(void)
{
}

static int logout_firstflag=1;

void calculate_force_using_tree(int n, REAL x[][3], REAL m[], REAL a[][3], REAL pot[],
                                REAL eps, struct clist_t clist[], int nwalk, int walklist[],
                                int index[], double maxx, double *st, double *dninter)
{
    int ii,i,j,nilistsum=0;
    double tlist,tgrape,lt=0;
    long long int current_key;
    long long int ninter,sumjlist;
    double tmpcm[3];

    int *ilist = NULL ;
    struct jlist_t *jlist;
    jlist = (struct jlist_t *)malloc(sizeof(struct jlist_t));

    tlist = tgrape = 0;

    ninter = sumjlist = 0;
    nilistsum = 0;

    for(ii=0;ii<nwalk;ii++){

        int njlist=0,nilist,tmpif;
        double coc[3],totalm=0;
        i = walklist[ii];

        nilist = clist[i].n;
        tmpif = clist[i].ifirst;

        ilist = (int *)realloc(ilist,sizeof(int)*nilist);
      
        for(j=0;j<nilist;j++) ilist[j] = index[j+tmpif];

        center_of_cell(clist[i].key,clist[i].key_level,clist[i].length,coc,maxx);
        current_key = 1;
        make_interaction_list(clist[i].key,coc,clist[i].length,current_key,clist,&njlist,index,jlist,x,m);

        if(njlist > NJMAX) printf("njlist %d > NJMAX\n",njlist);
        ninter += njlist * nilist;
        sumjlist += njlist; 
        nilistsum += nilist;     

        get_wcputime(&lt,st);
        tlist += lt; 

        calculate_force(x,m,a,pot,ilist,nilist,jlist,njlist,eps);

        if(eps!=0){
            for(j=0;j<nilist;j++) pot[ilist[j]] += m[ilist[j]]/eps;
        }    /* add 2002.7.24 */

        get_wcputime(&lt,st);
        tgrape += lt;

    }

    /*printf("force time xjt %g mjt %g it %g\n",forcexjt,forcemjt,forceit);*/

    (*dninter) = ((double)ninter);
    if(logout_firstflag == 1){
        long long int nave;
        nave = ninter/n;
        printf("ninter %lld ave %lld sumjlist %lld sumilist %d\n",ninter,nave,sumjlist,nilistsum);
        printf("time spent for list creation on the host : %g\n", tlist);
        printf("time spent on the accelerator            : %g\n", tgrape);
        system("ps augx | grep jcode | grep -v grep");
        system("ps augx | grep j9 | grep -v grep");
        printf("***************************************\n");
    }
    if(logout_firstflag == 1) logout_firstflag = 0;
    free(jlist);free(ilist);
}

