#ifndef _HASHSET_H_
#define _HASHSET_H_

#include <assert.h>
#include <stdbool.h>
#include <stdlib.h>
#include <string.h>

struct hashset
{
  size_t size, nmemb, thresh;
  struct node **tab;
};
/* invariant size = 2**n for some n */
/* invariant thresh = LOADFACTOR * size */
/* invariant 0 <= nmemb <= thresh */
/* invariant tab is an array of size elements */

struct node
{
  size_t hash;
  char *key;
  struct node *next;
};

#define HS_MINSIZE   16
#define HS_MINTHRESH 12

size_t hash (const char *str);

static inline bool __attribute__((const))
hs_has (struct hashset *hs, const char *key)
{
  if (!hs->size)
    return false;
  size_t hsh = hash (key);
  size_t index = (hs->size - 1) & hsh;

  for (struct node *ptr = hs->tab[index]; ptr; ptr = ptr->next)
    if (hsh == ptr->hash && !strcmp (ptr->key, key))
      return true;
  return false;
}

//static inline bool
//hs_remove (struct hashset *hs, const char *key)
//{
//  if (!hs->size)
//    return false;
//  size_t hsh = hash (key);
//  size_t index = (hs->size - 1) & hsh;
//
//  struct node *prev = NULL;
//  for (struct node *ptr = hs->tab[index]; ptr; ptr = ptr->next)
//    if (hsh == ptr->hash && !strcmp (ptr->key, key))
//    {
//      if (prev)
//        prev->next = ptr->next;
//      else
//        hs->tab[index] = NULL;
//      free (ptr->key);
//      free (ptr);
//      hs->nmemb--;
//      return true;
//    }
//  return false;
//}

static inline bool
hs_put (struct hashset *hs, const char *key)
{
  if (hs->nmemb == hs->thresh)
  {
    size_t oldsize = hs->size;
    if (hs->size)
    {
      hs->size *= 2;
      hs->thresh *= 2;
    }
    else
    {
      hs->size = HS_MINSIZE;
      hs->thresh = HS_MINTHRESH;
    }

    struct node **new = calloc (hs->size, sizeof *new);
    assert (new);

    for (size_t i = 0; i < oldsize; i++)
    {
      struct node *lhead = NULL, *ltail = NULL;
      struct node *hhead = NULL, *htail = NULL;
      for (struct node *ptr = hs->tab[i]; ptr; ptr = ptr->next)
        if (ptr->hash & oldsize)
        {
          if (htail)
            htail->next = ptr;
          else
            hhead = ptr;
          htail = ptr;
        }
        else
        {
          if (ltail)
            ltail->next = ptr;
          else
            lhead = ptr;
          ltail = ptr;
        }

      if (ltail)
      {
        ltail->next = NULL;
        new[i] = lhead;
      }
      if (htail)
      {
        htail->next = NULL;
        new[i + oldsize] = hhead;
      }
    }

    free (hs->tab);
    hs->tab = new;
  }

  size_t hsh = hash (key);
  size_t index = (hs->size - 1) & hsh;
  struct node *prev = NULL;
  for (struct node *ptr = hs->tab[index]; ptr; prev = ptr, ptr = ptr->next)
    if (hsh == ptr->hash && !strcmp (ptr->key, key))
      return false;

  struct node *new = malloc (sizeof *new);
  assert (new);
  hs->nmemb++;
  new->hash = hsh;
  new->key = strdup (key);
  new->next = NULL;
  if (prev)
    prev->next = new;
  else
    hs->tab[index] = new;
  return true;
}

static inline void
hs_free (const struct hashset *hs)
{
  for (size_t i = 0; i < hs->size; i++)
    for (struct node *ptr = hs->tab[i], *next; ptr; ptr = next)
    {
      next = ptr->next;
      free (ptr->key);
      free (ptr);
    }
  free (hs->tab);
}

#endif /* _HASHSET_H_ */
