LCOV - code coverage report
Current view: top level - dtwc - Problem.hpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 5 5 100.0 %
Date: 2024-09-07 20:53:22 Functions: 2 2 100.0 %

          Line data    Source code
       1             : /**
       2             :  * @file Problem.hpp
       3             :  * @brief Encapsulates the DTWC (Dynamic Time Warping Clustering) problem in a class.
       4             :  *
       5             :  * @details This file contains the definition of the Problem class used in DTWC applications.
       6             :  * It includes various methods for manipulating and analyzing clusters.
       7             :  *
       8             :  * @date 19 Oct 2022
       9             :  * @author Volkan Kumtepeli
      10             :  * @author Becky Perriment
      11             :  */
      12             : 
      13             : #pragma once
      14             : 
      15             : #include "Data.hpp"           // for Data
      16             : #include "DataLoader.hpp"     // for DataLoader
      17             : #include "fileOperations.hpp" // for writeMatrix, readMatrix
      18             : #include "settings.hpp"       // for data_t, resultsPath
      19             : #include "enums/enums.hpp"    // for using Enum types.
      20             : #include "initialisation.hpp" // for init functions
      21             : 
      22             : #include <cstddef>     // for size_t
      23             : #include <filesystem>  // for operator/, path
      24             : #include <ostream>     // for operator<<, basic_ostream, ofstream
      25             : #include <string>      // for char_traits, operator+, operator<<
      26             : #include <string_view> // for string_view
      27             : #include <utility>     // for pair
      28             : #include <vector>      // for vector, allocator
      29             : #include <type_traits> // std::decay_t
      30             : #include <functional>  // std::function
      31             : #include <iostream>
      32             : 
      33             : #include <armadillo>
      34             : 
      35             : namespace dtwc {
      36             : 
      37             : /**
      38             :  * @class Problem
      39             :  * @brief Class representing a problem in DTWC.
      40             :  *
      41             :  * @details This class encapsulates all the functionalities and data structures required to solve
      42             :  * a dynamic time warping clustering problem. It includes methods for initialising clusters,
      43             :  * calculating distances, clustering, and writing results.
      44             :  */
      45             : class Problem
      46             : {
      47             : public:
      48             :   using distMat_t = arma::Mat<double>;
      49             :   using path_t = std::decay_t<decltype(settings::resultsPath)>;
      50             : 
      51             : private:
      52             :   int Nc{ 1 };                                      /*!< Number of clusters. */
      53             :   distMat_t distMat;                                /*!< Distance matrix. */
      54             :   Solver mipSolver{ settings::DEFAULT_MIP_SOLVER }; /*!< Solver for MIP. */
      55             : 
      56             :   bool is_distMat_filled{ false }; /*!< Flag indicating if the distance matrix is filled. */
      57             : 
      58             :   // Private functions:
      59             :   std::pair<int, double> cluster_by_kMedoidsPAM_single(int rep);
      60             : 
      61             :   void writeBestRep(int best_rep);
      62             :   void writeMedoids(std::vector<std::vector<int>> &centroids_all, int rep, double total_cost);
      63             :   void distanceInClusters();
      64             : 
      65             : public:
      66             :   Method method{ Method::Kmedoids };         /*!< Clustering method. */
      67             :   int maxIter{ 100 };                        /*!< Maximum number of iteration for iterative-methods. */
      68             :   int N_repetition{ 1 };                     /*!< Repetition for iterative-methods. */
      69             :   int band{ settings::DEFAULT_BAND_LENGTH }; /*!< Band length for Sakoe-Chiba band, -1 for full DTW. */
      70             : 
      71             :   std::function<void(Problem &)> init_fun{ init::random }; /*!< Initialisation function. */
      72             : 
      73             :   path_t output_folder{ settings::resultsPath }; /*!< Output folder for results. */
      74             :   std::string name{};                            /*!< Problem name. */
      75             :   Data data;                                     /*!< Data associated with the problem. */
      76             : 
      77             :   std::vector<int> clusters_ind;  //!< Indices of which point belongs to which cluster. [0,Nc)
      78             :   std::vector<int> centroids_ind; //!< indices of cluster centroids. [0, Np)
      79             : 
      80             :   // Constructors:
      81             :   Problem() = default;
      82             :   Problem(std::string_view name_) : name{ name_ } {}
      83           1 :   Problem(std::string_view name_, DataLoader &loader_)
      84           1 :     : name{ name_ }, data{ loader_.load() }
      85             :   {
      86           1 :     refreshDistanceMatrix();
      87           1 :   }
      88             : 
      89             :   auto size() const { return data.size(); }
      90           2 :   auto cluster_size() const { return Nc; }
      91             :   auto &get_name(size_t i) { return data.p_names[i]; }
      92             :   auto const &get_name(size_t i) const { return data.p_names[i]; }
      93             : 
      94             :   auto &p_vec(size_t i) { return data.p_vec[i]; }
      95             :   auto const &p_vec(size_t i) const { return data.p_vec[i]; }
      96             : 
      97             :   void refreshDistanceMatrix();
      98             :   void resize();
      99             : 
     100             :   // Getters and setters:
     101             :   int centroid_of(int i_p) const { return centroids_ind[clusters_ind[i_p]]; } // [0, Np) Get the centroid of the cluster of i_p
     102             : 
     103             :   void readDistanceMatrix(const fs::path &distMat_path);
     104             :   void set_numberOfClusters(int Nc_);
     105             :   void set_clusters(std::vector<int> &candidate_centroids);
     106             :   bool set_solver(dtwc::Solver solver_);
     107             : 
     108             :   void set_data(dtwc::Data data_)
     109             :   {
     110             :     data = data_;
     111             :     refreshDistanceMatrix();
     112             :   }
     113             : 
     114             :   data_t maxDistance() const { return distMat.max(); }
     115             :   data_t distByInd(int i, int j);
     116             :   bool isDistanceMatrixFilled() const { return is_distMat_filled; }
     117             : 
     118             :   void fillDistanceMatrix();
     119             :   void printDistanceMatrix() const;
     120             : 
     121             :   void writeDistanceMatrix(const std::string &name_) const;
     122             :   void writeDistanceMatrix() const { writeDistanceMatrix(name + "_distanceMatrix.csv"); }
     123             : 
     124             :   void printClusters() const;
     125             :   void writeClusters();
     126             : 
     127             :   void writeMedoidMembers(int iter, int rep = 0) const;
     128             :   void writeSilhouettes();
     129             : 
     130             :   // Initialisation of clusters:
     131             :   void init() { init_fun(*this); }
     132             : 
     133             :   // Clustering functions:
     134             :   void cluster();
     135             :   void cluster_by_MIP();
     136             :   void cluster_by_kMedoidsPAM();
     137             : 
     138             :   void cluster_and_process();
     139             : 
     140             :   // Auxillary
     141             :   double findTotalCost();
     142             :   void assignClusters();
     143             : 
     144             :   void calculateMedoids();
     145             : };
     146             : 
     147             : 
     148             : } // namespace dtwc

Generated by: LCOV version 1.14