Aktualizr
C++ SOTA Client
All Classes Namespaces Files Functions Variables Enumerations Enumerator Pages
executor.h
1 #ifndef LT_EXECUTOR_H
2 #define LT_EXECUTOR_H
3 
4 #ifdef BUILD_OSTREE
5 #include <glib.h>
6 #endif
7 #include <atomic>
8 #include <boost/thread/latch.hpp>
9 #include <chrono>
10 #include <csignal>
11 #include <functional>
12 #include <iostream>
13 #include <thread>
14 #include <vector>
15 #include "logging/logging.h"
16 #include "stats.h"
17 
18 namespace timer = std::chrono;
19 
21  public:
22  virtual ~ExecutionController() = default;
23  virtual void stop() = 0;
24 
25  virtual bool claim() = 0;
26 };
27 
29  private:
30  std::atomic_bool stopped;
31 
32  public:
33  UnboundedExecutionController() : stopped{false} {}
34 
35  void stop() override { stopped = true; }
36 
37  bool claim() override { return !stopped; }
38 };
39 
41  private:
42  std::atomic_uint iterations;
43 
44  public:
45  FixedExecutionController(const unsigned int i) : iterations{i} {}
46 
47  void stop() override { iterations.store(0); }
48 
49  bool claim() override {
50  while (true) {
51  auto i = iterations.load();
52  if (i == 0) {
53  return false;
54  } else if (iterations.compare_exchange_strong(i, i - 1)) {
55  return true;
56  }
57  }
58  }
59 };
60 
62  private:
63  std::atomic_bool stopped;
64  static std::atomic_bool interrupted;
65  static void handleSignal(int) {
66  LOG_INFO << "SIGINT received";
67  interrupted = true;
68  }
69 
70  public:
71  InterruptableExecutionController() : stopped{false} {
72  std::signal(SIGINT, InterruptableExecutionController::handleSignal);
73  };
74 
75  bool claim() override { return !(interrupted || stopped); }
76 
77  void stop() override { stopped = true; }
78 };
79 
80 typedef timer::steady_clock::time_point TimePoint;
82  TimePoint startTime;
83  const timer::duration<int, std::milli> taskInterval;
84  std::atomic_ulong taskIndex;
85 
86  public:
87  TaskStartTimeCalculator(const unsigned rate) : startTime{}, taskInterval{std::milli::den / rate}, taskIndex{0} {}
88 
89  void start() { startTime = timer::steady_clock::now(); }
90 
91  TimePoint operator()() {
92  auto i = ++taskIndex;
93  return startTime + taskInterval * i;
94  }
95 };
96 
97 template <typename TaskStream>
98 class Executor {
99  std::unique_ptr<ExecutionController> controller;
100  std::vector<std::thread> workers;
101  std::vector<Statistics> statistics;
102  TaskStartTimeCalculator calculateTaskStartTime;
103  boost::latch threadCountDown;
104  boost::latch starter;
105  const std::string label;
106 
107  void runWorker(TaskStream &tasks, Statistics &stats) {
108 #ifdef BUILD_OSTREE
109  GMainContext *thread_context = g_main_context_new();
110  g_main_context_push_thread_default(thread_context);
111 #endif
112  using clock = std::chrono::steady_clock;
113  LOG_DEBUG << label << ": Worker created: " << std::this_thread::get_id();
114  threadCountDown.count_down();
115  starter.wait();
116  while (controller->claim()) {
117  auto task = tasks.nextTask();
118  const auto intendedStartTime = calculateTaskStartTime();
119  if (timer::steady_clock::now() < intendedStartTime) {
120  std::this_thread::sleep_until(intendedStartTime);
121  }
122  const clock::time_point start = clock::now();
123  task();
124  const clock::time_point end = clock::now();
125  std::chrono::milliseconds executionTime = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
126  stats.recordSuccess(executionTime);
127  }
128  LOG_DEBUG << label << ": Worker finished execution: " << std::this_thread::get_id();
129 #ifdef BUILD_OSTREE
130  g_main_context_pop_thread_default(thread_context);
131  g_main_context_unref(thread_context);
132 #endif
133  }
134 
135  public:
136  Executor(std::vector<TaskStream> &feeds, const unsigned rate, std::unique_ptr<ExecutionController> ctrl,
137  const std::string lbl)
138  : controller{std::move(ctrl)},
139  workers{},
140  statistics(feeds.size()),
141  calculateTaskStartTime{rate},
142  threadCountDown{feeds.size()},
143  starter{1},
144  label{lbl} {
145  workers.reserve(feeds.size());
146  try {
147  for (size_t i = 0; i < feeds.size(); i++) {
148  workers.push_back(std::thread(&Executor::runWorker, this, std::ref(feeds[i]), std::ref(statistics[i])));
149  }
150  } catch (...) {
151  controller->stop();
152  throw;
153  }
154  };
155 
156  Statistics run() {
157  Statistics summary{};
158  // wait till all threads are crerated and ready to go
159  LOG_INFO << label << ": Waiting for threads to start";
160  threadCountDown.wait();
161  calculateTaskStartTime.start();
162  summary.start();
163  LOG_INFO << label << ": Starting tests";
164  // start execution
165  starter.count_down();
166  // wait till all threads finished execution
167  for (size_t i = 0; i < workers.size(); i++) {
168  if (workers[i].joinable()) {
169  workers[i].join();
170  }
171  }
172 
173  summary.stop();
174  for (size_t i = 0; i < statistics.size(); i++) {
175  summary += statistics[i];
176  }
177  std::cout << "Results for: " << label << std::endl;
178  summary.print();
179  return summary;
180  };
181 };
182 
183 #endif