#include "solver.hpp"
#include <cstdint>


// n even => 0..n/2 solutions*2
// n odd => 0..floor(n/2) solutions*2 bzw. middle solutions*1 (but symmetry in second row)

// If n is even, we have a symmetry on the first row => we only have to consider the left half and double the amount of solutions
// If n is odd, we can also use this symmetry, except if it the center (we only have the normal amount of solutions then)
// But if n is odd and the queen in the first row is in the center, we can use the even-like symmetry in the second row

enum class stage {
    first_even,
    first_odd,
    second_odd_center,
    normal,
};

template <stage Stage, int SolutionMultiplier>
static int find_solutions(int n, uint32_t used_columns, uint32_t used_diagonals_lb, uint32_t used_diagonals_rt) {
  int solutions = 0;
  
  used_diagonals_lb >>= 1; // used_diagonals_lb holds the blocked fields for the previous row -> shift it one to the right for this row
  used_diagonals_rt <<= 1;

  // We cannot place the queen if the square is in the same column or diagonal as another queen
  uint32_t inverted_my_used_columns = ~(used_columns | used_diagonals_lb | used_diagonals_rt);
  if constexpr (Stage == stage::first_even || Stage == stage::first_odd || Stage == stage::second_odd_center) {
    // We can mask the right half of the columns for this iteration because they would just generate symmetric solutions
    inverted_my_used_columns &= ~((1 << (n/2)) - 1);
  }
  
  while (inverted_my_used_columns) {
    // Extract the lowest 0 and negate it to get a bitmask for the selected column
    uint32_t col_bit = inverted_my_used_columns & (~inverted_my_used_columns + 1);
    
    // We don't want to look at this column again (aka i++)
    inverted_my_used_columns ^= col_bit;
    
    uint32_t now_used_columns = used_columns ^ col_bit; // The current column is now blocked because a queen is placed there
    // If every column is allocated we are done
    if (~now_used_columns == 0) {
      solutions = SolutionMultiplier;
    } else {
      if constexpr (Stage == stage::first_even || Stage == stage::first_odd || Stage == stage::second_odd_center) {
        if (Stage == stage::first_odd && col_bit == (uint32_t(1) << (n / 2))) {
          solutions += find_solutions<stage::second_odd_center, SolutionMultiplier>(
            n,
            now_used_columns,
            used_diagonals_lb | col_bit, // block the current square for the diagonals
            used_diagonals_rt | col_bit
          );
        } else {
          solutions += find_solutions<stage::normal, SolutionMultiplier * 2>(
            n,
            now_used_columns,
            used_diagonals_lb | col_bit, // block the current square for the diagonals
            used_diagonals_rt | col_bit
          );
        }
      } else {
        solutions += find_solutions<stage::normal, SolutionMultiplier>(
          n,
          now_used_columns,
          used_diagonals_lb | col_bit, // block the current square for the diagonals
          used_diagonals_rt | col_bit
        );
      }
    }
  }

  return solutions;
}

int solve(int n) {
  uint32_t used_columns = (uint32_t(-1) << n);  // set all columns which are "out of bounds" to used
  if (n % 2 == 0) {
    return find_solutions<stage::first_even, 1>(n, used_columns, 0, 0);
  } else {
    return find_solutions<stage::first_odd, 1>(n, used_columns, 0, 0);
  }
}
