//===-- rosa/delux/CrossCombinator.h ----------------------------*- C++ -*-===//
//
//                                 The RoSA Framework
//
// Distributed under the terms and conditions of the Boost Software License 1.0.
// See accompanying file LICENSE.
//
// If you did not receive a copy of the license file, see
// http://www.boost.org/LICENSE_1_0.txt.
//
//===----------------------------------------------------------------------===//
///
/// \file rosa/agent/CrossCombinator.h
///
/// \author Daniel Schnoell
///
/// \date 2019
/// \note based on Maximilian Goetzinger(maxgot @utu.fi) code in
/// CAM_Dirty_include SA-EWS2_Version... inside Agent.cpp
///
/// \brief
///
/// \todo there is 1 exception that needs to be handled correctly.
/// \note the default search function is extremely slow maybe this could be done
/// via template for storage class and the functions/methods to efficiently find
/// the correct LinearFunction
//===----------------------------------------------------------------------===//
#ifndef ROSA_AGENT_CROSSCOMBINATOR_H
#define ROSA_AGENT_CROSSCOMBINATOR_H

#include "rosa/agent/Abstraction.hpp"
#include "rosa/agent/Functionality.h"
#include "rosa/agent/ReliabilityConfidenceCombinator.h"
#include "rosa/core/forward_declarations.h" // needed for id_t
#include "rosa/support/log.h"               // needed for error "handling"

// nedded headers

#include <string>
#include <type_traits> //assert
#include <vector>
// for static methods
#include <algorithm>
#include <numeric>

namespace rosa {
namespace agent {

template <typename id, typename IdentifierType, typename ReliabilityType>
std::vector<std::pair<id_t, IdentifierType>> &operator<<(
    std::vector<std::pair<id_t, IdentifierType>> &me,
    std::vector<std::tuple<id, IdentifierType, ReliabilityType>> Values) {
  for (auto tmp : Values) {
    std::pair<id, IdentifierType> tmp2;
    tmp2.first = std::get<0>(tmp);
    tmp2.second = std::get<1>(tmp);
    me.push_back(tmp2);
  }
  return me;
}

/// This is the Combinator class for cross Reliabilities. It has many functions
/// with different purposes
/// \brief It takes the Identifiers and Reliabilities of all given ids and
/// calculates the Reliability of them together. Also it can creates the
/// feedback that is needed by the \c ReliabilityAndConfidenceCombinator, which
/// is a kind of confidence.
///
/// \tparam IdentifierType Data type of the Identifier ( Typically double
/// or float) \tparam ReliabilityType Data type of the Reliability ( Typically
/// long	or int)
///
/// \note This class is commonly in a master slave relationship as master with
/// \c ReliabilityAndConfidenceCombinator. The \c operator()() combines the
/// Reliability of all connected Slaves and uses that as its own Reliability
/// also creates the feedback for the Slaves.
///
/// \note more information about how the Reliability and feedback is
/// created at \c operator()() , \c getCombinedCrossReliability() , \c
/// getCombinedInputReliability() , \c getOutputReliability() [ this is the
///  used Reliability ], \c getCrossConfidence() [ this is the feedback
/// for all Slaves ]
///
/// a bit more special Methods \c CrossConfidence() ,\c CrossReliability()
template <typename IdentifierType, typename ReliabilityType>
class CrossCombinator {
public:
  static_assert(std::is_arithmetic<IdentifierType>::value,
                "HighLevel: IdentifierType has to be an arithmetic type\n");
  static_assert(std::is_arithmetic<ReliabilityType>::value,
                "HighLevel: ReliabilityType has to be an arithmetic type\n");

  // ---------------------------------------------------------------------------
  //			useful definitions
  // ---------------------------------------------------------------------------
  /// typedef To shorten the writing.
  /// \c ConfOrRel
  using ConfOrRel = ConfOrRel<IdentifierType, ReliabilityType>;

  /// To shorten the writing.
  using Abstraction =
      typename rosa::agent::Abstraction<IdentifierType, ReliabilityType>;

  /// The return type for the \c operator()() Method
  struct returnType {
    ReliabilityType CrossReliability;
    std::map<id_t, std::vector<ConfOrRel>> CrossConfidence;
  };

  // -------------------------------------------------------------------------
  //			Relevant Methods
  // -------------------------------------------------------------------------
  /// Calculates the Reliability and the CrossConfidences for each id for all
  /// of there Identifiers.
  ///
  /// \param Values It gets the Identifiers and Reliabilities of
  /// all connected Slaves inside a vector.
  ///
  /// \return it returns a struct \c returnType containing the \c
  /// getCombinedCrossReliability() and \c getCrossConfidence()
  returnType operator()(
      std::vector<std::tuple<id_t, IdentifierType, ReliabilityType>> Values) {
    return {getOutputReliability(Values), getCrossConfidence(Values)};
  }

  /// returns the combined Cross Reliability via \c
  /// CombinedCrossRelCombinationMethod \c
  /// setCombinedCrossRelCombinationMethod() for all ids \c
  /// CrossReliability() \param Values the used Values
  ReliabilityType getCombinedCrossReliability(
      const std::vector<std::tuple<id_t, IdentifierType, ReliabilityType>>
          &Values) noexcept {

    ReliabilityType combinedCrossRel = -1;

    std::vector<std::pair<id_t, IdentifierType>> Agents;

    Agents << Values;

    for (auto Value : Values) {
      id_t id = std::get<0>(Value);
      IdentifierType sc = std::get<1>(Value);

      // calculate the cross reliability for this slave agent
      ReliabilityType realCrossReliabilityOfSlaveAgent =
          CrossReliability({id, sc}, Agents);

      if (combinedCrossRel != -1)
        combinedCrossRel = CombinedCrossRelCombinationMethod(
            combinedCrossRel, realCrossReliabilityOfSlaveAgent);
      else
        combinedCrossRel = realCrossReliabilityOfSlaveAgent;
    }
    return combinedCrossRel;
  }

  /// returns the combined via \c CombinedInputRelCombinationMethod \c
  /// setCombinedInputRelCombinationMethod()  input reliability \param Values
  /// the used Values
  ReliabilityType getCombinedInputReliability(
      const std::vector<std::tuple<id_t, IdentifierType, ReliabilityType>>
          &Values) noexcept {
    ReliabilityType combinedInputRel = -1;

    std::vector<std::pair<id_t, IdentifierType>> Agents;

    Agents << Values;

    for (auto Value : Values) {
      ReliabilityType rel = std::get<2>(Value);

      if (combinedInputRel != -1)
        combinedInputRel =
            CombinedInputRelCombinationMethod(combinedInputRel, rel);
      else
        combinedInputRel = rel;
    }
    return combinedInputRel;
  }

  /// returns the combination via \c OutputReliabilityCombinationMethod \c
  /// setOutputReliabilityCombinationMethod() of the Cross reliability and
  /// input reliability \param Values the used Values
  ReliabilityType getOutputReliability(
      const std::vector<std::tuple<id_t, IdentifierType, ReliabilityType>>
          &Values) noexcept {
    return OutputReliabilityCombinationMethod(
        getCombinedInputReliability(Values),
        getCombinedCrossReliability(Values));
  }

  /// returns the crossConfidence for all ids \c CrossConfidence()
  /// \param Values the used Values
  std::map<id_t, std::vector<ConfOrRel>> getCrossConfidence(
      const std::vector<std::tuple<id_t, IdentifierType, ReliabilityType>>
          &Values) noexcept {

    std::vector<std::pair<id_t, IdentifierType>> Agents;
    std::map<id_t, std::vector<ConfOrRel>> output;
    std::vector<ConfOrRel> output_temporary;

    Agents << Values;

    for (auto Value : Values) {
      id_t id = std::get<0>(Value);

      output_temporary.clear();
      for (IdentifierType thoIdentifier : Identifiers[id]) {
        ConfOrRel data;
        data.Identifier = thoIdentifier;
        data.Reliability = CrossConfidence(id, thoIdentifier, Agents);
        output_temporary.push_back(data);
      }

      output.insert({id, output_temporary});
    }

    return output;
  }

  /// Calculates the Cross Confidence
  /// \brief it uses the Identifier value and calculates
  /// the Confidence of a given agent( represented by their id ) for a given
  /// Identifiers in connection to all other given agents
  ///
  /// \note all combination of agents and there corresponding Cross Reliability
  /// function have to be specified
  ReliabilityType
  CrossConfidence(const id_t &MainAgent, const IdentifierType &TheoreticalValue,
                  const std::vector<std::pair<id_t, IdentifierType>>
                      &SlaveAgents) noexcept {

    ReliabilityType crossReliabiability;

    std::vector<ReliabilityType> values;

    for (std::pair<id_t, IdentifierType> SlaveAgent : SlaveAgents) {

      if (SlaveAgent.first == MainAgent)
        continue;

      if (TheoreticalValue == SlaveAgent.second)
        crossReliabiability = 1;
      else
        crossReliabiability =
            1 / (crossReliabilityParameter *
                 std::abs(TheoreticalValue - SlaveAgent.second));

      // profile reliability
      ReliabilityType crossReliabilityFromProfile =
          getCrossReliabilityFromProfile(
              MainAgent, SlaveAgent.first,
              std::abs(TheoreticalValue - SlaveAgent.second));
      values.push_back(
          std::max(crossReliabiability, crossReliabilityFromProfile));
    }
    return Method(values);
  }

  /// Calculates the Cross Reliability
  /// \brief it uses the Identifier value and calculates
  /// the Reliability of a given agent( represented by their id ) in connection
  /// to all other given agents
  ///
  /// \note all combination of agents and there corresponding Cross Reliability
  /// function have to be specified
  ReliabilityType
  CrossReliability(const std::pair<id_t, IdentifierType> &MainAgent,
                   const std::vector<std::pair<id_t, IdentifierType>>
                       &SlaveAgents) noexcept {

    ReliabilityType crossReliabiability;
    std::vector<ReliabilityType> values;

    for (std::pair<id_t, IdentifierType> SlaveAgent : SlaveAgents) {

      if (SlaveAgent.first == MainAgent.first)
        continue;

      if (MainAgent.second == SlaveAgent.second)
        crossReliabiability = 1;
      else
        crossReliabiability =
            1 / (crossReliabilityParameter *
                 std::abs(MainAgent.second - SlaveAgent.second));

      // profile reliability
      ReliabilityType crossReliabilityFromProfile =
          getCrossReliabilityFromProfile(
              MainAgent.first, SlaveAgent.first,
              std::abs(MainAgent.second - SlaveAgent.second));
      values.push_back(
          std::max(crossReliabiability, crossReliabilityFromProfile));
    }
    return Method(values);
  }

  // --------------------------------------------------------------------------
  //			Defining the class
  // --------------------------------------------------------------------------

  /// adds a Cross Reliability Profile used to get the Reliability of the
  /// Identifier difference
  ///
  /// \param idA The id of the one \c Agent ( ideally the id of \c Unit to make
  /// it absolutely unique )
  ///
  /// \param idB The id of the other \c Agent
  ///
  /// \param Function A shared pointer to an \c Abstraction it would use the
  /// difference in Identifier for its input
  void addCrossReliabilityProfile(
      const id_t &idA, const id_t &idB,
      const std::shared_ptr<Abstraction> &Function) noexcept {
    Functions.push_back({true, idA, idB, Function});
  }

  /// sets the cross reliability parameter
  void setCrossReliabilityParameter(const ReliabilityType &val) noexcept {
    crossReliabilityParameter = val;
  }

  /// This is the adder for the Identifiers
  /// \param id The id of the Agent of the Identifiers
  /// \param _Identifiers id specific Identifiers. This will be copied So that if
  /// Slaves have different Identifiers they can be used correctly. \brief The
  /// Identifiers of all connected slave Agents has to be known to be able to
  /// iterate over them
  void
  addIdentifiers(const id_t &id,
                 const std::vector<IdentifierType> &_Identifiers) noexcept {
    Identifiers.insert({id, _Identifiers});
  }

  // -------------------------------------------------------------------------
  //			Combinator Settings
  // -------------------------------------------------------------------------

  /// sets the used method to combine the values
  /// \param Meth the method which should be used. predefined functions in the
  /// struct \c predefinedMethods  \c
  /// CONJUNCTION() \c AVERAGE() \c DISJUNCTION()
  void setCrossReliabilityCombinatorMethod(
      const std::function<ReliabilityType(std::vector<ReliabilityType> values)>
          &Meth) noexcept {
    Method = Meth;
  }

  /// sets the combination method for the combined cross reliability
  /// \param Meth the method which should be used. predefined functions in the
  /// struct \c predefinedMethods CombinedCrossRelCombinationMethod<method>()
  void setCombinedCrossRelCombinationMethod(
      const std::function<ReliabilityType(ReliabilityType, ReliabilityType)>
          &Meth) noexcept {
    CombinedCrossRelCombinationMethod = Meth;
  }

  /// sets the combined input rel method
  /// \param Meth the method which should be used. predefined functions in the
  /// struct \c predefinedMethods CombinedInputRelCombinationMethod<method>()
  void setCombinedInputRelCombinationMethod(
      const std::function<ReliabilityType(ReliabilityType, ReliabilityType)>
          &Meth) noexcept {
    CombinedInputRelCombinationMethod = Meth;
  }

  /// sets the used OutputReliabilityCombinationMethod
  /// \param Meth the method which should be used. predefined functions in the
  /// struct \c predefinedMethods OutputReliabilityCombinationMethod<method>()
  void setOutputReliabilityCombinationMethod(
      const std::function<ReliabilityType(ReliabilityType, ReliabilityType)>
          &Meth) noexcept {
    OutputReliabilityCombinationMethod = Meth;
  }

  // -------------------------------------------------------------------------
  //			Predefined Functions
  // -------------------------------------------------------------------------
  /// This struct is a pseudo name space to have easier access to all predefined
  /// methods while still not overcrowding the class it self
  struct predefinedMethods {
    /// predefined combination method
    static ReliabilityType CONJUNCTION(std::vector<ReliabilityType> values) {
      return *std::min_element(values.begin(), values.end());
    }

    /// predefined combination method
    static ReliabilityType AVERAGE(std::vector<ReliabilityType> values) {
      return std::accumulate(values.begin(), values.end(), 0.0) / values.size();
    }

    /// predefined combination method
    static ReliabilityType DISJUNCTION(std::vector<ReliabilityType> values) {
      return *std::max_element(values.begin(), values.end());
    }

    /// predefined combination Method
    static ReliabilityType
    CombinedCrossRelCombinationMethodMin(ReliabilityType A, ReliabilityType B) {
      return std::min(A, B);
    }

    /// predefined combination Method
    static ReliabilityType
    CombinedCrossRelCombinationMethodMax(ReliabilityType A, ReliabilityType B) {
      return std::max(A, B);
    }

    /// predefined combination Method
    static ReliabilityType
    CombinedCrossRelCombinationMethodMult(ReliabilityType A,
                                          ReliabilityType B) {
      return A * B;
    }

    /// predefined combination Method
    static ReliabilityType
    CombinedCrossRelCombinationMethodAverage(ReliabilityType A,
                                             ReliabilityType B) {
      return (A + B) / 2;
    }

    /// predefined combination Method
    static ReliabilityType
    CombinedInputRelCombinationMethodMin(ReliabilityType A, ReliabilityType B) {
      return std::min(A, B);
    }

    /// predefined combination Method
    static ReliabilityType
    CombinedInputRelCombinationMethodMax(ReliabilityType A, ReliabilityType B) {
      return std::max(A, B);
    }

    /// predefined combination Method
    static ReliabilityType
    CombinedInputRelCombinationMethodMult(ReliabilityType A,
                                          ReliabilityType B) {
      return A * B;
    }

    /// predefined combination Method
    static ReliabilityType
    CombinedInputRelCombinationMethodAverage(ReliabilityType A,
                                             ReliabilityType B) {
      return (A + B) / 2;
    }

    /// predefined combination method
    static ReliabilityType
    OutputReliabilityCombinationMethodMin(ReliabilityType A,
                                          ReliabilityType B) {
      return std::min(A, B);
    }

    /// predefined combination method
    static ReliabilityType
    OutputReliabilityCombinationMethodMax(ReliabilityType A,
                                          ReliabilityType B) {
      return std::max(A, B);
    }

    /// predefined combination method
    static ReliabilityType
    OutputReliabilityCombinationMethodMult(ReliabilityType A,
                                           ReliabilityType B) {
      return A * B;
    }

    /// predefined combination method
    static ReliabilityType
    OutputReliabilityCombinationMethodAverage(ReliabilityType A,
                                              ReliabilityType B) {
      return (A + B) / 2;
    }
  };

  // -------------------------------------------------------------------------
  //				Cleanup
  // -------------------------------------------------------------------------

  ~CrossCombinator() { Functions.clear(); }

  // --------------------------------------------------------------------------
  //				Parameters
  // --------------------------------------------------------------------------
private:
  struct Functionblock {
    bool exists = false;
    id_t A;
    id_t B;
    std::shared_ptr<Abstraction> Funct;
  };

  std::map<id_t, std::vector<IdentifierType>> Identifiers;

  /// From Maxi in his code defined as 1 can be changed by set
  ReliabilityType crossReliabilityParameter = 1;

  /// Stored Cross Reliability Functions
  std::vector<Functionblock> Functions;

  /// Method which is used to combine the generated values
  std::function<ReliabilityType(std::vector<ReliabilityType>)> Method =
      predefinedMethods::AVERAGE;

  std::function<ReliabilityType(ReliabilityType, ReliabilityType)>
      CombinedCrossRelCombinationMethod =
          predefinedMethods::CombinedCrossRelCombinationMethodMin;

  std::function<ReliabilityType(ReliabilityType, ReliabilityType)>
      CombinedInputRelCombinationMethod =
          predefinedMethods::CombinedInputRelCombinationMethodMin;

  std::function<ReliabilityType(ReliabilityType, ReliabilityType)>
      OutputReliabilityCombinationMethod =
          predefinedMethods::OutputReliabilityCombinationMethodMin;

  //--------------------------------------------------------------------------------
  // helper function

  /// very inefficient searchFunction
  Functionblock (*searchFunction)(std::vector<Functionblock> vect,
                                  const id_t nameA, const id_t nameB) =
      [](std::vector<Functionblock> vect, const id_t nameA,
         const id_t nameB) -> Functionblock {
    for (Functionblock tmp : vect) {
      if (tmp.A == nameA && tmp.B == nameB)
        return tmp;
      if (tmp.A == nameB && tmp.B == nameA)
        return tmp;
    }
    return Functionblock();
  };

  /// evaluates the corresponding LinearFunction with the Identifier difference
  /// \param nameA these two parameters are the unique identifiers
  /// \param nameB these two parameters are the unique identifiers
  /// for the LinerFunction
  ///
  /// \note it doesn't matter if they are swapped
  ReliabilityType getCrossReliabilityFromProfile(
      const id_t &nameA, const id_t &nameB,
      const IdentifierType &IdentifierDifference) noexcept {
    Functionblock block = searchFunction(Functions, nameA, nameB);
    if (!block.exists) {
      LOG_ERROR(("CrossReliability: Block:" + std::to_string(nameA) + "," +
                 std::to_string(nameB) + "doesn't exist returning 0"));
      return 0;
    }
    return block.Funct->operator()(IdentifierDifference);
  }
};

} // End namespace agent
} // End namespace rosa

#endif // ROSA_AGENT_CROSSCOMBINATOR_H