import TestLib;

StartTest("SplayTree_as_Set");

from "template/imports/wrapper"(T=int) access
    Wrapper_T as wrapped_int,
    wrap,
    alias;

bool operator < (wrapped_int a, wrapped_int b) {
  return a.t < b.t;
}

bool operator == (wrapped_int a, wrapped_int b) {
  return a.t == b.t;
}

from "template/imports/sortedset"(T=wrapped_int) access
    makeNaiveSortedSet,
    SortedSet_T as SortedSet_wrapped_int;

from "template/imports/splaytree"(T=wrapped_int) access
    SplayTree_T as SplayTree_wrapped_int,
    operator cast;

struct ActionEnum {
  static restricted int numActions = 0;
  static private int next() {
    return ++numActions - 1;
  }
  static restricted int INSERT = next();
  static restricted int REPLACE = next();
  static restricted int DELETE = next();
  static restricted int CONTAINS = next();
  static restricted int DELETE_CONTAINS = next();
}

from mapArray(Src=wrapped_int, Dst=int) access map;
int get(wrapped_int a) {
  return a.t;
}

int[] operator cast(wrapped_int[] a) {
  for (wrapped_int x : a) {
    assert(!alias(x, null), 'Null element in array');
  }
  return map(get, a);
}

string differences(SortedSet_wrapped_int a, SortedSet_wrapped_int b) {
  if (a.size() != b.size()) {
    return 'Different sizes: ' + string(a.size()) + ' vs ' + string(b.size());
  }
  wrapped_int[] aArray = a;
  int[] aIntArray = aArray;
  wrapped_int[] bArray = b;
  int[] bIntArray = bArray;
  string arrayValues = '[\n';
  bool different = false;
  for (int i = 0; i < aIntArray.length; ++i) {
    arrayValues += '  [' + format('%5d', aIntArray[i]) + ',' 
                   + format('%5d', bIntArray[i]) + ']';
    if (!alias(aArray[i], bArray[i])) {
      arrayValues += '  <---';
      different = true;
    }
    arrayValues += '\n';
  }
  arrayValues += ']';
  // write(arrayValues + '\n');
  if (different) {
    return arrayValues;
  }
  return '';
}

string string(int[] a) {
  string result = '[';
  for (int i = 0; i < a.length; ++i) {
    if (i > 0) {
      result += ', ';
    }
    result += string(a[i]);
  }
  result += ']';
  return result;
}

string string(bool[] a) {
  string result = '[';
  for (int i = 0; i < a.length; ++i) {
    if (i > 0) {
      result += ', ';
    }
    result += a[i] ? 'true' : 'false';
  }
  result += ']';
  return result;
}

typedef void Action(int ...SortedSet_wrapped_int[]);

Action[] actions = new Action[ActionEnum.numActions];
actions[ActionEnum.INSERT] =
    new void(int maxItem ...SortedSet_wrapped_int[] sets) {
      wrapped_int toInsert = wrap(rand() % maxItem);
      // write('Inserting ' + string(toInsert.t) + '\n');
      for (SortedSet_wrapped_int s : sets) {
        s.insert(toInsert);
      }
    };
actions[ActionEnum.REPLACE] =
    new void(int maxItem ...SortedSet_wrapped_int[] sets) {
      wrapped_int toReplace = wrap(rand() % maxItem);
      // write('Replacing ' + string(toReplace.t) + '\n');
      wrapped_int[] results = new wrapped_int[];
      for (SortedSet_wrapped_int s : sets) {
        results.push(s.replace(toReplace));
      }
      if (results.length > 0) {
        wrapped_int expected = results[0];
        for (wrapped_int r : results) {
          if (!alias(r, expected)) {
            assert(false, 'Different results: ' + string(results));
          }
        }
      }
    };
actions[ActionEnum.DELETE] =
    new void(int maxItem ...SortedSet_wrapped_int[] sets) {
      wrapped_int toDelete = wrap(rand() % maxItem);
      // write('Deleting ' + string(toDelete.t) + '\n');
      bool[] results = new bool[];
      for (SortedSet_wrapped_int s : sets) {
        results.push(s.delete(toDelete));
      }
      if (results.length > 0) {
        bool expected = results[0];
        for (bool r : results) {
          assert(r == expected, 'Different results: ' + string(results));
        }
      }
    };
actions[ActionEnum.CONTAINS] =
    new void(int maxItem ...SortedSet_wrapped_int[] sets) {
      int toCheck = rand() % maxItem;
      // write('Checking ' + string(toCheck) + '\n');
      bool[] results = new bool[];
      for (SortedSet_wrapped_int s : sets) {
        results.push(s.contains(wrap(toCheck)));
      }
      if (results.length > 0) {
        bool expected = results[0];
        for (bool r : results) {
          assert(r == expected, 'Different results: ' + string(results));
        }
      }
    };
actions[ActionEnum.DELETE_CONTAINS] =
    new void(int ...SortedSet_wrapped_int[] sets) {
      if (sets.length == 0) {
        return;
      }
      int initialSize = sets[0].size();
      if (initialSize == 0) {
        return;
      }
      int indexToDelete = rand() % initialSize;
      int i = 0;
      wrapped_int toDelete = null;
      bool process(wrapped_int a) {
        if (i == indexToDelete) {
          toDelete = wrap(a.t);
          return false;
        }
        ++i;
        return true;
      }
      sets[0].forEach(process);
      assert(i < initialSize, 'Index out of range');
      // write('Deleting ' + string(toDelete.t) + '\n');
      int i = 0;
      for (SortedSet_wrapped_int s : sets) {
        assert(s.contains(toDelete), 'Contains failed ' + string(i));
        assert(s.delete(toDelete), 'Delete failed');
        assert(!s.contains(toDelete), 'Contains failed');
        assert(s.size() == initialSize - 1, 'Size failed');
        ++i;
      }
    };
real[] increasingProbs = new real[ActionEnum.numActions];
increasingProbs[ActionEnum.INSERT] = 0.7;
increasingProbs[ActionEnum.REPLACE] = 0.1;
increasingProbs[ActionEnum.DELETE] = 0.05;
increasingProbs[ActionEnum.CONTAINS] = 0.1;
increasingProbs[ActionEnum.DELETE_CONTAINS] = 0.05;
assert(sum(increasingProbs) == 1, 'Probabilities do not sum to 1');

real[] decreasingProbs = new real[ActionEnum.numActions];
decreasingProbs[ActionEnum.INSERT] = 0.1;
decreasingProbs[ActionEnum.REPLACE] = 0.1;
decreasingProbs[ActionEnum.DELETE] = 0.4;
decreasingProbs[ActionEnum.CONTAINS] = 0.1;
decreasingProbs[ActionEnum.DELETE_CONTAINS] = 0.3;
assert(sum(decreasingProbs) == 1, 'Probabilities do not sum to 1');

SortedSet_wrapped_int sorted_set =
    makeNaiveSortedSet(operator <, (wrapped_int)null);
SplayTree_wrapped_int splayset =
    SplayTree_wrapped_int(operator <, (wrapped_int)null);

int chooseAction(real[] probs) {
  real r = unitrand();
  real sum = 0;
  for (int i = 0; i < probs.length; ++i) {
    sum += probs[i];
    if (r < sum) {
      return i;
    }
  }
  return probs.length - 1;
} 

bool isStrictlySorted(wrapped_int[] arr) {
  for (int i = 1; i < arr.length; ++i) {
    if (!(arr[i - 1] < arr[i])) {
      return false;
    }
  }
  return true;
}

int maxSize = 0;
for (int i = 0; i < 2000; ++i) {
  real[] probs = i < 800 ? increasingProbs : decreasingProbs;
  int choice = chooseAction(probs);
  actions[choice](100, sorted_set, splayset);
  string diffs = differences(sorted_set, splayset);
  assert(diffs == '', 'Naive vs splayset: \n' + diffs);
  assert(isStrictlySorted(splayset), 'Not sorted');
  maxSize = max(maxSize, splayset.size());
}
EndTest();

StartTest("SplayTree_as_SortedSet");

struct ActionEnum {
  static restricted int numActions = 0;
  static private int next() {
    return ++numActions - 1;
  }
  static restricted int CONTAINS = next();
  static restricted int AFTER = next();
  static restricted int BEFORE = next();
  static restricted int FIRST_GEQ = next();
  static restricted int FIRST_LEQ = next();
  static restricted int MIN = next();
  static restricted int POP_MIN = next();
  static restricted int MAX = next();
  static restricted int POP_MAX = next();
  static restricted int INSERT = next();
  static restricted int REPLACE = next();
  static restricted int GET = next();
  static restricted int DELETE = next();
  static restricted int DELETE_CONTAINS = next();
}

Action[] actions = new Action[ActionEnum.numActions];
actions[ActionEnum.CONTAINS] =
    new void(int maxItem ...SortedSet_wrapped_int[] sets) {
      int toCheck = rand() % maxItem;
      // write('Checking ' + string(toCheck) + '\n');
      bool[] results = new bool[];
      for (SortedSet_wrapped_int s : sets) {
        results.push(s.contains(wrap(toCheck)));
      }
      if (results.length > 0) {
        bool expected = results[0];
        for (bool r : results) {
          assert(r == expected, 'Different results: ' + string(results));
        }
      }
    };
actions[ActionEnum.AFTER] =
    new void(int maxItem ...SortedSet_wrapped_int[] sets) {
      int toCheck = rand() % maxItem;
      // write('After ' + string(toCheck) + '\n');
      wrapped_int[] results = new wrapped_int[];
      for (SortedSet_wrapped_int s : sets) {
        results.push(s.after(wrap(toCheck)));
      }
      if (results.length > 0) {
        wrapped_int expected = results[0];
        for (wrapped_int r : results) {
          if (!alias(r, expected)) {
            assert(false, 'Different results: ' + string(results));
          }
        }
      }
    };
actions[ActionEnum.BEFORE] =
    new void(int maxItem ...SortedSet_wrapped_int[] sets) {
      int toCheck = rand() % maxItem;
      // write('Before ' + string(toCheck) + '\n');
      wrapped_int[] results = new wrapped_int[];
      for (SortedSet_wrapped_int s : sets) {
        results.push(s.before(wrap(toCheck)));
      }
      if (results.length > 0) {
        wrapped_int expected = results[0];
        for (wrapped_int r : results) {
          if (!alias(r, expected)) {
            assert(false, 'Different results: ' + string(results));
          }
        }
      }
    };
actions[ActionEnum.FIRST_GEQ] =
    new void(int maxItem ...SortedSet_wrapped_int[] sets) {
      int toCheck = rand() % maxItem;
      // write('First greater or equal ' + string(toCheck) + '\n');
      wrapped_int[] results = new wrapped_int[];
      for (SortedSet_wrapped_int s : sets) {
        results.push(s.firstGEQ(wrap(toCheck)));
      }
      if (results.length > 0) {
        wrapped_int expected = results[0];
        for (wrapped_int r : results) {
          if (!alias(r, expected)) {
            assert(false, 'Different results: ' + string(results));
          }
        }
      }
    };
actions[ActionEnum.FIRST_LEQ] =
    new void(int maxItem ...SortedSet_wrapped_int[] sets) {
      int toCheck = rand() % maxItem;
      // write('First less or equal ' + string(toCheck) + '\n');
      wrapped_int[] results = new wrapped_int[];
      for (SortedSet_wrapped_int s : sets) {
        results.push(s.firstLEQ(wrap(toCheck)));
      }
      if (results.length > 0) {
        wrapped_int expected = results[0];
        for (wrapped_int r : results) {
          if (!alias(r, expected)) {
            assert(false, 'Different results: ' + string(results));
          }
        }
      }
    };
actions[ActionEnum.MIN] = new void(int ...SortedSet_wrapped_int[] sets) {
  // write('Min\n');
  wrapped_int[] results = new wrapped_int[];
  for (SortedSet_wrapped_int s : sets) {
    results.push(s.min());
  }
  if (results.length > 0) {
    wrapped_int expected = results[0];
    for (wrapped_int r : results) {
      if (!alias(r, expected)) {
        assert(false, 'Different results: ' + string(results));
      }
    }
  }
};
actions[ActionEnum.POP_MIN] = new void(int ...SortedSet_wrapped_int[] sets) {
  // write('Pop min\n');
  wrapped_int[] results = new wrapped_int[];
  for (SortedSet_wrapped_int s : sets) {
    results.push(s.popMin());
  }
  if (results.length > 0) {
    wrapped_int expected = results[0];
    for (wrapped_int r : results) {
      if (!alias(r, expected)) {
        assert(false, 'Different results: ' + string(results));
      }
    }
  }
};
actions[ActionEnum.MAX] = new void(int ...SortedSet_wrapped_int[] sets) {
  // write('Max\n');
  wrapped_int[] results = new wrapped_int[];
  for (SortedSet_wrapped_int s : sets) {
    results.push(s.max());
  }
  if (results.length > 0) {
    wrapped_int expected = results[0];
    for (wrapped_int r : results) {
      if (!alias(r, expected)) {
        assert(false, 'Different results: ' + string(results));
      }
    }
  }
};
actions[ActionEnum.POP_MAX] = new void(int ...SortedSet_wrapped_int[] sets) {
  // write('Pop max\n');
  wrapped_int[] results = new wrapped_int[];
  for (SortedSet_wrapped_int s : sets) {
    results.push(s.popMax());
  }
  if (results.length > 0) {
    wrapped_int expected = results[0];
    for (wrapped_int r : results) {
      if (!alias(r, expected)) {
        assert(false, 'Different results: ' + string(results));
      }
    }
  }
};
actions[ActionEnum.INSERT] =
    new void(int maxItem ...SortedSet_wrapped_int[] sets) {
      wrapped_int toInsert = wrap(rand() % maxItem);
      // write('Inserting ' + string(toInsert.t) + '\n');
      for (SortedSet_wrapped_int s : sets) {
        s.insert(toInsert);
      }
    };
actions[ActionEnum.REPLACE] =
    new void(int maxItem ...SortedSet_wrapped_int[] sets) {
      wrapped_int toReplace = wrap(rand() % maxItem);
      // write('Replacing ' + string(toReplace.t) + '\n');
      wrapped_int[] results = new wrapped_int[];
      for (SortedSet_wrapped_int s : sets) {
        results.push(s.replace(toReplace));
      }
      if (results.length > 0) {
        wrapped_int expected = results[0];
        for (wrapped_int r : results) {
          if (!alias(r, expected)) {
            assert(false, 'Different results: ' + string(results));
          }
        }
      }
    };
actions[ActionEnum.GET] = new void(int maxItem ...SortedSet_wrapped_int[] sets)
{
  wrapped_int toGet = wrap(rand() % maxItem);
  // write('Getting ' + string(toGet) + '\n');
  wrapped_int[] results = new wrapped_int[];
  for (SortedSet_wrapped_int s : sets) {
    results.push(s.get(toGet));
  }
  if (results.length > 0) {
    wrapped_int expected = results[0];
    for (wrapped_int r : results) {
      if (!alias(r, expected)) {
        assert(false, 'Different results: ' + string(results));
      }
    }
  }
};
actions[ActionEnum.DELETE] =
    new void(int maxItem ...SortedSet_wrapped_int[] sets) {
      wrapped_int toDelete = wrap(rand() % maxItem);
      // write('Deleting ' + string(toDelete.t) + '\n');
      bool[] results = new bool[];
      for (SortedSet_wrapped_int s : sets) {
        results.push(s.delete(toDelete));
      }
      if (results.length > 0) {
        bool expected = results[0];
        for (bool r : results) {
          assert(r == expected, 'Different results: ' + string(results));
        }
      }
    };
actions[ActionEnum.DELETE_CONTAINS] =
    new void(int ...SortedSet_wrapped_int[] sets) {
      if (sets.length == 0) {
        return;
      }
      int initialSize = sets[0].size();
      if (initialSize == 0) {
        return;
      }
      int indexToDelete = rand() % initialSize;
      int i = 0;
      wrapped_int toDelete = null;
      bool process(wrapped_int a) {
        if (i == indexToDelete) {
          toDelete = wrap(a.t);
          return false;
        }
        ++i;
        return true;
      }
      sets[0].forEach(process);
      assert(i < initialSize, 'Index out of range');
      // write('Deleting ' + string(toDelete.t) + '\n');
      int i = 0;
      for (SortedSet_wrapped_int s : sets) {
        assert(s.delete(toDelete), 'Delete failed');
        assert(!s.contains(toDelete), 'Contains failed');
        assert(s.size() == initialSize - 1, 'Size failed');
        ++i;
      }
    };

real[] increasingProbs = array(n=ActionEnum.numActions, value=0.0);
// Actions that don't modify the set (except for rebalancing):
increasingProbs[ActionEnum.CONTAINS] = 1 / 2^5;
increasingProbs[ActionEnum.AFTER] = 1 / 2^5;
increasingProbs[ActionEnum.BEFORE] = 1 / 2^5;
increasingProbs[ActionEnum.FIRST_GEQ] = 1 / 2^5;
increasingProbs[ActionEnum.FIRST_LEQ] = 1 / 2^5;
increasingProbs[ActionEnum.MIN] = 1 / 2^5;
increasingProbs[ActionEnum.MAX] = 1 / 2^5;
increasingProbs[ActionEnum.GET] = 1 / 2^5;
// 1/4 probability of this sort of action:
assert(sum(increasingProbs) == 8 / 2^5);
// Actions that might add an element:
increasingProbs[ActionEnum.INSERT] = 1 / 4;
increasingProbs[ActionEnum.REPLACE] = 1 / 4;
assert(sum(increasingProbs) == 3/4);
// Actions that might remove an element:
increasingProbs[ActionEnum.POP_MIN] = 1 / 16;
increasingProbs[ActionEnum.POP_MAX] = 1 / 16;
increasingProbs[ActionEnum.DELETE] = 1 / 16;
increasingProbs[ActionEnum.DELETE_CONTAINS] = 1 / 16;
assert(sum(increasingProbs) == 1, 'Probabilities do not sum to 1');

real[] decreasingProbs = copy(increasingProbs);
// Actions that might add an element:
decreasingProbs[ActionEnum.INSERT] = 1 / 8;
decreasingProbs[ActionEnum.REPLACE] = 1 / 8;
// Actions that might remove an element:
decreasingProbs[ActionEnum.POP_MIN] = 1 / 8;
decreasingProbs[ActionEnum.POP_MAX] = 1 / 8;
decreasingProbs[ActionEnum.DELETE] = 1 / 8;
decreasingProbs[ActionEnum.DELETE_CONTAINS] = 1 / 8;
assert(sum(decreasingProbs) == 1, 'Probabilities do not sum to 1');

SortedSet_wrapped_int sorted_set =
    makeNaiveSortedSet(operator <, (wrapped_int)null);
SplayTree_wrapped_int splayset =
    SplayTree_wrapped_int(operator <, (wrapped_int)null);


int maxSize = 0;
for (int i = 0; i < 2000; ++i) {
  real[] probs = i < 800 ? increasingProbs : decreasingProbs;
  int choice = chooseAction(probs);
  actions[choice](100, sorted_set, splayset);
  string diffs = differences(sorted_set, splayset);
  assert(diffs == '', 'Naive vs splayset: \n' + diffs);
  assert(isStrictlySorted(splayset), 'Not sorted');
  maxSize = max(maxSize, splayset.size());
}

EndTest();