//===-- rosa/delux/CrossReliability.h ---------------------------*- C++ -*-===//
//
//                                 The RoSA Framework
//
//===----------------------------------------------------------------------===//
///
/// \file rosa/delux/CrossReliability.h
///
/// \author Daniel Schnoell
///
/// \date 2019
///
/// \brief
///
/// \todo there is 1 exception that needs to be handled correctly.
/// \note the default search function is extremely slow maby this could be done
/// via template for storage class and the functions/methods to efficiently find
/// the correct LinearFunction
//===----------------------------------------------------------------------===//
#ifndef ROSA_AGENT_CROSSRELIABILITY_H
#define ROSA_AGENT_CROSSRELIABILITY_H

#include "rosa/agent/Abstraction.hpp"
#include "rosa/agent/Functionality.h"
#include "rosa/core/forward_declarations.h" // needed for id_t
#include "rosa/support/log.h"               // needed for error "handling"

// nedded headers

#include <string>
#include <type_traits> //assert
#include <vector>
// for static methods
#include <algorithm>
#include <numeric>

namespace rosa {
namespace agent {

///	Calculates the Cross Reliability
/// \brief it uses the state represented by a numerical value and calculates the
/// Reliability of a given agent( represented by there id ) in connection to all
/// other given agents
///
/// \note all combination of agents and there coresponding Cross Reliability
/// function have to be specified
template <typename StateType, typename Type>
class CrossReliability : public Abstraction<StateType, Type> {

  static_assert(
      std::is_arithmetic<Type>::value,
      "CrossReliability: <Type> has to be arithmetic type\n"); // sanitny check
  static_assert(
      std::is_arithmetic<StateType>::value,
      "CrossReliability: <StateType> has to be arithmetic type\n"); // sanitny
                                                                    // check

  using Abstraction = typename rosa::agent::Abstraction<StateType, Type>;

  struct Functionblock {
    bool exists = false;
    id_t A;
    id_t B;
    Abstraction *Funct;
  };

  /// From Maxi in his code defined as 1 can be changed by set
  Type crossReliabilityParameter = 1;

  /// Stored Cross Reliability Functions
  std::vector<Functionblock> Functions;

  /// Method which is used to combine the generated values
  Type (*Method)(std::vector<Type> values) = AVERAGE;

  //--------------------------------------------------------------------------------
  // helper function
  /// evalues the absolute distance between two values
  /// \note this is actually the absolute distance but to ceep it somewhat
  /// conform with maxis code
  template <typename Type_t> Type_t AbsuluteValue(Type_t A, Type_t B) {
    return ((A - B) < 0) ? B - A : A - B;
  }

  /// verry inefficient searchFunction
  Functionblock (*searchFunction)(std::vector<Functionblock> vect,
                                  const id_t nameA, const id_t nameB) =
      [](std::vector<Functionblock> vect, const id_t nameA,
         const id_t nameB) -> Functionblock {
    for (Functionblock tmp : vect) {
      if (tmp.A == nameA && tmp.B == nameB)
        return tmp;
      if (tmp.A == nameB && tmp.B == nameA)
        return tmp;
    }
    return Functionblock();
  };

  /// evaluest the corisponding LinearFunction thith the score difference
  /// \param nameA these two parameters are the unique identifiers
  /// \param nameB these two parameters are the unique identifiers
  /// for the LinerFunction
  ///
  /// \note If the block nameA nameB doesn't exist it logs the error and returns
  /// 0
  /// \note it doesn't matter if they are swapped
  Type getCrossReliabilityFromProfile(id_t nameA, id_t nameB,
                                      StateType scoreDifference) {
    Functionblock block = searchFunction(Functions, nameA, nameB);
    if (!block.exists) {
      LOG_ERROR(("CrossReliability: Block:" + std::to_string(nameA) + "," +
                 std::to_string(nameB) + "doesn't exist returning 0"));
      return 0;
    }
    return block.Funct->operator()(scoreDifference);
  }

public:
  /// adds a Cross Reliability Profile used to get the Reliability of the state
  /// difference
  /// \param idA The id of the one \c Agent ( idealy the id of \c Unit to make
  /// it absolutly unique )
  ///
  /// \param idB The id of the other \c Agent
  ///
  /// \param Function A unique pointer to an \c Abstraction it would use the difference
  /// in score for its input
  void addCrossReliabilityProfile(id_t idA, id_t idB,
                                  std::unique_ptr<Abstraction> &Function) {
    Abstraction *ptr = Function.release();
    Functions.push_back({true, idA, idB, ptr});
  }

  /// sets the cross reliability parameter
  void setCrossReliabilityParameter(Type val) {
    crossReliabilityParameter = val;
  }
  /// sets the used method to combine the values
  /// \param Meth The Function which defines the combination method.
  /// \note Inside \c CrossReliability there are static methods defined which
  /// can be used.
  void setCrossReliabilityMethod(Type (*Meth)(std::vector<Type> values)) {
    Method = Meth;
  }

  CrossReliability() : Abstraction(0) {}

  ~CrossReliability() {
    for (auto tmp : Functions)
      delete tmp.Funct;
    Functions.clear();
  }

  /// Calculets the CrossReliability
  /// \note both Main and Slaveagents are represented by there data and an
  /// unique identifier
  ///
  /// \param MainAgent defines the value pair around which the Cross Reliability
  /// is calculated
  /// \param SlaveAgents defines all value pairs of the connected Agents it
  /// doesn't matter if Main agent exists inside this vector
  Type operator()(std::pair<id_t, StateType> &&MainAgent,
                  std::vector<std::pair<id_t, StateType>> &SlaveAgents);

  /// predefined combination method
  static Type CONJUNCTION(std::vector<Type> values) {
    return *std::min_element(values.begin(), values.end());
  }

  /// predefined combination method
  static Type AVERAGE(std::vector<Type> values) {
    return std::accumulate(values.begin(), values.end(), 0.0) / values.size();
  }

  /// predefined combination method
  static Type DISJUNCTION(std::vector<Type> values) {
    return *std::max_element(values.begin(), values.end());
  }
};

template <typename StateType, typename Type>
inline Type CrossReliability<StateType, Type>::
operator()(std::pair<id_t, StateType> &&MainAgent,
           std::vector<std::pair<id_t, StateType>> &SlaveAgents) {

  Type crossReliabiability;
  std::vector<Type> values;

  for (std::pair<id_t, StateType> SlaveAgent : SlaveAgents) {

    if (SlaveAgent.first == MainAgent.first)
      continue;

    if (MainAgent.second == SlaveAgent.second)
      crossReliabiability = 1;
    else
      crossReliabiability =
          1 / (crossReliabilityParameter *
               AbsuluteValue(MainAgent.second, SlaveAgent.second));

    // profile reliability
    Type crossReliabilityFromProfile = getCrossReliabilityFromProfile(
        MainAgent.first, SlaveAgent.first,
        AbsuluteValue(MainAgent.second, SlaveAgent.second));
    values.push_back(
        std::max(crossReliabiability, crossReliabilityFromProfile));
  }
  return Method(values);
}

///	Calculates the \c CrossConfidence
/// \brief It uses the a theoretical state represented by a numerical value and
/// calculates the Reliability of a given agent[ represented by there id ] in
/// connection to all other given agents this can be used to get a Confidence of
/// the current state
///
/// \note all combination of agents and there coresponding \c CrossReliability
/// function have to be specified
template <typename StateType, typename Type>
class CrossConfidence : public Abstraction<StateType, Type> {

  static_assert(std::is_arithmetic<Type>::value,
                "CrossConfidence: <Type> has to be an arithmetic type\n");
  static_assert(std::is_arithmetic<StateType>::value,
                "CrossConfidence: <StateType> has to be an arithmetic type\n");

  using Abstraction = typename rosa::agent::Abstraction<StateType, Type>;

  struct Functionblock {
    bool exists = false;
    id_t A;
    id_t B;
    Abstraction *Funct;
  };

  /// From Maxi in his code defined as 1 can be changed by set
  Type crossReliabilityParameter = 1;

  /// Stored Cross Reliability Functions
  std::vector<Functionblock> Functions;

  /// Method which is used to combine the generated values
  Type (*Method)(std::vector<Type> values) = AVERAGE;

  //--------------------------------------------------------------------------------
  // helper function
  /// evalues the absolute distance between two values
  /// \note this is actually the absolute distance but to ceep it somewhat
  /// conform with maxis code
  template <typename Type_t> Type_t AbsuluteValue(Type_t A, Type_t B) {
    return ((A - B) < 0) ? B - A : A - B;
  }

  /// verry inefficient searchFunction
  Functionblock (*searchFunction)(std::vector<Functionblock> vect,
                                  const id_t nameA, const id_t nameB) =
      [](std::vector<Functionblock> vect, const id_t nameA,
         const id_t nameB) -> Functionblock {
    for (Functionblock tmp : vect) {
      if (tmp.A == nameA && tmp.B == nameB)
        return tmp;
      if (tmp.A == nameB && tmp.B == nameA)
        return tmp;
    }
    return Functionblock();
  };

  /// evaluest the corisponding LinearFunction thith the score difference
  /// \param nameA these two parameters are the unique identifiers
  /// \param nameB these two parameters are the unique identifiers
  /// for the LinerFunction
  ///
  /// \note it doesn't matter if they are swapped
  Type getCrossReliabilityFromProfile(id_t nameA, id_t nameB,
                                      StateType scoreDifference) {
    Functionblock block = searchFunction(Functions, nameA, nameB);
    if (!block.exists) {
      LOG_ERROR(("CrossReliability: Block:" + std::to_string(nameA) + "," +
                 std::to_string(nameB) + "doesn't exist returning 0"));
      return 0;
    }
    return block.Funct->operator()(scoreDifference);
  }

public:
  /// adds a Cross Reliability Profile used to get the Reliability of the state
  /// difference
  /// \param idA The id of the one \c Agent ( idealy the id of \c Unit to make
  /// it absolutly unique )
  ///
  /// \param idB The id of the other \c Agent
  ///
  /// \param Function A unique pointer to an \c Abstraction it would use the difference
  /// in score for its input
  void addCrossReliabilityProfile(id_t idA, id_t idB,
                                  std::unique_ptr<Abstraction> &Function) {
    Abstraction *ptr = Function.release();
    Functions.push_back({true, idA, idB, ptr});
  }

  /// sets the cross reliability parameter
  void setCrossReliabilityParameter(Type val) {
    crossReliabilityParameter = val;
  }
  /// sets the used method to combine the values
  /// \param Meth The Function which defines the combination method.
  /// \note Inside \c CrossReliability there are static methods defined which
  /// can be used.
  void setCrossReliabilityMethod(Type (*Meth)(std::vector<Type> values)) {
    Method = Meth;
  }

  CrossConfidence() : Abstraction(0) {}

  ~CrossConfidence() {
    for (auto tmp : Functions)
      delete tmp.Funct;
    Functions.clear();
  }

  Type operator()(id_t MainAgent, StateType TheoreticalValue,
                  std::vector<std::pair<id_t, StateType>> &SlaveAgents);

  /// predefined combination method
  static Type CONJUNCTION(std::vector<Type> values) {
    return *std::min_element(values.begin(), values.end());
  }

  /// predefined combination method
  static Type AVERAGE(std::vector<Type> values) {
    return std::accumulate(values.begin(), values.end(), 0.0) / values.size();
  }

  /// predefined combination method
  static Type DISJUNCTION(std::vector<Type> values) {
    return *std::max_element(values.begin(), values.end());
  }
};

/// Calculats the CrossConfidence of the main agent compared to all other Agents
/// \param MainAgent The id of the Main agent
/// \param TheoreticalValue The throretical value it should use for calculation
/// \param SlaveAgents The numerical Representation of all other Slave Agents
template <typename StateType, typename Type>
inline Type CrossConfidence<StateType, Type>::
operator()(id_t MainAgent, StateType TheoreticalValue,
           std::vector<std::pair<id_t, StateType>> &SlaveAgents) {

  Type crossReliabiability;

  std::vector<Type> values;

  for (std::pair<id_t, StateType> SlaveAgent : SlaveAgents) {

    if (SlaveAgent.first == MainAgent)
      continue;

    if (TheoreticalValue == SlaveAgent.second)
      crossReliabiability = 1;
    else
      crossReliabiability =
          1 / (crossReliabilityParameter *
               AbsuluteValue(TheoreticalValue, SlaveAgent.second));

    // profile reliability
    Type crossReliabilityFromProfile = getCrossReliabilityFromProfile(
        MainAgent, SlaveAgent.first,
        AbsuluteValue(TheoreticalValue, SlaveAgent.second));
    values.push_back(
        std::max(crossReliabiability, crossReliabilityFromProfile));
  }
  return Method(values);
}

} // End namespace agent
} // End namespace rosa

#endif // ROSA_AGENT_CROSSRELIABILITY_H