#include <mpi.h>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <math.h>

int my_rank;
int nprocs;
int nprocs_y;
int nprocs_x;
int prev_y;
int next_y;
int next_x;
int prev_x;
MPI_Datatype vertSlice, horizSlice;
MPI_Comm cart_comm;
int imax_full;
int jmax_full;
int gbl_i_begin;
int gbl_j_begin;

double* dat_ptrs[4];
int dat_dirty[4] = {1,1,1,1};

void mpi_setup(int argc, char **argv, int *imax, int *jmax) {
  MPI_Init(&argc,&argv);
  MPI_Comm_size(MPI_COMM_WORLD, &nprocs);
  int sides[2]={0,0};
  MPI_Dims_create(nprocs,2,sides);

  nprocs_x = sides[0];
  nprocs_y = sides[1];
  int dims[2] = {sides[0],sides[1]};
  int periods[2] = {0,0};
  int reorder = 0;
  int coords[2];
  MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, reorder, &cart_comm);
  MPI_Comm_rank(cart_comm,&my_rank);
  MPI_Cart_coords(cart_comm, my_rank, 2, coords);
  int my_rank_x = coords[0];
  int my_rank_y = coords[1];
  MPI_Cart_shift(cart_comm,0,1,&prev_y,&next_y);
  MPI_Cart_shift(cart_comm,1,1,&prev_x,&next_x);
  
  //Save original full sizes in x and y directions
  imax_full = *imax;
  jmax_full = *jmax;
  
  //Modify imax and jmax (pay attention to integer divisions's rounding issues!)
  *imax = (my_rank_x != nprocs_x-1) ? imax_full/nprocs_x : imax_full - my_rank_x * (imax_full/nprocs_x);
  *jmax = (my_rank_y != nprocs_y-1) ? jmax_full/nprocs_y : jmax_full - my_rank_y * (jmax_full/nprocs_y);

  //Figure out beginning i and j index in terms of global indexing
  gbl_i_begin = my_rank_x * (imax_full/nprocs_x);   
  gbl_j_begin = my_rank_y * (jmax_full/nprocs_y);   

  //Let's set up MPI Datatypes
  //Homework: ghost cells are not 1 on each side, but 2! Change these to send 2 rows/columns at the same time
  MPI_Type_vector((*jmax)*2+8,2,(*imax)+1, MPI_DOUBLE, &vertSlice);
  MPI_Type_vector((*imax)*2+8,1,1, MPI_DOUBLE, &horizSlice);
  MPI_Type_commit(&vertSlice);
  MPI_Type_commit(&horizSlice); 
  
}

void exchange_halo(int imax, int jmax, double *arr) {
  int dirty = -1;
  
  for (int i = 0; i < 4; i++) {
    if ((double*)arr == dat_ptrs[i]) {
      if (dat_dirty[i]) dirty = i;
      break;
    }
  }
  if (dirty!=-1) {
    //Homework: ghost cells are not 1 on each side, but 2!
    // since we are sending 2 rows/columns, make sure the offsets into arr are right!
    //Exchange halos: top, bottom, left, right
    MPI_Request requests[8];
    int counter = 0;
    if (next_x >= 0)
        MPI_Isend(&arr[0*(jmax+4)+jmax],    1,vertSlice,next_x,0, cart_comm,&requests[counter++]);
    if (prev_x >= 0)
        MPI_Irecv(&arr[0*(jmax+4)+0],       1,vertSlice,prev_x,my_rank,cart_comm,&requests[counter++]);

    if (prev_x >= 0)
        MPI_Isend(&arr[0*(jmax+4)+2],       1,vertSlice,prev_x,0, cart_comm,&requests[counter++]);
    if (next_x >= 0)
        MPI_Irecv(&arr[0*(jmax+4)+jmax+2],  1,vertSlice,next_x,my_rank,cart_comm,&requests[counter++]);

    if (next_y >= 0)
        MPI_Isend(&arr[(imax)*(jmax+4)+0],  1,horizSlice,next_y,0, cart_comm,&requests[counter++]);
    if (prev_y >= 0)
        MPI_Irecv(&arr[0*(jmax+4)+0]     ,  1,horizSlice,prev_y,my_rank,cart_comm,&requests[counter++]);

    if (prev_y >= 0)
        MPI_Isend(&arr[2*(jmax+4)+0],       1,horizSlice,prev_y,0, cart_comm,&requests[counter++]);
    if (next_y >= 0)
        MPI_Irecv(&arr[(imax+2)*(jmax+4)+0],1,horizSlice,next_y,my_rank,cart_comm,&requests[counter++]);

    //printf("%d %d %d %d %d %d\n", my_rank,counter, requests[0], requests[1], requests[2], requests[3]);
    MPI_Status statuses[8];
    //MPI_Waitall(counter,requests,statuses);
    dat_dirty[dirty] = 0;
  }
}

void set_dirty(double *arr) {
  for (int i = 0; i < 4; i++) {
    if ((double*)arr == dat_ptrs[i]) {
      dat_dirty[i] = 1;
      break;
    }
  }
}
