#include "solver.hpp"
#include <iostream>
#include <algorithm>
#include <memory>

struct board {
  int n;
  std::unique_ptr<bool[]> fields;

  board(int n) : n(n), fields(new bool[n*n]) {
    std::fill(fields.get(), fields.get() + n*n, false);
  }

  bool& at(int row, int col) {
    return fields[row*n + col];
  }
};

static bool check_valid(board& board) {
  int n = board.n;
  ::board check_board(n);

  for (int row = 0; row < n; row++) {
    for (int col = 0; col < n; col++) {
      if (board.at(row, col)) {
        if (check_board.at(row, col)) {
          return false;
        }
        // fill rows and cols
        for (int i = 0; i < n; i++) {
          check_board.at(i, col) = true;
          check_board.at(row, i) = true;
        }
        // fill diagonals
        for (int i = -n; i < n; i++) {
          int r = row + i;
          int c1 = col + i;
          
          if (i != 0 && r >= 0 && r < n && c1 >= 0 && c1 < n) {
            check_board.at(r, c1) = true;
          }
          
          int c2 = col - i;
          if (i != 0 && r >= 0 && r < n && c2 >= 0 && c2 < n) {
            check_board.at(r, c2) = true;
          }
        }
      }
    }
  }
  
  return true;
}

static int find_solutions(board& board, int queens_todo, int start_row, int start_col) {
  if (queens_todo == 0) {
    bool valid = check_valid(board);
    if (valid) {
      return 1;
    } else {
      return 0;
    }
  } else {
    int solutions = 0;
    for (int row = start_row; row < board.n; ++row) {
      for (int col = (row == start_row ? start_col : 0); col < board.n; ++col) {
        bool& cell = board.at(row, col);
        if (!cell) {
          cell = true;
          solutions += find_solutions(board, queens_todo - 1, row, col);
          cell = false;
        }
      }
    }
    return solutions;
  }
}

int solve(int n) {
  board board(n);
  int result = find_solutions(board, n, 0, 0);
  return result;
}
