//===-- rosa/delux/CrossReliability.h ---------------------------*- C++ -*-===//
//
//                                 The RoSA Framework
//
//===----------------------------------------------------------------------===//
///
/// \file rosa/delux/CrossReliability.h
///
/// \author Daniel Schnoell
///
/// \date 2019
///
/// \brief
///
/// \todo there is 1 exception that needs to be handled correctly.
/// \note the default search function is extremely slow maybe this could be done
/// via template for storage class and the functions/methods to efficiently find
/// the correct LinearFunction
//===----------------------------------------------------------------------===//
#ifndef ROSA_AGENT_CROSSRELIABILITY_H
#define ROSA_AGENT_CROSSRELIABILITY_H

#include "rosa/agent/Abstraction.hpp"
#include "rosa/agent/Functionality.h"
#include "rosa/agent/ReliabilityConfidenceCombinator.h"
#include "rosa/core/forward_declarations.h" // needed for id_t
#include "rosa/support/log.h"               // needed for error "handling"

// nedded headers

#include <string>
#include <type_traits> //assert
#include <vector>
// for static methods
#include <algorithm>
#include <numeric>

namespace rosa {
namespace agent {

template <typename id, typename StateType, typename ReliabilityType>
std::vector<std::pair<id_t, StateType>> &
operator<<(std::vector<std::pair<id_t, StateType>> &me,
           std::vector<std::tuple<id, StateType, ReliabilityType>> Values) {
  for (auto tmp : Values) {
    std::pair<id, StateType> tmp2;
    tmp2.first = std::get<0>(tmp);
    tmp2.second = std::get<1>(tmp);
    me.push_back(tmp2);
  }
  return me;
}

/// This is the Combinator class for cross reliabilities it has many functions
/// with different purposes
/// \brief It takes the scores and reliabilities of all given ids and calculates
/// the Reliability of them together. Also it can creates the feedback that is
/// needed by the \c ReliabilityAndConfidenceCombinator, which is a kind of
/// confidence.
///
/// \tparam StateType Datatype of the State ( Typically double	or float)
/// \tparam ReliabilityType	Datatype of the Reliability		(
/// Typically	long	or int)
///
/// \note This class is commonly in a master slave relationship as master with
/// \c ReliabilityAndConfidenceCombinator. The \c operator()() combines the
/// Reliability of all connected Slaves and uses that as its own Reliability
/// also creates the feedback for the Slaves.
///
/// \note more information about how the Reliability and feedback is
/// created at \c operator()() , \c getCombinedCrossReliability() , \c
/// getCombinedInputReliability() , \c getOutputReliability() [ this is the
/// commonly used Reliability ], \c getCrossConfidence() [ this is the feedback
/// for all Slaves ]
template <typename StateType, typename ReliabilityType> class CrossCombinator {
public:
  static_assert(std::is_arithmetic<StateType>::value,
                "HighLevel: StateType has to be an arithmetic type\n");
  static_assert(std::is_arithmetic<ReliabilityType>::value,
                "HighLevel: ReliabilityType has to be an arithmetic type\n");

  // ---------------------------------------------------------------------------
  //			useful definitions
  // ---------------------------------------------------------------------------
  /// typedef To shorten the writing.
  /// \c ConfOrRel
  typedef ConfOrRel<StateType, ReliabilityType> ConfOrRel;

  /// To shorten the writing.
  using Abstraction =
      typename rosa::agent::Abstraction<StateType, ReliabilityType>;

  /// The return type for the \c operator()() Method
  struct returnType {
    ReliabilityType CrossReliability;
    std::map<id_t, std::vector<ConfOrRel>> CrossConfidence;
  };

  // -------------------------------------------------------------------------
  //			Relevant Methods
  // -------------------------------------------------------------------------
  /// Calculates the Reliability and the CrossConfidences for each id for all
  /// of there states.
  ///
  /// \param Values It gets the States and Reliabilities of
  /// all connected Slaves inside a vector.
  ///
  /// \return it returns a struct \c returnType containing the \c
  /// getCombinedCrossReliability() and \c getCrossConfidence()
  returnType
  operator()(std::vector<std::tuple<id_t, StateType, ReliabilityType>> Values) {
    return {getOutputReliability(Values), getCrossConfidence(Values)};
  }

  /// returns the combined via \c CombinedCrossRelCombinationMethod \c
  /// setCombinedCrossRelCombinationMethod()  Cross Reliability for all ids \c
  /// CrossReliability() \param Values the used Values
  ReliabilityType getCombinedCrossReliability(
      std::vector<std::tuple<id_t, StateType, ReliabilityType>> Values) {

    ReliabilityType combinedCrossRel = -1;

    std::vector<std::pair<id_t, StateType>> Agents;

    Agents << Values;

    for (auto Value : Values) {
      id_t id = std::get<0>(Value);
      StateType sc = std::get<1>(Value);

      // calculate the cross reliability for this slave agent
      ReliabilityType realCrossReliabilityOfSlaveAgent = CrossReliability(
          {id, sc},
          Agents); // AVERAGE, MULTIPLICATION, CONJUNCTION (best to worst:
                   // AVERAGE = CONJUNCTION > MULTIPLICATION >> )

      if (combinedCrossRel != -1)
        combinedCrossRel = CombinedCrossRelCombinationMethod(
            combinedCrossRel, realCrossReliabilityOfSlaveAgent);
      else
        combinedCrossRel = realCrossReliabilityOfSlaveAgent;
    }
    return combinedCrossRel;
  }

  /// returns the combined via \c CombinedInputRelCombinationMethod \c
  /// setCombinedInputRelCombinationMethod()  input relibility \param Values the
  /// used Values
  ReliabilityType getCombinedInputReliability(
      std::vector<std::tuple<id_t, StateType, ReliabilityType>> Values) {
    ReliabilityType combinedInputRel = -1;

    std::vector<std::pair<id_t, StateType>> Agents;

    Agents << Values;

    for (auto Value : Values) {
      ReliabilityType rel = std::get<2>(Value);

      if (combinedInputRel != -1)
        combinedInputRel =
            CombinedInputRelCombinationMethod(combinedInputRel, rel);
      else
        combinedInputRel = rel;
    }
    return combinedInputRel;
  }

  /// returns the combination via  \c OutputReliabilityCombinationMethod \c
  /// setOutputReliabilityCombinationMethod()   of the Cross reliability and
  /// input reliability \param Values the used Values
  ReliabilityType getOutputReliability(
      std::vector<std::tuple<id_t, StateType, ReliabilityType>> Values) {
    return OutputReliabilityCombinationMethod(
        getCombinedInputReliability(Values),
        getCombinedCrossReliability(Values));
  }

  /// retruns the crossConfidence for all ids \c CrossConfidence()
  /// \param Values the used Values
  std::map<id_t, std::vector<ConfOrRel>> getCrossConfidence(
      std::vector<std::tuple<id_t, StateType, ReliabilityType>> Values) {

    std::vector<std::pair<id_t, StateType>> Agents;
    std::map<id_t, std::vector<ConfOrRel>> output;
    std::vector<ConfOrRel> output_temporary;

    Agents << Values;

    for (auto Value : Values) {
      id_t id = std::get<0>(Value);

      output_temporary.clear();
      for (StateType thoScore : States[id]) {
        ConfOrRel data;
        data.score = thoScore;
        data.Reliability = CrossConfidence(id, thoScore, Agents);
        output_temporary.push_back(data);
      }

      output.insert({id, output_temporary});
    }

    return output;
  }

  /// Calculates the Cross Confidence
  /// \brief it uses the state represented by a numerical value and calculates
  /// the Confidence of a given agent( represented by there id ) for a given
  /// state in connection to all other given agents
  ///
  /// \note all combination of agents and there corresponding Cross Reliability
  /// function have to be specified
  ReliabilityType
  CrossConfidence(id_t MainAgent, StateType TheoreticalValue,
                  std::vector<std::pair<id_t, StateType>> &SlaveAgents) {

    ReliabilityType crossReliabiability;

    std::vector<ReliabilityType> values;

    for (std::pair<id_t, StateType> SlaveAgent : SlaveAgents) {

      if (SlaveAgent.first == MainAgent)
        continue;

      if (TheoreticalValue == SlaveAgent.second)
        crossReliabiability = 1;
      else
        crossReliabiability =
            1 / (crossReliabilityParameter *
                 AbsuluteValue(TheoreticalValue, SlaveAgent.second));

      // profile reliability
      ReliabilityType crossReliabilityFromProfile =
          getCrossReliabilityFromProfile(
              MainAgent, SlaveAgent.first,
              AbsuluteValue(TheoreticalValue, SlaveAgent.second));
      values.push_back(
          std::max(crossReliabiability, crossReliabilityFromProfile));
    }
    return Method(values);
  }

  /// Calculates the Cross Reliability
  /// \brief it uses the state represented by a numerical value and calculates
  /// the Reliability of a given agent( represented by there id ) in connection
  /// to all other given agents
  ///
  /// \note all combination of agents and there corresponding Cross Reliability
  /// function have to be specified
  ReliabilityType
  CrossReliability(std::pair<id_t, StateType> &&MainAgent,
                   std::vector<std::pair<id_t, StateType>> &SlaveAgents) {

    ReliabilityType crossReliabiability;
    std::vector<ReliabilityType> values;

    for (std::pair<id_t, StateType> SlaveAgent : SlaveAgents) {

      if (SlaveAgent.first == MainAgent.first)
        continue;

      if (MainAgent.second == SlaveAgent.second)
        crossReliabiability = 1;
      else
        crossReliabiability =
            1 / (crossReliabilityParameter *
                 AbsuluteValue(MainAgent.second, SlaveAgent.second));

      // profile reliability
      ReliabilityType crossReliabilityFromProfile =
          getCrossReliabilityFromProfile(
              MainAgent.first, SlaveAgent.first,
              AbsuluteValue(MainAgent.second, SlaveAgent.second));
      values.push_back(
          std::max(crossReliabiability, crossReliabilityFromProfile));
    }
    return Method(values);
  }

  // --------------------------------------------------------------------------
  //			Defining the class
  // --------------------------------------------------------------------------

  /// adds a Cross Reliability Profile used to get the Reliability of the state
  /// difference
  /// \param idA The id of the one \c Agent ( idealy the id of \c Unit to make
  /// it absolutly unique )
  ///
  /// \param idB The id of the other \c Agent
  ///
  /// \param Function A unique pointer to an \c Abstraction it would use the
  /// difference in score for its input
  void addCrossReliabilityProfile(id_t idA, id_t idB,
                                  std::unique_ptr<Abstraction> &Function) {
    Abstraction *ptr = Function.release();
    Functions.push_back({true, idA, idB, ptr});
  }

  /// sets the cross reliability parameter
  void setCrossReliabilityParameter(ReliabilityType val) {
    crossReliabilityParameter = val;
  }

  /// This is the adder for the states
  /// \param id The id of the Agent of the states
  /// \param States id specific states. this will be copied So that if Slaves
  /// have different States they can be used correctly.
  /// \brief The States of all connected lowlevel Agents has to be known to be
  /// able to iterate over them
  void addStates(id_t id, std::vector<StateType> States) {
    this->States.insert({id, States});
  }

  // -------------------------------------------------------------------------
  //			Combinator Settings
  // -------------------------------------------------------------------------

  /// sets the used method to combine the values
  /// \param Meth The Function which defines the combination method. predef: \c
  /// CONJUNCTION() \c AVERAGE() \c DISJUNCTION()
  void setCrossReliabilityCombinatorMethod(
      ReliabilityType (*Meth)(std::vector<ReliabilityType> values)) {
    Method = Meth;
  }

  /// sets the combination method for the combined cross reliability
  /// \param Meth the method which should be used. predef: \c
  /// CombinedCrossRelCombinationMethodMin() \c
  /// CombinedCrossRelCombinationMethodMax() \c
  /// CombinedCrossRelCombinationMethodMult() \c
  /// CombinedCrossRelCombinationMethodAverage()
  void setCombinedCrossRelCombinationMethod(
      ReliabilityType (*Meth)(ReliabilityType, ReliabilityType)) {
    CombinedCrossRelCombinationMethod = Meth;
  }

  /// sets the combined input rel method
  /// \param Meth the method which should be used predef: \c
  /// CombinedInputRelCombinationMethodMin() \c
  /// CombinedInputRelCombinationMethodMax() \c
  /// CombinedInputRelCombinationMethodMult() \c
  /// CombinedInputRelCombinationMethodAverage()
  void setCombinedInputRelCombinationMethod(
      ReliabilityType (*Meth)(ReliabilityType, ReliabilityType)) {
    CombinedInputRelCombinationMethod = Meth;
  }

  /// sets the used OutputReliabilityCombinationMethod
  /// \param Meth the used Method. predef: \c
  /// OutputReliabilityCombinationMethodMin() \c
  /// OutputReliabilityCombinationMethodMax() \c
  /// OutputReliabilityCombinationMethodMult() \c
  /// OutputReliabilityCombinationMethodAverage()
  void setOutputReliabilityCombinationMethod(
      ReliabilityType (*Meth)(ReliabilityType, ReliabilityType)) {
    OutputReliabilityCombinationMethod = Meth;
  }

  // -------------------------------------------------------------------------
  //			Predefined Functions
  // -------------------------------------------------------------------------
  /// predefined combination method
  static ReliabilityType CONJUNCTION(std::vector<ReliabilityType> values) {
    return *std::min_element(values.begin(), values.end());
  }

  /// predefined combination method
  static ReliabilityType AVERAGE(std::vector<ReliabilityType> values) {
    return std::accumulate(values.begin(), values.end(), 0.0) / values.size();
  }

  /// predefined combination method
  static ReliabilityType DISJUNCTION(std::vector<ReliabilityType> values) {
    return *std::max_element(values.begin(), values.end());
  }

  /// predefined combination Method
  static ReliabilityType
  CombinedCrossRelCombinationMethodMin(ReliabilityType A, ReliabilityType B) {
    return std::min(A, B);
  }

  /// predefined combination Method
  static ReliabilityType
  CombinedCrossRelCombinationMethodMax(ReliabilityType A, ReliabilityType B) {
    return std::max(A, B);
  }

  /// predefined combination Method
  static ReliabilityType
  CombinedCrossRelCombinationMethodMult(ReliabilityType A, ReliabilityType B) {
    return A * B;
  }

  /// predefined combination Method
  static ReliabilityType
  CombinedCrossRelCombinationMethodAverage(ReliabilityType A,
                                           ReliabilityType B) {
    return (A + B) / 2;
  }

  /// predefined combination Method
  static ReliabilityType
  CombinedInputRelCombinationMethodMin(ReliabilityType A, ReliabilityType B) {
    return std::min(A, B);
  }

  /// predefined combination Method
  static ReliabilityType
  CombinedInputRelCombinationMethodMax(ReliabilityType A, ReliabilityType B) {
    return std::max(A, B);
  }

  /// predefined combination Method
  static ReliabilityType
  CombinedInputRelCombinationMethodMult(ReliabilityType A, ReliabilityType B) {
    return A * B;
  }

  /// predefined combination Method
  static ReliabilityType
  CombinedInputRelCombinationMethodAverage(ReliabilityType A,
                                           ReliabilityType B) {
    return (A + B) / 2;
  }

  /// predefined combination method
  static ReliabilityType
  OutputReliabilityCombinationMethodMin(ReliabilityType A, ReliabilityType B) {
    return std::min(A, B);
  }

  /// predefined combination method
  static ReliabilityType
  OutputReliabilityCombinationMethodMax(ReliabilityType A, ReliabilityType B) {
    return std::max(A, B);
  }

  /// predefined combination method
  static ReliabilityType
  OutputReliabilityCombinationMethodMult(ReliabilityType A, ReliabilityType B) {
    return A * B;
  }

  /// predefined combination method
  static ReliabilityType
  OutputReliabilityCombinationMethodAverage(ReliabilityType A,
                                            ReliabilityType B) {
    return (A + B) / 2;
  }

  // -------------------------------------------------------------------------
  //				Cleanup
  // -------------------------------------------------------------------------

  ~CrossCombinator() {
    for (auto tmp : Functions)
      delete tmp.Funct;
    Functions.clear();
  }

  // --------------------------------------------------------------------------
  //			Needed stuff and stored stuff
  // --------------------------------------------------------------------------
private:
  struct Functionblock {
    bool exists = false;
    id_t A;
    id_t B;
    Abstraction *Funct;
  };

  std::map<id_t, std::vector<StateType>> States;

  /// From Maxi in his code defined as 1 can be changed by set
  ReliabilityType crossReliabilityParameter = 1;

  /// Stored Cross Reliability Functions
  std::vector<Functionblock> Functions;

  /// Method which is used to combine the generated values
  ReliabilityType (*Method)(std::vector<ReliabilityType> values) = AVERAGE;

  ReliabilityType (*CombinedCrossRelCombinationMethod)(
      ReliabilityType, ReliabilityType) = CombinedCrossRelCombinationMethodMin;

  ReliabilityType (*CombinedInputRelCombinationMethod)(
      ReliabilityType, ReliabilityType) = CombinedInputRelCombinationMethodMin;

  ReliabilityType (*OutputReliabilityCombinationMethod)(
      ReliabilityType, ReliabilityType) = OutputReliabilityCombinationMethodMin;

  //--------------------------------------------------------------------------------
  // helper function
  /// evaluates the absolute Value of two values
  /// \note this is actually the absolute distance but to keep it somewhat
  /// conform with maxis code
  template <typename Type_t> Type_t AbsuluteValue(Type_t A, Type_t B) {
    return ((A - B) < 0) ? B - A : A - B;
  }

  /// very inefficient searchFunction
  Functionblock (*searchFunction)(std::vector<Functionblock> vect,
                                  const id_t nameA, const id_t nameB) =
      [](std::vector<Functionblock> vect, const id_t nameA,
         const id_t nameB) -> Functionblock {
    for (Functionblock tmp : vect) {
      if (tmp.A == nameA && tmp.B == nameB)
        return tmp;
      if (tmp.A == nameB && tmp.B == nameA)
        return tmp;
    }
    return Functionblock();
  };

  /// evaluates the corresponding LinearFunction with the score difference
  /// \param nameA these two parameters are the unique identifiers
  /// \param nameB these two parameters are the unique identifiers
  /// for the LinerFunction
  ///
  /// \note it doesn't matter if they are swapped
  ReliabilityType getCrossReliabilityFromProfile(id_t nameA, id_t nameB,
                                                 StateType scoreDifference) {
    Functionblock block = searchFunction(Functions, nameA, nameB);
    if (!block.exists) {
      LOG_ERROR(("CrossReliability: Block:" + std::to_string(nameA) + "," +
                 std::to_string(nameB) + "doesn't exist returning 0"));
      return 0;
    }
    return block.Funct->operator()(scoreDifference);
  }
};

} // End namespace agent
} // End namespace rosa

#endif // ROSA_AGENT_CROSSRELIABILITY_H