#include <cgreen/cgreen.h>
#include "../assert_octave.h"
#include <octave/oct.h>
#include "../../Cost_auto_fcn.h"

// holds utility functions: scramble_matrix() and get_2random_idx()
// and holds the data: ts_data[] -- Gaussian process
#include "util.h"

using namespace cgreen;

Ensure (cost_is_zero_when_no_changes_to_series)
{
  // Construct and load matrix with ts_data
  Matrix ts_series(50,1);
  memcpy (ts_series.fortran_vec (), ts_data, sizeof (ts_data));

  // Prepare and construct cost function (Cost_auto_fcn)
  octave_idx_type no_lags = 5;
  octave_idx_type average_type = 0;
  Cost_auto_fcn cost_f (&ts_series, no_lags, average_type);
  cost_f.cost_transform ();

  // Perform test
  cost_f.cost_full ();

  assert_that_double (cost_f.get_cost () + 1 -1, is_equal_to_double(0));
}

Ensure (cost_is_non_zero_when_series_is_scrambled)
{
  // Construct and load matrix with ts_data
  Matrix ts_series(50,1);
  memcpy (ts_series.fortran_vec (), ts_data, sizeof (ts_data));

  // Prepare and construct cost function (Cost_auto_fcn)
  octave_idx_type no_lags = 5;
  octave_idx_type average_type = 0;
  Cost_auto_fcn cost_f (&ts_series, no_lags, average_type);
  cost_f.cost_transform ();

  // Perform test
  scramble_matrix (ts_series);

  cost_f.cost_full ();

  assert_that_double (cost_f.get_cost() +1 -1, is_not_equal_to_double (0));
}

Ensure (cost_update_is_accepted_if_cmax_is_large)
{
  // Construct and load matrix with ts_data
  Matrix ts_series(50,1);
  memcpy (ts_series.fortran_vec (), ts_data, sizeof (ts_data));

  // Prepare and construct cost function (Cost_auto_fcn)
  octave_idx_type no_lags = 5;
  octave_idx_type average_type = 0;
  Cost_auto_fcn cost_f (&ts_series, no_lags, average_type);
  cost_f.cost_transform ();

  // Perform test
  octave_idx_type n1, n2;
  get_2random_idx (ts_series, n1, n2);

  bool accept;
  cost_f.cost_update (n1,n2,1e10,accept);
  assert_that (accept, is_true);
}

Ensure (cost_updates_have_larger_cost_than_cost_full)
{
  // Construct and load matrix with ts_data
  Matrix ts_series(50,1);
  memcpy (ts_series.fortran_vec (), ts_data, sizeof (ts_data));

  // Prepare and construct cost function (Cost_auto_fcn)
  octave_idx_type no_lags = 5;
  octave_idx_type average_type = 0;
  Cost_auto_fcn cost_f (&ts_series, no_lags, average_type);
  cost_f.cost_transform ();

  // Prepare matrix
  scramble_matrix (ts_series);

  // Establish a base cost from which to update
  cost_f.cost_full ();

  // Perform cost updates multiple times
  for (octave_idx_type i = 0; i < 1000; i++)
    {
      octave_idx_type n1, n2;
      get_2random_idx (ts_series, n1, n2);
      // cmax is set high so that every cost update is accepted
      bool accept;
      cost_f.cost_update (n1, n2, 1e10, accept);

      // Swap element in n1 with n2 (I assume the cost_update was accepted)
      double tmp = ts_series (n1);
      ts_series(n1) = ts_series (n2);
      ts_series(n2) = tmp;
    }
  // The resulting cost of the updates
  double cost_update = cost_f.get_cost ();

  // Calculate full cost
  cost_f.cost_full ();
  double cost_full = cost_f.get_cost ();

  // Compare both
  assert_true_with_message (cost_update >= cost_full,
                            "expected [cost_update] >= [cost_full] but got:\n"
                            "\t\t[cost_update]: [%g]\n"
                            "\t\t[cost_full]: [%g]", cost_update, cost_full);

}

//! setup function
void prepare_auto ()
{
  init_util ();
  significant_figures_for_assert_double_are (8);
}

TestSuite *normal_tests ();
TestSuite *cost_auto_tests ();
TestSuite *cost_spikespec_event_tests ();

int main()
{
  TestSuite *suite = create_test_suite ();
  add_suite (suite, cost_auto_tests ());
  add_suite (suite, normal_tests ());
  add_suite (suite, cost_spikespec_event_tests ());
  return run_test_suite (suite, create_text_reporter ());
}

TestSuite *cost_auto_tests ()
{
  TestSuite *suite = create_test_suite ();
  set_setup (suite, prepare_auto);
// FIXME: transform automatically scrambles...
//  add_test (suite, cost_is_zero_when_no_changes_to_series);
  add_test (suite, cost_is_non_zero_when_series_is_scrambled);
  add_test (suite, cost_update_is_accepted_if_cmax_is_large);
  add_test (suite, cost_updates_have_larger_cost_than_cost_full);

  return suite;
}

