//===-- rosa/delux/CrossCombinator.h ----------------------------*- C++ -*-===//
//
//                                 The RoSA Framework
//
// Distributed under the terms and conditions of the Boost Software License 1.0.
// See accompanying file LICENSE.
//
// If you did not receive a copy of the license file, see
// http://www.boost.org/LICENSE_1_0.txt.
//
//===----------------------------------------------------------------------===//
///
/// \file rosa/agent/CrossCombinator.h
///
/// \author Daniel Schnoell
///
/// \date 2019-2020
///
/// \note based on Maximilian Goetzinger(maxgot @utu.fi) code in
/// CAM_Dirty_include SA-EWS2_Version... inside Agent.cpp
///
/// \brief
///
/// \todo there is 1 exception that needs to be handled correctly.
/// \note the default search function is extremely slow maybe this could be done
/// via template for storage class and the functions/methods to efficiently find
/// the correct LinearFunction
//===----------------------------------------------------------------------===//
#ifndef ROSA_AGENT_CROSSCOMBINATOR_H
#define ROSA_AGENT_CROSSCOMBINATOR_H

#include "rosa/agent/Abstraction.hpp"
#include "rosa/agent/Functionality.h"
#include "rosa/agent/ReliabilityConfidenceCombination.h"
#include "rosa/core/forward_declarations.h" // needed for id_t
#include "rosa/support/log.h"               // needed for error "handling"

// nedded headers

#include <string>
#include <type_traits> //assert
#include <vector>
// for static methods
#include <algorithm>
#include <numeric>

namespace rosa {
namespace agent {

template <typename id, typename IdentifierType, typename LikelinessType>
std::vector<std::pair<id_t, IdentifierType>> &operator<<(
    std::vector<std::pair<id_t, IdentifierType>> &me,
    std::vector<std::tuple<id, IdentifierType, LikelinessType>> Values) {
  for (auto tmp : Values) {
    std::pair<id, IdentifierType> tmp2;
    tmp2.first = std::get<0>(tmp);
    tmp2.second = std::get<1>(tmp);
    me.push_back(tmp2);
  }
  return me;
}

/// This is the Combinator class for cross Reliabilities. It has many functions
/// with different purposes
/// \brief It takes the Identifiers and Reliabilities of all given ids and
/// calculates the Likeliness of them together. Also it can creates the
/// feedback that is needed by the \c ReliabilityConfidenceCombination, which
/// is a kind of confidence.
///
/// \tparam IdentifierType Data type of the Identifier ( Typically double
/// or float) \tparam LikelinessType Data type of the Likeliness ( Typically
/// long	or int) // this might be swapped
///
/// \note This class is commonly in a master slave relationship as master with
/// \c ReliabilityConfidenceCombination. The \c operator()() combines the
/// Likeliness of all connected Slaves and uses that as its own Likeliness
/// also creates the feedback for the Slaves.
///
/// \note more information about how the Likeliness and feedback is
/// created at \c operator()() , \c getCombinedCrossLikeliness() , \c
/// getCombinedInputLikeliness() , \c getOutputLikeliness() [ this is the
///  used Likeliness ], \c getCrossLikeliness() [ this is the feedback
/// for all Compares ]
///
/// a bit more special Methods \c CrossConfidence() ,\c CrossLikeliness()
template <typename IdentifierType, typename LikelinessType>
class CrossCombinator {
public:
  static_assert(std::is_arithmetic<IdentifierType>::value,
                "HighLevel: IdentifierType has to be an arithmetic type\n");
  static_assert(std::is_arithmetic<LikelinessType>::value,
                "HighLevel: LikelinessType has to be an arithmetic type\n");

  // ---------------------------------------------------------------------------
  //			useful definitions
  // ---------------------------------------------------------------------------
  /// typedef To shorten the writing.
  /// \c Symbol
  using Symbol = Symbol<IdentifierType, LikelinessType>;

  /// To shorten the writing.
  using Abstraction =
      typename rosa::agent::Abstraction<IdentifierType, LikelinessType>;

  /// The return type for the \c operator()() Method
  struct returnType {
    LikelinessType CrossLikeliness;
    std::map<id_t, std::vector<Symbol>> Likeliness; 
  };

  // -------------------------------------------------------------------------
  //			Relevant Methods
  // -------------------------------------------------------------------------
  /// Calculates the CrossLikeliness and the Likeliness for each id for all
  /// of there Identifiers.
  ///
  /// \param Values It gets the Identifiers and Reliabilities of
  /// all connected Compare Agentss inside a vector.
  ///
  /// \return it returns a struct \c returnType containing the \c
  /// getCombinedCrossLikeliness() and \c getCrossLikeliness()
  returnType operator()(
      std::vector<std::tuple<id_t, IdentifierType, LikelinessType>> Values) {
    return {getOutputLikeliness(Values), getCrossLikeliness(Values)};
  }

  /// returns the combined Cross Likeliness via \c
  /// LikelinessCombinationMethod \c
  /// setLikelinessCombinationMethod() for all ids \c
  /// CrossLikeliness() \param Values the used Values
  LikelinessType getCombinedCrossLikeliness(
      const std::vector<std::tuple<id_t, IdentifierType, LikelinessType>>
          &Values) noexcept {

    LikelinessType Likeliness = -1;

    std::vector<std::pair<id_t, IdentifierType>> Agents;

    Agents << Values;

    for (auto Value : Values) {
      id_t id = std::get<0>(Value);
      IdentifierType sc = std::get<1>(Value);

      // calculate the cross Likeliness for this Compare agent
      LikelinessType realCrossLikelinessOfCompareInput =
          CrossLikeliness({id, sc}, Agents);

      if (Likeliness != -1)
        Likeliness = LikelinessCombinationMethod(
            Likeliness, realCrossLikelinessOfCompareInput);
      else
        Likeliness = realCrossLikelinessOfCompareInput;
    }
    return Likeliness;
  }

  /// returns the combined via \c CombinedInputLikelinessCombinationMethod \c
  /// setCombinedInputLikelinessCombinationMethod()  input Likeliness \param Values
  /// the used Values
  LikelinessType getCombinedInputLikeliness(
      const std::vector<std::tuple<id_t, IdentifierType, LikelinessType>>
          &Values) noexcept {
    LikelinessType combinedInputRel = -1;

    std::vector<std::pair<id_t, IdentifierType>> Agents;

    Agents << Values;

    for (auto Value : Values) {
      LikelinessType rel = std::get<2>(Value);

      if (combinedInputRel != -1)
        combinedInputRel =
            CombinedInputLikelinessCombinationMethod(combinedInputRel, rel);
      else
        combinedInputRel = rel;
    }
    return combinedInputRel;
  }

  /// returns the combination via \c OutputLikelinessCombinationMethod \c
  /// setOutputLikelinessCombinationMethod() of the Cross Likeliness and
  /// input Likeliness \param Values the used Values
  LikelinessType getOutputLikeliness(
      const std::vector<std::tuple<id_t, IdentifierType, LikelinessType>>
          &Values) noexcept {
    return OutputLikelinessCombinationMethod(
        getCombinedInputLikeliness(Values),
        getCombinedCrossLikeliness(Values));
  }

  /// returns the crossConfidence for all ids \c CrossConfidence()
  /// \param Values the used Values
  std::map<id_t, std::vector<Symbol>> getCrossLikeliness(
      const std::vector<std::tuple<id_t, IdentifierType, LikelinessType>>
          &Values) noexcept {

    std::vector<std::pair<id_t, IdentifierType>> Agents;
    std::map<id_t, std::vector<Symbol>> output;
    std::vector<Symbol> output_temporary;

    Agents << Values;

    for (auto Value : Values) {
      id_t id = std::get<0>(Value);

      output_temporary.clear();
      for (IdentifierType thoIdentifier : Identifiers[id]) {
        Symbol data;
        data.Identifier = thoIdentifier;
        data.Likeliness = CrossConfidence(id, thoIdentifier, Agents);
        output_temporary.push_back(data);
      }

      output.insert({id, output_temporary});
    }

    return output;
  }

  /// Calculates the Cross Confidence
  /// \brief it uses the Identifier value and calculates
  /// the Confidence of a given agent( represented by their id ) for a given
  /// Identifiers in connection to all other given agents
  ///
  /// \note all combination of agents and there corresponding Cross Likeliness
  /// function have to be specified
  LikelinessType
  CrossConfidence(const id_t &MainAgent, const IdentifierType &TheoreticalValue,
                  const std::vector<std::pair<id_t, IdentifierType>>
                      &CompareInputs) noexcept {

    LikelinessType crossReliabiability;

    std::vector<LikelinessType> values;

    for (std::pair<id_t, IdentifierType> CompareInput : CompareInputs) {

      if (CompareInput.first == MainAgent)
        continue;

      if (TheoreticalValue == CompareInput.second)
        crossReliabiability = 1;
      else
        crossReliabiability =
            1 / (crossLikelinessParameter *
                 std::abs(TheoreticalValue - CompareInput.second));

      // profile Likeliness
      LikelinessType crossLikelinessFromProfile =
          getCrossLikelinessFromProfile(
              MainAgent, CompareInput.first,
              std::abs(TheoreticalValue - CompareInput.second));
      values.push_back(
          std::max(crossReliabiability, crossLikelinessFromProfile));
    }
    return Method(values);
  }

  /// Calculates the Cross Likeliness
  /// \brief it uses the Identifier value and calculates
  /// the Likeliness of a given agent( represented by their id ) in connection
  /// to all other given agents
  ///
  /// \note all combination of agents and there corresponding Cross Likeliness
  /// function have to be specified
  LikelinessType
  CrossLikeliness(const std::pair<id_t, IdentifierType> &MainAgent,
                   const std::vector<std::pair<id_t, IdentifierType>>
                       &CompareInputs) noexcept {

    LikelinessType crossReliabiability;
    std::vector<LikelinessType> values;

    for (std::pair<id_t, IdentifierType> CompareInput : CompareInputs) {

      if (CompareInput.first == MainAgent.first)
        continue;

      if (MainAgent.second == CompareInput.second)
        crossReliabiability = 1;
      else
        crossReliabiability =
            1 / (crossLikelinessParameter *
                 std::abs(MainAgent.second - CompareInput.second));

      // profile Likeliness
      LikelinessType crossLikelinessFromProfile = getCrossLikelinessFromProfile(
          MainAgent.first, CompareInput.first,
          (IdentifierType)std::abs(MainAgent.second - CompareInput.second));
      values.push_back(
          std::max(crossReliabiability, crossLikelinessFromProfile));
    }
    return Method(values);
  }

  // --------------------------------------------------------------------------
  //			Defining the class
  // --------------------------------------------------------------------------

  /// adds a Cross Likeliness Profile used to get the Likeliness of the
  /// Identifier difference
  ///
  /// \param idA The id of the one \c Agent ( ideally the id of \c Unit to make
  /// it absolutely unique )
  ///
  /// \param idB The id of the other \c Agent
  ///
  /// \param Function A shared pointer to an \c Abstraction it would use the
  /// difference in Identifier for its input
  void addCrossLikelinessProfile(  //conf 
      const id_t &idA, const id_t &idB,
      const std::shared_ptr<Abstraction> &Function) noexcept {
    Functions.push_back({true, idA, idB, Function});  //confidence Profiles
  }

  /// sets the cross Likeliness parameter
  void setCrossLikelinessParameter(const LikelinessType &val) noexcept {
    crossLikelinessParameter = val;
  }

  /// This is the adder for the Identifiers
  /// \param id The id of the Agent of the Identifiers
  /// \param _Identifiers id specific Identifiers. This will be copied So that if
  /// Compares have different Identifiers they can be used correctly. \brief The
  /// Identifiers of all connected Compare Agents has to be known to be able to
  /// iterate over them
  void
  addIdentifiers(const id_t &id,		//add IdentifierIdentifiers
                 const std::vector<IdentifierType> &_Identifiers) noexcept {
    Identifiers.insert({id, _Identifiers});
  }

  // -------------------------------------------------------------------------
  //			Combinator Settings
  // -------------------------------------------------------------------------

  /// sets the used method to combine the values
  /// \param Meth the method which should be used. predefined functions in the
  /// struct \c predefinedMethods  \c
  /// CONJUNCTION() \c AVERAGE() \c DISJUNCTION()
  void setCrossLikelinessCombinatorMethod(
      const std::function<LikelinessType(std::vector<LikelinessType> values)>
          &Meth) noexcept {
    Method = Meth;
  }

  /// sets the combination method for the combined cross Likeliness
  /// \param Meth the method which should be used. predefined functions in the
  /// struct \c predefinedMethods LikelinessCombinationMethod<method>()
  void setLikelinessCombinationMethod(
      const std::function<LikelinessType(LikelinessType, LikelinessType)>
          &Meth) noexcept {
    LikelinessCombinationMethod = Meth;
  }

  /// sets the combined input rel method
  /// \param Meth the method which should be used. predefined functions in the
  /// struct \c predefinedMethods CombinedInputLikelinessCombinationMethod<method>()
  void setCombinedInputLikelinessCombinationMethod(
      const std::function<LikelinessType(LikelinessType, LikelinessType)>
          &Meth) noexcept {
    CombinedInputLikelinessCombinationMethod = Meth;
  }

  /// sets the used OutputLikelinessCombinationMethod
  /// \param Meth the method which should be used. predefined functions in the
  /// struct \c predefinedMethods OutputLikelinessCombinationMethod<method>()
  void setOutputLikelinessCombinationMethod(
      const std::function<LikelinessType(LikelinessType, LikelinessType)>
          &Meth) noexcept {
    OutputLikelinessCombinationMethod = Meth;
  }

  // -------------------------------------------------------------------------
  //			Predefined Functions
  // -------------------------------------------------------------------------
  /// This struct is a pseudo name space to have easier access to all predefined
  /// methods while still not overcrowding the class it self
  struct predefinedMethods {
    /// predefined combination method
    static LikelinessType CONJUNCTION(std::vector<LikelinessType> values) {
      return *std::min_element(values.begin(), values.end());
    }

    /// predefined combination method
    static LikelinessType AVERAGE(std::vector<LikelinessType> values) {
      return std::accumulate(values.begin(), values.end(), (LikelinessType)0) /
             values.size();
    }

    /// predefined combination method
    static LikelinessType DISJUNCTION(std::vector<LikelinessType> values) {
      return *std::max_element(values.begin(), values.end());
    }

    /// predefined combination Method
    static LikelinessType
    LikelinessCombinationMethodMin(LikelinessType A, LikelinessType B) {
      return std::min(A, B);
    }

    /// predefined combination Method
    static LikelinessType
    LikelinessCombinationMethodMax(LikelinessType A, LikelinessType B) {
      return std::max(A, B);
    }

    /// predefined combination Method
    static LikelinessType
    LikelinessCombinationMethodMult(LikelinessType A,
                                          LikelinessType B) {
      return A * B;
    }

    /// predefined combination Method
    static LikelinessType
    LikelinessCombinationMethodAverage(LikelinessType A,
                                             LikelinessType B) {
      return (A + B) / 2;
    }

    /// predefined combination Method
    static LikelinessType
    CombinedInputLikelinessCombinationMethodMin(LikelinessType A, LikelinessType B) {
      return std::min(A, B);
    }

    /// predefined combination Method
    static LikelinessType
    CombinedInputLikelinessCombinationMethodMax(LikelinessType A, LikelinessType B) {
      return std::max(A, B);
    }

    /// predefined combination Method
    static LikelinessType
    CombinedInputLikelinessCombinationMethodMult(LikelinessType A,
                                          LikelinessType B) {
      return A * B;
    }

    /// predefined combination Method
    static LikelinessType
    CombinedInputLikelinessCombinationMethodAverage(LikelinessType A,
                                             LikelinessType B) {
      return (A + B) / 2;
    }

    /// predefined combination method
    static LikelinessType
    OutputLikelinessCombinationMethodMin(LikelinessType A,
                                          LikelinessType B) {
      return std::min(A, B);
    }

    /// predefined combination method
    static LikelinessType
    OutputLikelinessCombinationMethodMax(LikelinessType A,
                                          LikelinessType B) {
      return std::max(A, B);
    }

    /// predefined combination method
    static LikelinessType
    OutputLikelinessCombinationMethodMult(LikelinessType A,
                                           LikelinessType B) {
      return A * B;
    }

    /// predefined combination method
    static LikelinessType
    OutputLikelinessCombinationMethodAverage(LikelinessType A,
                                              LikelinessType B) {
      return (A + B) / 2;
    }
  };

  // -------------------------------------------------------------------------
  //				Cleanup
  // -------------------------------------------------------------------------

  ~CrossCombinator() { Functions.clear(); }

  // --------------------------------------------------------------------------
  //				Parameters
  // --------------------------------------------------------------------------
private:
  struct Functionblock {
    bool exists = false;
    id_t A;
    id_t B;
    std::shared_ptr<Abstraction> Funct;
  };

  std::map<id_t, std::vector<IdentifierType>> Identifiers;

  /// From Maxi in his code defined as 1 can be changed by set
  LikelinessType crossLikelinessParameter = 1;

  /// Stored Cross Likeliness Functions
  std::vector<Functionblock> Functions;

  /// Method which is used to combine the generated values
  std::function<LikelinessType(std::vector<LikelinessType>)> Method =
      predefinedMethods::AVERAGE;

  std::function<LikelinessType(LikelinessType, LikelinessType)>
      LikelinessCombinationMethod =
          predefinedMethods::LikelinessCombinationMethodMin;

  std::function<LikelinessType(LikelinessType, LikelinessType)>
      CombinedInputLikelinessCombinationMethod =
          predefinedMethods::CombinedInputLikelinessCombinationMethodMin;

  std::function<LikelinessType(LikelinessType, LikelinessType)>
      OutputLikelinessCombinationMethod =
          predefinedMethods::OutputLikelinessCombinationMethodMin;

  //--------------------------------------------------------------------------------
  // helper function

  /// very inefficient searchFunction
  Functionblock (*searchFunction)(std::vector<Functionblock> vect,
                                  const id_t nameA, const id_t nameB) =
      [](std::vector<Functionblock> vect, const id_t nameA,
         const id_t nameB) -> Functionblock {
    for (Functionblock tmp : vect) {
      if (tmp.A == nameA && tmp.B == nameB)
        return tmp;
      if (tmp.A == nameB && tmp.B == nameA)
        return tmp;
    }
    return Functionblock();
  };

  /// evaluates the corresponding LinearFunction with the Identifier difference
  /// \param nameA these two parameters are the unique Identifiers
  /// \param nameB these two parameters are the unique Identifiers
  /// for the LinerFunction
  ///
  /// \note it doesn't matter if they are swapped
  LikelinessType getCrossLikelinessFromProfile(
      const id_t &nameA, const id_t &nameB,
      const IdentifierType &IdentifierDifference) noexcept {
    Functionblock block = searchFunction(Functions, nameA, nameB);
    if (!block.exists) {
      LOG_ERROR(("CrossLikeliness: Block:" + std::to_string(nameA) + "," +
                 std::to_string(nameB) + "doesn't exist returning 0"));
      return 0;
    }
    return block.Funct->operator()(IdentifierDifference);
  }
};

} // End namespace agent
} // End namespace rosa

#endif // ROSA_AGENT_CROSSCOMBINATOR_H
