Line data Source code
1 : /** 2 : * @file Problem.hpp 3 : * @brief Encapsulates the DTWC (Dynamic Time Warping Clustering) problem in a class. 4 : * 5 : * @details This file contains the definition of the Problem class used in DTWC applications. 6 : * It includes various methods for manipulating and analyzing clusters. 7 : * 8 : * @date 19 Oct 2022 9 : * @author Volkan Kumtepeli 10 : * @author Becky Perriment 11 : */ 12 : 13 : #pragma once 14 : 15 : #include "Data.hpp" // for Data 16 : #include "DataLoader.hpp" // for DataLoader 17 : #include "fileOperations.hpp" // for writeMatrix, readMatrix 18 : #include "settings.hpp" // for data_t, resultsPath 19 : #include "enums/enums.hpp" // for using Enum types. 20 : #include "initialisation.hpp" // for init functions 21 : 22 : #include <cstddef> // for size_t 23 : #include <filesystem> // for operator/, path 24 : #include <ostream> // for operator<<, basic_ostream, ofstream 25 : #include <string> // for char_traits, operator+, operator<< 26 : #include <string_view> // for string_view 27 : #include <utility> // for pair 28 : #include <vector> // for vector, allocator 29 : #include <type_traits> // std::decay_t 30 : #include <functional> // std::function 31 : #include <iostream> 32 : 33 : #include <armadillo> 34 : 35 : namespace dtwc { 36 : 37 : /** 38 : * @class Problem 39 : * @brief Class representing a problem in DTWC. 40 : * 41 : * @details This class encapsulates all the functionalities and data structures required to solve 42 : * a dynamic time warping clustering problem. It includes methods for initialising clusters, 43 : * calculating distances, clustering, and writing results. 44 : */ 45 : class Problem 46 : { 47 : public: 48 : using distMat_t = arma::Mat<double>; 49 : using path_t = std::decay_t<decltype(settings::resultsPath)>; 50 : 51 : private: 52 : int Nc{ 1 }; /*!< Number of clusters. */ 53 : distMat_t distMat; /*!< Distance matrix. */ 54 : Solver mipSolver{ settings::DEFAULT_MIP_SOLVER }; /*!< Solver for MIP. */ 55 : 56 : bool is_distMat_filled{ false }; /*!< Flag indicating if the distance matrix is filled. */ 57 : 58 : // Private functions: 59 : std::pair<int, double> cluster_by_kMedoidsPAM_single(int rep); 60 : 61 : void writeBestRep(int best_rep); 62 : void writeMedoids(std::vector<std::vector<int>> ¢roids_all, int rep, double total_cost); 63 : void distanceInClusters(); 64 : 65 : public: 66 : Method method{ Method::Kmedoids }; /*!< Clustering method. */ 67 : int maxIter{ 100 }; /*!< Maximum number of iteration for iterative-methods. */ 68 : int N_repetition{ 1 }; /*!< Repetition for iterative-methods. */ 69 : int band{ settings::DEFAULT_BAND_LENGTH }; /*!< Band length for Sakoe-Chiba band, -1 for full DTW. */ 70 : 71 : std::function<void(Problem &)> init_fun{ init::random }; /*!< Initialisation function. */ 72 : 73 : path_t output_folder{ settings::resultsPath }; /*!< Output folder for results. */ 74 : std::string name{}; /*!< Problem name. */ 75 : Data data; /*!< Data associated with the problem. */ 76 : 77 : std::vector<int> clusters_ind; //!< Indices of which point belongs to which cluster. [0,Nc) 78 : std::vector<int> centroids_ind; //!< indices of cluster centroids. [0, Np) 79 : 80 : // Constructors: 81 : Problem() = default; 82 : Problem(std::string_view name_) : name{ name_ } {} 83 1 : Problem(std::string_view name_, DataLoader &loader_) 84 1 : : name{ name_ }, data{ loader_.load() } 85 : { 86 1 : refreshDistanceMatrix(); 87 1 : } 88 : 89 : auto size() const { return data.size(); } 90 2 : auto cluster_size() const { return Nc; } 91 : auto &get_name(size_t i) { return data.p_names[i]; } 92 : auto const &get_name(size_t i) const { return data.p_names[i]; } 93 : 94 : auto &p_vec(size_t i) { return data.p_vec[i]; } 95 : auto const &p_vec(size_t i) const { return data.p_vec[i]; } 96 : 97 : void refreshDistanceMatrix(); 98 : void resize(); 99 : 100 : // Getters and setters: 101 : int centroid_of(int i_p) const { return centroids_ind[clusters_ind[i_p]]; } // [0, Np) Get the centroid of the cluster of i_p 102 : 103 : void readDistanceMatrix(const fs::path &distMat_path); 104 : void set_numberOfClusters(int Nc_); 105 : void set_clusters(std::vector<int> &candidate_centroids); 106 : bool set_solver(dtwc::Solver solver_); 107 : 108 : void set_data(dtwc::Data data_) 109 : { 110 : data = data_; 111 : refreshDistanceMatrix(); 112 : } 113 : 114 : data_t maxDistance() const { return distMat.max(); } 115 : data_t distByInd(int i, int j); 116 : bool isDistanceMatrixFilled() const { return is_distMat_filled; } 117 : 118 : void fillDistanceMatrix(); 119 : void printDistanceMatrix() const; 120 : 121 : void writeDistanceMatrix(const std::string &name_) const; 122 : void writeDistanceMatrix() const { writeDistanceMatrix(name + "_distanceMatrix.csv"); } 123 : 124 : void printClusters() const; 125 : void writeClusters(); 126 : 127 : void writeMedoidMembers(int iter, int rep = 0) const; 128 : void writeSilhouettes(); 129 : 130 : // Initialisation of clusters: 131 : void init() { init_fun(*this); } 132 : 133 : // Clustering functions: 134 : void cluster(); 135 : void cluster_by_MIP(); 136 : void cluster_by_kMedoidsPAM(); 137 : 138 : void cluster_and_process(); 139 : 140 : // Auxillary 141 : double findTotalCost(); 142 : void assignClusters(); 143 : 144 : void calculateMedoids(); 145 : }; 146 : 147 : 148 : } // namespace dtwc