/*

============================== tmpi.c =========================================

   The following MPI code exposes a bug on some platforms (Abe and probably 
   Ranger). At least on Abe, the bug appears regardless of the setting of
   SET_ATOMICITY, REOPEN, DO_IND_WRITE, and DISABLE_AGGREGATION (all of which
   should be set to either 0 or 1).

   The bug expresses itself as file (and thus data set corruption) that is 
   generally consistent with writes being executed out of order despite
   what (we hope) is the correct use of sync/barrier/sync or file close/open 
   calls that should enforce correct ordering.

============================== tmpi.c =========================================

*/

#include <mpi.h>
#include <stdlib.h>
#include <stdio.h>
#include <assert.h>
#include <string.h>

#define BLOCK 10
#define NITER 100
#define IND_WRITE_BUF_SIZE	80

/* set to 1 to use MPI_set_file_atomicity */
#define SET_ATOMICITY 0

/* set to 1 to close and reopen the file after the write */
#define REOPEN 0

/* set to 1 to do independant write after collective read */
#define DO_IND_WRITE 1

/* set to 1 to disable aggregation */
#define DISABLE_AGGREGATION 1


void construct_file_mpi_datatype(int mpi_rank,
                                 int mpi_size,
                                 int block_len,
 				 MPI_Datatype * file_type_ptr)
{
    int          block_length[3];
    MPI_Datatype inner_type;    /* Inner MPI Datatype */
    MPI_Datatype outer_type;    /* Inner MPI Datatype */
    MPI_Datatype filetype;      /* MPI File datatype */
    MPI_Datatype old_types[3];
    MPI_Aint     displacement[3];

    /* Create base contiguous type */
    MPI_Type_contiguous(sizeof(int), MPI_BYTE, &inner_type);

    if ( mpi_rank == 0 ) {
        /* Rank 0 operates on 2 blocks, other processes only operate on 1 */

        /* Select the first and last blocks for mpi_rank 0 */
        MPI_Type_vector(2, block_len, block_len * mpi_size, inner_type, &outer_type);
        MPI_Type_free(&inner_type);

        inner_type = outer_type;

        filetype = inner_type;

        MPI_Type_commit(&filetype);

    } else {

        /* Select the block corresponding to the mpi_rank */
        MPI_Type_vector(1, block_len, 1, inner_type, &outer_type);
        MPI_Type_free(&inner_type);

        inner_type = outer_type;

        block_length[0] = 1;
        block_length[1] = 1;
        block_length[2] = 1;

        old_types[0] = MPI_LB;
        old_types[1] = outer_type;
        old_types[2] = MPI_UB;

        displacement[0] = 0;
        displacement[1] = mpi_rank * block_len * sizeof(int);
        displacement[2] = (mpi_size + 1) * block_len * sizeof(int);

        MPI_Type_struct(3, block_length, displacement, old_types, &inner_type);

        MPI_Type_free(&outer_type);

        filetype = inner_type;

        MPI_Type_commit(&filetype);
    }

    *file_type_ptr = filetype;

    return;

} /* construct_file_mpi_datatype() */


void do_independant_write(MPI_File fh,
                          int mpi_rank,
                          int mpi_size,
                          int generation,
                          MPI_Offset base_offset)
{
    char	 write_buf[IND_WRITE_BUF_SIZE];
    int		 i;
    int		 success = 1;
#if DISABLE_AGGREGATION
    MPI_Info	 info;
#endif /* DISABLE_AGGREGATION */

    for ( i = 0; i < IND_WRITE_BUF_SIZE; i++ ) {

        write_buf[i] = '\0';
    }

    sprintf(write_buf, "Independent write %d/%d.", generation, mpi_rank);

    assert(strlen(write_buf) < IND_WRITE_BUF_SIZE);

#if DISABLE_AGGREGATION
    MPI_Info_create(&info);
    MPI_Info_set(info, "cb_config_list", "*:*");
    MPI_File_set_view(fh, 0, MPI_BYTE, MPI_BYTE, "native", info);
#else /* DISABLE_AGGREGATION */
    MPI_File_set_view(fh, 0, MPI_BYTE, MPI_BYTE, "native", MPI_INFO_NULL);
#endif /* DISABLE_AGGREGATION */

    MPI_File_write_at(fh, 
                      base_offset + (mpi_rank * IND_WRITE_BUF_SIZE), 
                      write_buf, 
                      IND_WRITE_BUF_SIZE, 
                      MPI_BYTE, 
                      MPI_STATUS_IGNORE);
#if DISABLE_AGGREGATION
    MPI_Info_free(&info);
#endif /* DISABLE_AGGREGATION */

    return;

} /* do_independant_write() */


int main(int argc, char *argv[])
{
    int          *wbuf = NULL;  /* Write buffer */
    int          *rbuf = NULL;  /* Read buffer */
    int          mpi_rank;      /* MPI Rank */
    int          mpi_size;      /* MPI Size */
    int		 block_len = BLOCK;
    MPI_File     fh;            /* File */
    MPI_Datatype filetype;      /* MPI File datatype */
    int          failed = 0;
    int          failure_point;
    int          i, j, k;
#if DISABLE_AGGREGATION
    MPI_Info	 info;
#endif /* DISABLE_AGGREGATION */


    /* Setup */
    MPI_Init(&argc, &argv);
    MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
    MPI_Comm_size(MPI_COMM_WORLD, &mpi_size);


    if ( mpi_rank == 0 ) {

	fprintf(stdout, 
                "NITER = %d, SET_ATOMICITY = %d, REOPEN = %d, DO_IND_WRITE = %d, DISABLE_AGGREGATION = %d.\n",
                NITER, SET_ATOMICITY, REOPEN, DO_IND_WRITE, DISABLE_AGGREGATION);
    }

    /* Loop NITER times */
    for(i=0; i<NITER; i++) {

        if ( mpi_rank == 0 ) {

            fprintf(stdout, "Itteration %d: block size == %d.\n", i, block_len);
        }

        /* construct the file mpi derived type */
        construct_file_mpi_datatype(mpi_rank, mpi_size, block_len, &filetype);

        /* Allocate buffers */
        /* All processes read the entire file */
        rbuf = (int *)malloc((mpi_size + 1) * block_len * sizeof(int));

        if(mpi_rank == 0) {
            /* Rank 0 operates on 2 blocks, other processes only operate on 1 */
            wbuf = (int *)malloc(2 * block_len * sizeof(int));

            for(j=0; j<block_len; j++) {
                wbuf[j] = j + i;
                wbuf[j + block_len] = j + (mpi_size * block_len) + i;
            }
        } else {
            wbuf = (int *)malloc(block_len * sizeof(int));

            /* Fill buffer: final file will be simply a series of increasing
             * integers: 0, 1, 2, 3... */
            for(j=0; j<block_len; j++)
                wbuf[j] = j + (mpi_rank * block_len) + i;
        }

        /* Barrier */
        MPI_Barrier(MPI_COMM_WORLD);

        /* Open file collectively */
        MPI_File_open(MPI_COMM_WORLD, "tmpi.dat", MPI_MODE_RDWR
                | MPI_MODE_CREATE, MPI_INFO_NULL, &fh);

#if SET_ATOMICITY
        MPI_File_set_atomicity(fh, 1);
#endif

#if DISABLE_AGGREGATION
        MPI_Info_create(&info);
        MPI_Info_set(info, "cb_config_list", "*:*");
        MPI_File_set_view(fh, 0, MPI_BYTE, filetype, "native", info);
#else /* DISABLE_AGGREGATION */
        /* Set the file view */
        MPI_File_set_view(fh, 0, MPI_BYTE, filetype, "native", MPI_INFO_NULL);
#endif /* DISABLE_AGGREGATION */

        /* Write the data */
        MPI_File_write_at_all(fh, 0, wbuf, 
                              (mpi_rank == 0 ? 2 : 1) * block_len * sizeof(int), 
                              MPI_BYTE, MPI_STATUS_IGNORE);
#if DISABLE_AGGREGATION
        MPI_Info_free(&info);
#endif /* DISABLE_AGGREGATION */

#if REOPEN
        MPI_File_close(&fh);
        MPI_Barrier(MPI_COMM_WORLD);
        MPI_File_open(MPI_COMM_WORLD, "tmpi.dat", MPI_MODE_RDWR,
                MPI_INFO_NULL, &fh);
#if SET_ATOMICITY
        MPI_File_set_atomicity(fh, 1);
#endif
#else
#if ( !( REOPEN || SET_ATOMICITY ) )
        /* Sync/Barrier/Sync */
        MPI_File_sync(fh);
        MPI_Barrier(MPI_COMM_WORLD);
        MPI_File_sync(fh);
#endif
#endif


#if DISABLE_AGGREGATION
        MPI_Info_create(&info);
        MPI_Info_set(info, "cb_config_list", "*:*");
        MPI_File_set_view(fh, 0, MPI_BYTE, MPI_BYTE, "native", info);
#else /* DISABLE_AGGREGATION */
        MPI_File_set_view(fh, 0, MPI_BYTE, MPI_BYTE, "native", MPI_INFO_NULL);
#endif /* DISABLE_AGGREGATION */

        /* Read the data */
        MPI_File_read_at_all(fh, 0, rbuf, (mpi_size + 1) * block_len * sizeof(int),
                MPI_BYTE, MPI_STATUS_IGNORE);

#if DISABLE_AGGREGATION
        MPI_Info_free(&info);
#endif /* DISABLE_AGGREGATION */

        /* Verify the read data */
        failed = 0;
        for(j = 0; !failed && j < (mpi_size + 1) * block_len; j++)
            if(rbuf[j] != j + i) {
                failed = 1;
                failure_point = j;
                printf("Rank %d detected error on iteration %d at location %d!\n",
                        mpi_rank, i, j);
            }

	if ( ( mpi_rank == 0 ) && ( failed ) ) {

            k = 0;
            fprintf(stdout, "\n");
            for ( j = 0; j < (mpi_size + 1) * block_len; j++ ) {

                fprintf(stdout, " %d", rbuf[j]);
                k++;
                if ( k >= 10 ) {

                    k = 0;
                    fprintf(stdout, "\n");
                }
            }
            fprintf(stdout, "\n");

            fprintf(stdout, 
               "String representation of receive buffer starting at rbuf[%d]: \"%s\"\n\n",
               failure_point, (char *)(&(rbuf[failure_point])));
        }

#if REOPEN
        MPI_File_close(&fh);
        MPI_Barrier(MPI_COMM_WORLD);
        MPI_File_open(MPI_COMM_WORLD, "tmpi.dat", MPI_MODE_RDWR,
                MPI_INFO_NULL, &fh);
#if SET_ATOMICITY
        MPI_File_set_atomicity(fh, 1);
#endif
#else
#if ( ! ( REOPEN || SET_ATOMICITY ) )
        /* Sync/Barrier/Sync */
        MPI_File_sync(fh);
        MPI_Barrier(MPI_COMM_WORLD);
        MPI_File_sync(fh);
#endif
#endif

#if DO_IND_WRITE
        do_independant_write(fh, mpi_rank, mpi_size, i, 0);
#endif
        MPI_Type_free(&filetype);

        MPI_File_close(&fh);

        free(wbuf);
        free(rbuf);
    }

    MPI_Finalize();

    return 0;
}
