#include "aes.h"
#include "scs_lib.h"

#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstdlib>

// this number varies on different systems
#define MIN_CACHE_MISS_CYCLES (160)

// more encryptions show features more clearly
#define NUMBER_OF_ENCRYPTIONS (1000)

// this function fills the provided block with random bytes
void randomize_block(uint8_t block[16]) {
    for ( size_t i = 0; i < 16; ++i ) {
        block[i] = rand() % 256;
    }
}

// this function prints the block in hex
void print_block_hex(uint8_t block[16]) {
    for ( int i = 0; i < 16; ++i ) {
        printf("0x%02x ", block[i]);
    }
    printf("\n");
}

// this function prints the block with human readable characters
void print_block_readable(uint8_t block[16]) {
    for ( int i = 0; i < 16; ++i ) {
        printf("%c", (char)block[i]);
    }
    printf("\n");
}

// This function should return the address to the corresponding T-table for the
// current key byte attacked.
//
// The following snippet shows the implementation of the first round AES with
// T-tables:
//
// t0 = Te0[key[ 0] ^ pt[ 0]] ^ Te1[key[ 5] ^ pt[ 5]] ^ Te2[key[10] ^ pt[10]] ^ Te3[key[15] ^ pt[15]] ^ ...
// t1 = Te0[key[ 4] ^ pt[ 4]] ^ Te1[key[ 9] ^ pt[ 9]] ^ Te2[key[14] ^ pt[14]] ^ Te3[key[ 3] ^ pt[ 3]] ^ ...
// t2 = Te0[key[ 8] ^ pt[ 8]] ^ Te1[key[13] ^ pt[13]] ^ Te2[key[ 2] ^ pt[ 2]] ^ Te3[key[ 7] ^ pt[ 7]] ^ ...
// t3 = Te0[key[12] ^ pt[12]] ^ Te1[key[ 1] ^ pt[ 1]] ^ Te2[key[ 6] ^ pt[ 6]] ^ Te3[key[11] ^ pt[11]] ^ ...
//
// In our attack, we want to recover the key specified in the "key" variable. Notice that each key byte is
// used to index a different T-table denoted by Te0-Te3.
//
//
// @param libaes_base: is the address where we  mapped the shared library
// @param byte_index: is the current key byte index that should be recovered
//
// @return: return the associated T-table address for the current byte index.
//
uint8_t *get_ttable_address(uint8_t *libaes_base, size_t byte_index) {

    // TODO: extract the T-table addresses from libaes.so with readelf
    // "Te0", "Te1", "Te2", "Te3"
    ptrdiff_t Te0_offset = 0;
    ptrdiff_t Te1_offset = 0;
    ptrdiff_t Te2_offset = 0;
    ptrdiff_t Te3_offset = 0;

    // TODO: calculate the correct T-Table index based on the current`byte_index`.
    // Take a look at how the key bytes are used in the first round.

    size_t ttable_select = 0;

    switch ( ttable_select ) {
        case 0:
            return libaes_base + Te0_offset;
        case 1:
            return libaes_base + Te1_offset;
        case 2:
            return libaes_base + Te2_offset;
        case 3:
            return libaes_base + Te3_offset;
    }
    return nullptr;
}

// This function performs one round of the AES first round Flush & Reload attack.
// 1) First we extract the T-table address for the given byte_index.
// 2) We generate a random Plaintext except the current key byte we want to recover
// 3) We set the key byte we want ot recover to your guess
// 4) We check if we observe a Flush & Reload hit on the 0th entry of the corresponding
//    T-table, indicating that 0 <= (key[byte_index] ^ plaintext[byte_index]) < 16
// 5) We repeat this process until we got enough confidence (parent function)
//
// @param libaes_base: is the address where we  mapped the shared library
// @param byte_index: is the current key byte index that should be recovered
// @param byte_guess: the current byte guess for the key byte at position byte_index
//
// @return: returns true, if Flush & Reload observed a hit on the corresponding T-table
// (for the current byte_index) at postion 0
//
bool perform_one_round(uint8_t *libaes_base, size_t byte_index, uint8_t byte_guess) {

    uint8_t *probe = get_ttable_address(libaes_base, byte_index);

    uint8_t plaintext[16]  = {};
    uint8_t ciphertext[16] = {};

    randomize_block(plaintext);

    plaintext[byte_index] = byte_guess;

    encrypt(plaintext, ciphertext);

    bool zeroth_ttable_entry_hit = false;

    // TODO: implement flush&reload on the given zeroth T-table entry.

    return zeroth_ttable_entry_hit;
}

// This function iterates over all the possible key byte values
// and returns the correct value for the current byte index.
//
// @param libaes_base: is the address where we  mapped the shared library
// @param byte_index: is the current key byte index that should be recovered
//
// @return: returns the key byte for the current byte_index
//
uint8_t recover_key_byte(uint8_t *libaes_base, size_t byte_index) {

    // Iterate over all the possible key byte values.
    // Remember: The first round attack can only recover
    // the upper 4 bits of each byte. Therefore, we skip 16 values as
    // these values do not provide additional information.
    for ( size_t byte_guess = 0; byte_guess < 256; byte_guess += 16 ) {

        // number of overall cache hits
        uint64_t hits = 0;

        // repeat the attack multiple times to gain confidence
        for ( size_t i = 0; i < NUMBER_OF_ENCRYPTIONS; ++i ) {
            hits += perform_one_round(libaes_base, byte_index, byte_guess);
        }

        // calculate the hit ratio
        double hit_ratio = (double)hits / NUMBER_OF_ENCRYPTIONS;

        // if we observe a byte_guess with more than 95% hit ratio
        // we found the correct key byte
        if ( hit_ratio > 0.95 ) {
            return byte_guess;
        }
    }

    // we did not find a byte guess
    return 0;
}

int main() {
    // seed random number generator
    srand(time(0));

    // map the shared library into out process
    uint8_t *libaes_base = open_shared("libaes.so");

    // these are the known key parts, which we provide to you
    // your goal in this exercise is to recover the upper 4 bit of each of the key bytes using
    // the first round flush & reload attack on the AES T-table implementation
    uint8_t key[16] = { 0xd, 0xb, 0x2, 0xf, 0xd, 0x3, 0x2, 0x2, 0xf, 0xb, 0xd, 0xc, 0x5, 0xe, 0xc, 0x7 };

    // for each of the 16 key bytes recover the byte value
    for ( size_t byte_index = 0; byte_index < 16; ++byte_index ) {

        // recover the key byte
        uint8_t key_byte = recover_key_byte(libaes_base, byte_index);

        // fill in the upper 4 bits of each key byte into the final key
        key[byte_index] |= (key_byte & 0xF0);
    }

    // print the recovered key
    printf("The recovered key:\n");
    print_block_hex(key);

    // try to decrypt the flag for the EDX page
    uint8_t decrypted_flag[17] = {};
    verify(key, decrypted_flag);
    printf("The decrypted flag:\n");
    print_block_readable(decrypted_flag);

    return 0;
}
