1 #ifndef CGAL_SURFACE_MESH_SEGMENTATION_EXPECTATION_MAXIMIZATION_H
2 #define CGAL_SURFACE_MESH_SEGMENTATION_EXPECTATION_MAXIMIZATION_H
9 #include <CGAL/internal/Surface_mesh_segmentation/K_means_clustering.h>
10 #include <CGAL/assertions.h>
12 #define CGAL_DEFAULT_MAXIMUM_ITERATION 10
13 #define CGAL_DEFAULT_NUMBER_OF_RUN 15
14 #define CGAL_DEFAULT_THRESHOLD 1e-3
16 #define CGAL_DEFAULT_SEED 1340818006
26 class Expectation_maximization
38 double mixing_coefficient;
40 Gaussian_center(): mean(0), deviation(0), mixing_coefficient(1.0)
42 Gaussian_center(
double mean,
double deviation,
double mixing_coefficient)
43 : mean(mean), deviation(deviation), mixing_coefficient(mixing_coefficient)
51 double probability(
double x)
const
53 double e_over = -0.5 * std::pow((x - mean) / deviation, 2);
54 return exp(e_over) / deviation;
61 double probability_with_coef(
double x)
const
63 return probability(x) * mixing_coefficient;
66 bool operator < (
const Gaussian_center& center)
const
68 return mean < center.mean;
73 enum Initialization_types
75 RANDOM_INITIALIZATION,
77 K_MEANS_INITIALIZATION
80 double final_likelihood;
82 std::vector<Gaussian_center> centers;
83 std::vector<double> points;
84 std::vector<std::vector<double> > responsibility_matrix;
87 int maximum_iteration;
89 Initialization_types init_type;
105 Expectation_maximization(
int number_of_centers,
106 const std::vector<double>& data,
107 Initialization_types init_type = PLUS_INITIALIZATION,
108 int number_of_runs = CGAL_DEFAULT_NUMBER_OF_RUN,
109 double threshold = CGAL_DEFAULT_THRESHOLD,
110 int maximum_iteration = CGAL_DEFAULT_MAXIMUM_ITERATION )
112 final_likelihood(-(std::numeric_limits<double>::max)()), points(data),
113 responsibility_matrix(std::vector<std::vector<double> >(number_of_centers, std::vector<double>(points.size()))),
114 threshold(threshold), maximum_iteration(maximum_iteration), init_type(init_type)
117 if(init_type == K_MEANS_INITIALIZATION)
119 K_means_clustering k_means(number_of_centers, data, K_means_clustering::PLUS_INITIALIZATION,
120 number_of_runs, maximum_iteration);
121 std::vector<int> initial_center_ids;
122 k_means.fill_with_center_ids(initial_center_ids);
124 initiate_centers_from_memberships(number_of_centers, initial_center_ids);
125 calculate_clustering();
130 srand(CGAL_DEFAULT_SEED);
131 calculate_clustering_with_multiple_run(number_of_centers, number_of_runs);
133 sort(centers.begin(), centers.end());
140 void fill_with_center_ids(std::vector<int>& data_centers)
142 data_centers.reserve(points.size());
143 for(std::vector<double>::iterator point_it = points.begin();
144 point_it != points.end(); ++point_it)
146 double max_likelihood = 0.0;
147 int max_center = -1, center_counter = 0;
148 for(std::vector<Gaussian_center>::iterator center_it = centers.begin();
149 center_it != centers.end(); ++center_it, ++center_counter)
151 double likelihood = center_it->probability_with_coef(*point_it);
152 if(max_likelihood < likelihood)
154 max_likelihood = likelihood;
155 max_center = center_counter;
158 data_centers.push_back(max_center);
166 void fill_with_probabilities(std::vector<std::vector<double> >& probabilities)
168 probabilities = std::vector<std::vector<double> >
169 (centers.size(), std::vector<double>(points.size()));
170 for(std::size_t point_i = 0; point_i < points.size(); ++point_i)
172 double total_probability = 0.0;
173 for(std::size_t center_i = 0; center_i < centers.size(); ++center_i)
175 double probability = centers[center_i].probability_with_coef(points[point_i]);
176 total_probability += probability;
177 probabilities[center_i][point_i] = probability;
179 for(std::size_t center_i = 0; center_i < centers.size(); ++center_i)
181 probabilities[center_i][point_i] /= total_probability;
191 void calculate_initial_deviations()
193 std::vector<int> member_count(centers.size(), 0);
194 for(std::vector<double>::iterator it = points.begin(); it!= points.end(); ++it)
196 int closest_center = 0;
197 double min_distance = std::abs(centers[0].mean - *it);
198 for(std::size_t i = 1; i < centers.size(); ++i)
200 double distance = std::abs(centers[i].mean - *it);
201 if(distance < min_distance)
203 min_distance = distance;
207 member_count[closest_center]++;
208 centers[closest_center].deviation += min_distance * min_distance;
210 for(std::size_t i = 0; i < centers.size(); ++i)
213 CGAL_assertion(member_count[i] != 0);
214 centers[i].deviation = std::sqrt(centers[i].deviation / member_count[i]);
222 void initiate_centers_randomly(
int number_of_centers)
225 double initial_mixing_coefficient = 1.0 / number_of_centers;
226 double initial_deviation = 0.0;
227 for(
int i = 0; i < number_of_centers; ++i)
229 int random_index = rand() % points.size();
230 double initial_mean = points[random_index];
232 if(!make_center(initial_mean, initial_deviation, initial_mixing_coefficient))
235 calculate_initial_deviations();
243 void initiate_centers_plus_plus(
int number_of_centers)
246 double initial_deviation = 0.0;
247 double initial_mixing_coefficient = 1.0 / number_of_centers;
249 std::vector<double> distance_square_cumulative(points.size());
250 std::vector<double> distance_square(points.size(), (std::numeric_limits<double>::max)());
254 double initial_mean = points[rand() % points.size()];
255 make_center(initial_mean, initial_deviation, initial_mixing_coefficient);
257 for(
int i = 1; i < number_of_centers; ++i)
259 double cumulative_distance_square = 0.0;
260 for(std::size_t j = 0; j < points.size(); ++j)
262 double new_distance = std::pow(centers.back().mean - points[j], 2);
263 if(new_distance < distance_square[j]) { distance_square[j] = new_distance; }
264 cumulative_distance_square += distance_square[j];
265 distance_square_cumulative[j] = cumulative_distance_square;
268 double zero_one = rand() / (RAND_MAX + 1.0);
269 double random_ds = zero_one * (distance_square_cumulative.back());
270 int selection_index = std::upper_bound(distance_square_cumulative.begin(), distance_square_cumulative.end(), random_ds)
271 - distance_square_cumulative.begin();
272 double initial_mean = points[selection_index];
274 if(!make_center(initial_mean, initial_deviation, initial_mixing_coefficient))
277 calculate_initial_deviations();
285 void initiate_centers_from_memberships(
int number_of_centers,
const std::vector<int>& initial_center_ids)
288 int number_of_points = initial_center_ids.size();
289 centers = std::vector<Gaussian_center>(number_of_centers);
290 std::vector<int> member_count(number_of_centers, 0);
292 for(
int i = 0; i < number_of_points; ++i)
294 int center_id = initial_center_ids[i];
295 centers[center_id].mean += points[i];
296 member_count[center_id] += 1;
299 for(
int i = 0; i < number_of_centers; ++i)
301 centers[i].mean /= member_count[i];
302 centers[i].mixing_coefficient = member_count[i] /
static_cast<double>(number_of_points);
305 for(
int i = 0; i < number_of_points; ++i)
307 int center_id = initial_center_ids[i];
308 centers[center_id].deviation += std::pow(points[i] - centers[center_id].mean, 2);
310 for(
int i = 0; i < number_of_centers; ++i)
312 CGAL_assertion(member_count[i] != 0);
313 centers[i].deviation = std::sqrt(centers[i].deviation / member_count[i]);
322 bool is_already_center(
const Gaussian_center& center)
const
324 for(std::vector<Gaussian_center>::const_iterator it = centers.begin(); it != centers.end(); ++it)
326 if(it->mean == center.mean) {
return true; }
338 bool make_center(
double mean,
double deviation,
double mixing_coefficient)
340 Gaussian_center new_center(mean, deviation, mixing_coefficient);
341 if(is_already_center(new_center)) {
return false; }
342 centers.push_back(new_center);
351 void calculate_parameters()
353 for(std::size_t center_i = 0; center_i < centers.size(); ++center_i)
356 double new_mean = 0.0, total_membership = 0.0;
357 for(std::size_t point_i = 0; point_i < points.size(); ++point_i)
359 double membership = responsibility_matrix[center_i][point_i];
360 new_mean += membership * points[point_i];
361 total_membership += membership;
363 new_mean /= total_membership;
366 double new_deviation = 0.0;
367 for(std::size_t point_i = 0; point_i < points.size(); ++point_i)
369 double membership = responsibility_matrix[center_i][point_i];
370 new_deviation += membership * std::pow(points[point_i] - new_mean, 2);
372 new_deviation = std::sqrt(new_deviation/total_membership);
375 centers[center_i].mixing_coefficient = total_membership / points.size();
376 centers[center_i].deviation = new_deviation;
377 centers[center_i].mean = new_mean;
386 double calculate_likelihood()
391 double likelihood = 0.0;
392 for(std::size_t point_i = 0; point_i < points.size(); ++point_i)
394 double total_membership = 0.0;
395 for(std::size_t center_i = 0; center_i < centers.size(); ++center_i)
397 double membership = centers[center_i].probability_with_coef(points[point_i]);
398 total_membership += membership;
399 responsibility_matrix[center_i][point_i] = membership;
401 for(std::size_t center_i = 0; center_i < centers.size(); ++center_i)
403 responsibility_matrix[center_i][point_i] /= total_membership;
405 likelihood += log(total_membership);
416 double iterate(
bool first_iteration)
421 if(first_iteration) { calculate_likelihood(); }
424 calculate_parameters();
427 return calculate_likelihood();
437 double calculate_clustering()
439 double likelihood = -(std::numeric_limits<double>::max)(), prev_likelihood;
440 int iteration_count = 0;
441 double is_converged =
false;
442 while(!is_converged && iteration_count++ < maximum_iteration)
444 prev_likelihood = likelihood;
445 likelihood = iterate(iteration_count == 1);
446 double progress = likelihood - prev_likelihood;
447 is_converged = progress < threshold * std::abs(likelihood);
449 if(final_likelihood < likelihood) { final_likelihood = likelihood; }
460 void calculate_clustering_with_multiple_run(
int number_of_centers,
int number_of_run)
462 std::vector<Gaussian_center> max_centers;
464 while(number_of_run-- > 0)
466 init_type == RANDOM_INITIALIZATION ? initiate_centers_randomly(number_of_centers)
467 : initiate_centers_plus_plus(number_of_centers);
469 double likelihood = calculate_clustering();
470 if(likelihood == final_likelihood) { max_centers = centers; }
472 centers = max_centers;
478 #undef CGAL_DEFAULT_SEED
479 #undef CGAL_DEFAULT_MAXIMUM_ITERATION
480 #undef CGAL_DEFAULT_THRESHOLD
481 #undef CGAL_DEFAULT_NUMBER_OF_RUN
482 #endif //CGAL_SURFACE_MESH_SEGMENTATION_EXPECTATION_MAXIMIZATION_H