#include "TGraph.h"

#include "ODEIntegrator.h"

double metodoEulero(TGraph* gr, double f(double, double), double xmin, double xmax, double y0, double step){
  
  // check sul numero di passi
  int nsteps=(int)((xmax-xmin)/step);
  xmax = xmin + nsteps*step;
  if (fabs(nsteps-(xmax-xmin)/step)>0.01*nsteps) {
    printf("The closest integer number of steps is: %d\n", nsteps);
    printf("The equation will be solved in [%f, %f]\n", xmin, xmax);
  }
  
  double ynp1=y0;
  double xn=xmin;
  double yn=ynp1;
  if (gr) gr->SetPoint(0, xn, yn);
  for(unsigned int j=1; j<=nsteps; j++) {
    ynp1=yn+step*f(xn,yn);
    //    printf("Eulero %f %f %f -> %f\n", xn, yn, f(xn,yn), ynp1);
    xn+=step;
    yn=ynp1;
    if (gr) gr->SetPoint(j, xn, yn);
  }
  
  return ynp1;
}

double metodoMidpointEsplicito(TGraph* gr, double f(double, double), double xmin, double xmax, double y0, double step){

  // check sul numero di passi
  int nsteps=(int)((xmax-xmin)/step);
  xmax = xmin + nsteps*step;
  if (fabs(nsteps-(xmax-xmin)/step)>0.01*nsteps) {
    printf("The closest integer number of steps is: %d\n", nsteps);
    printf("The equation will be solved in [%f, %f]\n", xmin, xmax);
  }
  
  double ynp1=y0;
  double xn=xmin;
  double yn=ynp1;
  if (gr) gr->SetPoint(0, xn, yn);
  for(unsigned int j=1; j<=nsteps; j++) {
    ynp1=yn+step*f(xn+0.5*step,yn+0.5*step*f(xn,yn));;
    //    printf("%f %f -> %f\n", xn, yn, ynp1);
    xn+=step;
    yn=ynp1;
    if (gr) gr->SetPoint(j, xn, yn);
  }
  
  return ynp1;
}

double metodoMidpointImplicito(TGraph* gr, double f(double, double), double xmin, double xmax, double y0, double step){

  // check sul numero di passi
  int nsteps=(int)((xmax-xmin)/step);
  xmax = xmin + nsteps*step;
  if (fabs(nsteps-(xmax-xmin)/step)>0.01*nsteps) {
    printf("The closest integer number of steps is: %d\n", nsteps);
    printf("The equation will be solved in [%f, %f]\n", xmin, xmax);
  }
  
  double ynp1=y0;
  double xn=xmin;
  double yn=ynp1;
  if (gr) gr->SetPoint(0, xn, yn);
  for(unsigned int j=1; j<=nsteps; j++) {
    // valido solo per la exp
    //    ynp1=yn*(1.0+0.5*step)/(1.0-0.5*step);

    ImplicitMidpointEquationSolver imes;
    double euleroincrement = step*f(xn, yn);//useful just for the 'edges' of the equation solver
    if (fabs(euleroincrement)<1.0E-20) euleroincrement=1.0;
    ynp1 = imes.FindSolution(f, yn, xn, step, yn-100.0*euleroincrement, yn+100.0*euleroincrement, xn==0?1.0E-6:xn*1.0E-6);
    
    //    printf("%f %f -> %f\n", xn, yn, ynp1);
    xn+=step;
    yn=ynp1;
    if (gr) gr->SetPoint(j, xn, yn);
  }
  
  return ynp1;
}

double metodoRK4(TGraph* gr, double f(double, double), double xmin, double xmax, double y0, double step){

  // check sul numero di passi
  int nsteps=(int)((xmax-xmin)/step);
  xmax = xmin + nsteps*step;
  if (fabs(nsteps-(xmax-xmin)/step)>0.01*nsteps) {
    printf("The closest integer number of steps is: %d\n", nsteps);
    printf("The equation will be solved in [%f, %f]\n", xmin, xmax);
  }
  
  double ynp1=y0;
  double xn=xmin;
  double yn=ynp1;
  if (gr) gr->SetPoint(0, xn, yn);
  for(unsigned int j=1; j<=nsteps; j++) {
    double k1 = f(xn         , yn);
    double k2 = f(xn+0.5*step, yn+0.5*step*k1);
    double k3 = f(xn+0.5*step, yn+0.5*step*k2);
    double k4 = f(xn+step    , yn+step*k3);
    ynp1=yn+(step/6.0)*(k1+2.0*k2+2.0*k3+k4);
    //    printf("%f %f -> %f\n", xn, yn, ynp1);
    xn+=step;
    yn=ynp1;
    if (gr) gr->SetPoint(j, xn, yn);
  }
  
  return ynp1;
}

ODEIntegrator::ODEIntegrator(const std::string method) {
  _fi = NULL;
  SetMethod(method);
  return;
}

void ODEIntegrator::SetMethod(const std::string method){

  _method = method;
  
  if (_method=="Eulero")                 _fi = metodoEulero;
  else if (_method=="MidpointImplicito") _fi = metodoMidpointImplicito;
  else if (_method=="MidpointEsplicito") _fi = metodoMidpointEsplicito;
  else if (_method=="RK4") _fi = metodoRK4;
  else {
    printf("metodo %s non implementato\n", _method.c_str());
    _fi = NULL;
  }
  
  return;
}

double ODEIntegrator::odeint(const std::string method, double f(double, double), double xmin, double xmax, double y0, double step) {
  ODEIntegrator oi(method);
  return oi.odeint(NULL, f, xmin, xmax, y0, step); 
}

double ODEIntegrator::odeint(const std::string method, TGraph* gr, double f(double, double), double xmin, double xmax, double y0, double step) {
  ODEIntegrator oi(method);
  return oi.odeint(gr, f, xmin, xmax, y0, step); 
}

double ODEIntegrator::odeint(double f(double, double), double xmin, double xmax, double y0, double step) {
  return odeint(NULL, f, xmin, xmax, y0, step); 
}

double ODEIntegrator::odeint(TGraph* gr, double f(double, double), double xmin, double xmax, double y0, double step) {

  if (!_fi) {
    printf("Method %s not valid...\n", _method.c_str());
    return -99999.9;
  }
  
  double solatxmax = _fi(gr, f, xmin, xmax, y0, step);
  
  return solatxmax;  
}
