<template>
  <div class="text-center select-none bg-gray-50 pt-4">
    <div class="">
      <div class="text-2xl font-bold">Who is player one?</div>
      <div class="">
        <span
          class="text-3xl font-bold mr-5"
          :class="[!enabled ? 'text-indigo-600' : 'text-gray-400']"
          >You</span
        >
        <Switch
          v-model="enabled"
          class="
            flex-shrink-0
            group
            relative
            rounded-full
            inline-flex
            items-center
            justify-center
            h-5
            w-10
            cursor-pointer
            focus:outline-none
            focus:ring-2
            focus:ring-offset-2
            focus:ring-indigo-500
          "
        >
          <span class="sr-only">Select Player</span>
          <span
            aria-hidden="true"
            class="
              pointer-events-none
              absolute
              bg-white
              w-full
              h-full
              rounded-md
            "
          />
          <span
            aria-hidden="true"
            :class="[
              enabled ? 'bg-indigo-600' : 'bg-gray-200',
              'pointer-events-none absolute h-4 w-9 mx-auto rounded-full transition-colors ease-in-out duration-200',
            ]"
          />
          <span
            aria-hidden="true"
            :class="[
              enabled ? 'translate-x-5' : 'translate-x-0',
              'pointer-events-none absolute left-0 inline-block h-5 w-5 border border-gray-200 rounded-full bg-white shadow transform ring-0 transition-transform ease-in-out duration-200',
            ]"
          />
        </Switch>
        <span
          class="text-3xl font-bold ml-5"
          :class="[enabled ? 'text-indigo-600' : 'text-gray-400']"
          >AI</span
        >
      </div>
    </div>
    <div class="mb-5">
      <div class="text-2xl font-bold">Choose a model</div>
      <button
        class="
          text-m
          lg:text-xl
          mt-2
          rounded-l-lg
          border-r-2
          hover:bg-green-400
          border-black
          p-2
          text-white
          mx-auto
        "
        :class="[modelName !== 'jos' ? 'bg-green-600' : 'bg-green-800']"
        @click="selectModel('jos')"
      >
        Jose 1
      </button>
      <button
        class="
          text-m
          lg:text-xl
          mt-2
          border-r-2
          hover:bg-green-400
          border-black
          p-2
          text-white
          mx-auto
        "
        :class="[modelName !== 'jos-2' ? 'bg-green-600' : 'bg-green-800']"
        @click="selectModel('jos-2')"
      >
        Jose 2
      </button>
      <button
        class="
          text-m
          lg:text-xl
          mt-2
          p-2
          text-white
          border-r-2 border-black
          hover:bg-green-400
          bg-green-600
          mx-auto
        "
        :class="[modelName !== 'win' ? 'bg-green-600' : 'bg-green-800']"
        @click="selectModel('win')"
      >
        Winthrop
      </button>
      <button
        class="
          text-m
          lg:text-xl
          mt-2
          border-r-2
          hover:bg-green-400
          border-black
          p-2
          text-white
          mx-auto
        "
        :class="[modelName !== 'win-2' ? 'bg-green-600' : 'bg-green-800']"
        @click="selectModel('win-2')"
      >
        Winthrop 2
      </button>
      <button
        class="
          text-m
          lg:text-xl
          mt-2
          p-2
          border-black
          text-white
          border-r-2
          hover:bg-green-400
          bg-green-600
          mx-auto
        "
        :class="[modelName !== 'q-learner_1' ? 'bg-green-600' : 'bg-green-800']"
        @click="selectModel('q-learner_1')"
      >
        Q-Learner 1
      </button>
      <button
        class="
          text-m
          lg:text-xl
          mt-2
          p-2
          rounded-r-lg
          text-white
          border-r-2
          hover:bg-green-400
          bg-green-600
          mx-auto
        "
        :class="[modelName !== 'q-learner_2' ? 'bg-green-600' : 'bg-green-800']"
        @click="selectModel('q-learner_2')"
      >
        Q-Learner 2
      </button>
    </div>
    <canvas
      id="canvas"
      class="mx-auto w-3/5"
      @mouseleave="stopHover"
      @mousemove="displayHover"
      @click="play"
    ></canvas>
    <button
      class="text-xl mt-2 rounded p-2 text-white bg-green-600 mx-auto"
      @click="resetGame"
    >
      Reset
    </button>
  </div>
</template>

<script>
import * as tf from "@tensorflow/tfjs";
import { toRaw } from "vue";
import * as CryptoJS from "crypto-js";
import { get } from "../helper";

import { ref } from "vue";
import { Switch } from "@headlessui/vue";

export default {
  name: "Connect4",
  components: {
    Switch,
  },
  async mounted() {
    this.canvas = document.getElementById("canvas");
    // This will scale the width based on the screen size while maintaining the
    // board's aspect ratio.
    this.width = Math.min(window.outerWidth, 950);
    this.height = this.width * (800 / 950);
    window.requestAnimationFrame(this.drawBoard);

    // Update the board dimensions when/if the screen resizes
    window.addEventListener("resize", () => {
      this.width = Math.min(window.outerWidth, 950);
      this.height = this.width * (800 / 950);
      window.requestAnimationFrame(this.drawBoard);
    });

    await this.selectModel("jos");
    this.resetGame();
  },
  computed: {
    offsetY() {
      return this.radius + 10;
    },
  },

    watch: {
      // whenever question changes, this function will run
      enabled: function () {
        console.log("hi");
        this.resetGame();
      },
    },
  data() {
    return {
      enabled: ref(false),
      model: null,
      gameReady: false,
      canvas: undefined,
      hoverColumNumber: -1,
      radius: 60,
      pad: 10,
      offsetX: 30,
      scaleFactor: 1,
      // Why 800 / 950? Honestly because I just hardcoded those dimensions while
      // expirementing with the canvas and the board looked fine so I kept them 🤷🏽‍♂️
      height: 800,
      width: 950,
      // [0][5] is the bottom left; yeah it's transposed but it made the render
      // logic easy doing it this way; just transpose
      boardState: [
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
      ],
      turn: 1,
      mappingFunc: null,
      round: 1,
      modelName: null,
      gameOver: false,
      aiMarker: 0,
    };
  },
  methods: {
    async selectModel(modelName) {
      if (this.gameReady) {
        this.resetGame();
      }
      this.gameReady = false;
      const t0 = Date.now();
      this.drawLoadingScreen();
      this.modelName = modelName;
      await this.loadModel(modelName);

      const deltaTime = Date.now() - t0;

      const delay = 250 - deltaTime;
      setTimeout(() => {
        this.resetGame();
      }, Math.max(delay, 0));
    },
    async predict(boardState) {
      const model = toRaw(this.model);
      let pred;

      if (toRaw(this.modelName) == "jos" || toRaw(this.modelName) == "jos-2") {
        pred = model.predict({ "input/Ob": boardState });
      } else if (toRaw(this.modelName).startsWith("win")) {
        pred = await model.executeAsync({ "input/Ob": boardState });
      }

      return pred;
    },
    async loadModel(modelName) {
      this.gameReady = false;
      this.mappingFunc = null;
      const versions = {
        jos: "100",
        "jos-2": "100",
        win: "100",
        "win-2": "100",
      };
      if (
        modelName == "jos" ||
        modelName == "win" ||
        modelName == "jos-2" ||
        modelName == "win-2"
      ) {
        const version = localStorage.getItem(`${modelName}-version`);
        let invalidCache = false;
        if (version !== versions[modelName]) {
          invalidCache = true;
          localStorage.removeItem(
            `tensorflowjs_models/model-${modelName}/info`
          );
        }
        const modelExist = localStorage.getItem(
          `tensorflowjs_models/model-${modelName}/info`
        );

        if (modelExist && !invalidCache) {
          console.log(`loading from local storage`);
          this.model = await tf.loadGraphModel(
            `localstorage://model-${modelName}`
          );
        } else {
          console.log(`loading from network`);
          const url = new URL(
            `/models/connect4/${modelName}/model.json`,
            process.env.VUE_APP_BASE_URL
          );
          this.model = await tf.loadGraphModel(url.href);
          await this.model.save(`localstorage://model-${modelName}`);
          localStorage.setItem(`${modelName}-version`, versions[modelName]);
        }

        this.mappingFunc = modelName.startsWith("win")
          ? this.winMap
          : this.josMap;
        this.gameReady = true;
      } else {
        console.log("Loading Q-Learner from network");
        let q_table = await get(
          `/models/connect4/${modelName}/hashedLearner.json`
        );
        q_table = await q_table.json();
        console.log(q_table);
        this.model = null;
        this.q_table = q_table;
        this.gameReady = true;
      }
    },
    nextTurn() {
      this.turn = this.turn === 1 ? 2 : 1;
      this.currentPlayer = this.currentPlayer === "user" ? "ai" : "user";
    },
    checkLine(a, b, c, d) {
      return a && a === b && a === c && a === d;
    },
    checkWin() {
      for (let r = 0; r < 3; r++)
        for (let c = 0; c < 7; c++)
          if (
            this.checkLine(
              this.boardState[c][r],
              this.boardState[c][r + 1],
              this.boardState[c][r + 2],
              this.boardState[c][r + 3]
            )
          )
            return this.boardState[c][r];

      // Check right
      for (let r = 0; r < 6; r++)
        for (let c = 0; c < 4; c++)
          if (
            this.checkLine(
              this.boardState[c][r],
              this.boardState[c + 1][r],
              this.boardState[c + 2][r],
              this.boardState[c + 3][r]
            )
          )
            return this.boardState[c][r];

      // Check down-right
      for (let r = 0; r < 3; r++)
        for (let c = 0; c < 4; c++)
          if (
            this.checkLine(
              this.boardState[c][r],
              this.boardState[c + 1][r + 1],
              this.boardState[c + 2][r + 2],
              this.boardState[c + 3][r + 3]
            )
          )
            return this.boardState[c][r];

      // Check down-left
      for (let r = 3; r < 6; r++)
        for (let c = 0; c < 4; c++)
          if (
            this.checkLine(
              this.boardState[c][r],
              this.boardState[c + 1][r - 1],
              this.boardState[c + 2][r - 2],
              this.boardState[c + 3][r - 3]
            )
          )
            return this.boardState[c][r];

      return 0;
    },
    drawPiece(columnNumber, rowNumber, piece) {
      const x =
        this.offsetX +
        this.radius +
        columnNumber * (this.pad + this.radius * 2);

      const y = this.offsetY + rowNumber * (this.pad + 2 * this.radius);
      const ctx = this.canvas.getContext("2d");
      ctx.beginPath();
      if (piece === 1) {
        // player 1 (red)
        ctx.fillStyle = "rgba(255, 0, 0, 1)";
        ctx.strokeStyle = "rgba(255, 0, 0, 1)";
      } else if (piece === 2) {
        // player 2 (yellow)
        ctx.fillStyle = "rgba(255, 255, 0, 1)";
        ctx.strokeStyle = "rgba(255, 255, 0, 1)";
      } else {
        // empty (white)
        ctx.fillStyle = "rgba(255, 255, 255, 1)";
        ctx.strokeStyle = "rgba(255, 255, 255, 1)";
      }
      ctx.arc(x, y, this.radius, 0, 2 * Math.PI);
      ctx.fill();
      ctx.stroke();
    },
    josMap() {
      // Now let the AI go
      const mappedState = this.boardState.reduce((prev, curr) => {
        prev.push(curr.map((val) => val));
        return prev;
      }, []);

      return tf.reshape(
        tf.cast(tf.tensor(this.transpose(mappedState)), "int32"),
        [-1, 6, 7, 1]
      );
    },
    winMap() {
      // Now let the AI go
      const channel1 = this.transpose(
        this.boardState.reduce((prev, curr) => {
          prev.push(curr.map((val) => (val === toRaw(this.aiMarker) ? 1 : 0)));
          return prev;
        }, [])
      );
      console.log("\nChannel 1:");
      console.log(channel1);

      const channel2 = this.transpose(
        this.boardState.reduce((prev, curr) => {
          prev.push(
            curr.map((val) => (val === toRaw(this.playerMarker) ? 1 : 0))
          );
          return prev;
        }, [])
      );
      console.log("\nChannel 2:");
      console.log(channel2);

      const channel3 = this.transpose(
        this.boardState.map((column) => {
          const top1 = column.indexOf(1);
          const top2 = column.indexOf(2);

          const newColumn = [0, 0, 0, 0, 0, 0];
          // Empty column
          if (top1 == -1 && top2 == -1) {
            newColumn[5] = 1;
            return newColumn;
          }

          // Must be 1
          let topIndex;
          if (top1 == -1) {
            // If Only yellow
            topIndex = top2;
          } else if (top2 == -1) {
            // if only red
            topIndex = top1;
          } else {
            // Column has both colors
            topIndex = Math.min(top1, top2); // figure out which is the highest
            if (topIndex == 0) {
              // if the column is full
              return newColumn;
            }
          }

          topIndex = topIndex - 1; // minus 1 for the open spot

          newColumn[topIndex] = 1;
          return newColumn;
        })
      );
      console.log("\nChannel 3: ");
      console.log(channel3);

      const mapped = [];
      for (let i = 0; i < channel1.length; i++) {
        mapped.push(this.transpose([channel1[i], channel2[i], channel3[i]]));
      }

      const input = tf.reshape(tf.tensor(mapped), [-1, 6, 7, 3]);

      console.log("Input tensor:");
      input.print();
      return input;
    },
    userMove(event) {
      if (!this.gameReady) {
        return;
      }

      // Check the player's move
      let columnNumber = this.getColumn(event);
      if (columnNumber === undefined) {
        return;
      }

      // Validate the positions
      let openPosition = this.boardState[columnNumber].lastIndexOf(0);
      if (openPosition < 0) {
        return;
      }

      // Update internal state
      this.boardState[columnNumber][openPosition] = this.turn;

      // Render the move
      this.drawPiece(columnNumber, openPosition, this.turn);

      // Check for a winner
      let winner = this.checkWin(columnNumber, openPosition);
      if (winner) {
        alert(`winner: ${winner}`);
        return true;
      }

      return false;
    },
    async aiMove() {
      this.gameReady = false;
      console.log(`\n\nBoard State - Round${this.round++}`);
      tf.tensor(this.transpose(toRaw(this.boardState))).print();

      // Get the AI's predictions
      const prediction = await this.getAIMove();

      // Figure out which column to use
      const modelName = toRaw(this.modelName);
      let columnNumber = 3;
      if (
        modelName === "jos" ||
        modelName === "jos-2" ||
        modelName === "win-2"
      ) {
        columnNumber = prediction.arraySync()[0];
      } else if (modelName === "win") {
        const probs = prediction.arraySync()[0];
        console.log(`\nProbs:`);
        console.log(probs);

        columnNumber = probs.indexOf(Math.max(...probs));
        console.log(`\nAI Play column: ${columnNumber}`);
      } else if (modelName.startsWith("q-learner")) {
        columnNumber = prediction;
      }

      // Validate the move
      const openPosition = this.boardState[columnNumber].lastIndexOf(0);
      if (openPosition < 0) {
        return;
      }

      // Update the internal state
      this.boardState[columnNumber][openPosition] = this.turn;

      // Render the move
      await this.sleep(150);
      this.drawPiece(columnNumber, openPosition, this.turn);
      const winner = this.checkWin(columnNumber, openPosition);
      this.gameReady = true;

      // Check for a winner
      if (winner) {
        alert(`winner: ${winner}`);
        return true;
      }
      return false;
    },
    async play(event) {
      if (this.gameOver || !this.gameReady) {
        return;
      }
      if (this.currentPlayer === "user") {
        const userWon = this.userMove(event);
        if (userWon) {
          this.gameOver = true;
          return;
        }
        this.nextTurn();
        this.play();
      } else {
        const aiWon = await this.aiMove(event);
        if (aiWon) {
          this.gameOver = true;
          return;
        }

        this.nextTurn();
      }
    },
    resetGame() {
      this.round = 1;
      this.boardState = [
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
      ];
      this.turn = 1;
      this.drawBoard();
      this.currentPlayer = !this.enabled ? "user" : "ai";
      this.aiMarker = 2;
      this.playerMarker = 1;

      this.gameOver = false;
      this.gameReady = true;
      if (this.currentPlayer === "ai") {
        this.aiMarker = 1;
        this.playerMarker = 2;
        this.play();
      }
    },
    getColumn(event) {
      const pos = this.getMousePos(event);
      // I hardcoded pixel positions here before I implemented scaling so I
      // just replaced them with this array. columnBounds is recalculated on
      // draw calls
      if (pos.x < this.columnBounds[0]) {
        return 0;
      } else if (pos.x < this.columnBounds[1]) {
        return 1;
      } else if (pos.x < this.columnBounds[2]) {
        return 2;
      } else if (pos.x < this.columnBounds[3]) {
        return 3;
      } else if (pos.x < this.columnBounds[4]) {
        return 4;
      } else if (pos.x < this.columnBounds[5]) {
        return 5;
      } else if (pos.x < this.columnBounds[6]) {
        return 6;
      } else {
        return;
      }
    },
    // This displays they greyish overlay on the column you're hovering over
    displayHover(event) {
      if (!this.gameReady) {
        return;
      }

      const columnNumber = this.getColumn(event);

      if (
        columnNumber === undefined ||
        columnNumber === this.hoverColumNumber
      ) {
        return;
      }
      this.drawBoard();
      this.hoverDisplayed = true;
      this.hoverColumNumber = columnNumber;

      window.requestAnimationFrame(() => {
        const ctx = document.getElementById("canvas").getContext("2d");

        ctx.fillStyle = "rgba(0, 0, 0, 0.2)";
        ctx.strokeStyle = "rgba(0, 0, 0, 0.2)";
        ctx.fillRect(
          this.offsetX + columnNumber * (this.pad + this.radius * 2),
          0,
          this.radius * 2,
          this.height
        );
      });
    },
    // And this gets rid of it by redrawing the board
    stopHover() {
      this.hoverDisplayed = false;
      this.hoverColumNumber = -1;
      this.drawBoard();
    },
    drawLoadingScreen() {
      const ctx = this.canvas.getContext("2d");

      this.canvas.style.width = this.width + "px";
      this.canvas.style.height = this.height + "px";
      const scale = window.devicePixelRatio;
      this.canvas.width = Math.floor(this.width * scale);
      this.canvas.height = Math.floor(this.height * scale);
      ctx.scale(scale, scale);

      ctx.fillStyle = "rgba(0, 0, 0, 1)";
      ctx.fillRect(0, 0, this.width, this.height);
      ctx.font = "48px serif";
      ctx.fillStyle = "white";
      ctx.fillText("Loading...", 10, 50);
    },
    drawBoard() {
      const ctx = this.canvas.getContext("2d");

      this.canvas.style.width = this.width + "px";
      this.canvas.style.height = this.height + "px";
      const scale = window.devicePixelRatio;
      this.canvas.width = Math.floor(this.width * scale);
      this.canvas.height = Math.floor(this.height * scale);
      ctx.scale(scale, scale);

      const columnBounds = [];
      for (let i = 0; i < 7; i++) {
        columnBounds.push(
          this.offsetX + this.radius * 2 + (this.radius * 2 + this.pad) * i
        );
      }

      this.columnBounds = columnBounds;

      this.offsetX = (this.width * 0.05) / 2;
      this.radius = (this.width - this.pad * 7 - this.offsetX * 2) / 14; // 14: 7 columns /2 for diameter -> radius

      ctx.fillStyle = "rgba(0, 54, 201, 1)";
      ctx.fillRect(0, 0, this.width, this.height);

      for (let i = 0; i < 7; i++) {
        for (let j = 0; j < 6; j++) {
          this.drawPiece(i, j, this.boardState[i][j]);
        }
      }
    },
    sleep(ms) {
      return new Promise((resolve) => setTimeout(resolve, ms));
    },
    async getAIMove() {
      if (
        toRaw(this.modelName) === "q-learner_1" ||
        toRaw(this.modelName) === "q-learner_2"
      ) {
        return this.get_q_learner_move(this.q_table);
      } else {
        const mappedState = this.mappingFunc(toRaw(this.boardState));
        if (
          toRaw(this.modelName) === "jos" ||
          toRaw(this.modelName) === "jos-2" ||
          toRaw(this.modelName) === "win-2"
        ) {
          return await this.predict(mappedState);
        } else if (toRaw(this.modelName) === "win") {
          return await this.predict(mappedState);
        }
      }
    },
    getMousePos(evt) {
      const rect = this.canvas.getBoundingClientRect();
      return {
        x: evt.clientX - rect.left,
        y: evt.clientY - rect.top,
      };
    },
    transpose(m) {
      return m[0].map((x, i) => m.map((x) => x[i]));
    },
    hashKey(key) {
      let newKey = [];
      for (let val of key) {
        val = Object.is(val, -0) ? 0 : parseInt(val);
        val = val === 2 ? -1 : val;
        newKey.push(parseInt(val));
      }
      console.log(newKey);

      newKey = JSON.stringify(newKey);
      console.log(newKey);
      const md5Key = CryptoJS.MD5(newKey).toString(CryptoJS.enc.Base64);
      return md5Key;
    },
    get_q_learner_move(q_table) {
      function randomInt(upper) {
        return Math.floor(Math.random() * upper);
      }
      const key = this.hashKey(this.transpose(toRaw(this.boardState)).flat());
      console.log(key);
      let move;

      if (q_table[key] === undefined) {
        move = randomInt(7);
      } else {
        const max = Math.max(...q_table[key]);
        const indices = q_table[key].reduce((acc, el, i) => {
          if (el == max) acc.push(i);
          return acc;
        }, []);
        move = indices[randomInt(indices.length)];
      }

      return move;
    },
  },
};
</script>

<style lang="scss" scoped></style>
