Aktualizr
C++ SOTA Client
All Classes Namespaces Files Functions Variables Enumerations Enumerator Pages
sqlstorage_base.cc
1 #include "sqlstorage_base.h"
2 #include "storage_exception.h"
3 
4 #include <sys/stat.h>
5 
6 boost::filesystem::path SQLStorageBase::dbPath() const { return sqldb_path_; }
7 
8 StorageLock::StorageLock(boost::filesystem::path path) : lock_path(std::move(path)) {
9  {
10  std::fstream fs;
11  fs.open(lock_path.c_str(), std::fstream::in | std::fstream::out | std::fstream::app);
12  }
13  fl_ = lock_path.c_str();
14  if (!fl_.try_lock()) {
16  }
17 }
18 
19 StorageLock::~StorageLock() {
20  try {
21  if (!lock_path.empty()) {
22  fl_.unlock();
23  std::remove(lock_path.c_str());
24  }
25  } catch (std::exception& e) {
26  }
27 }
28 
29 SQLStorageBase::SQLStorageBase(boost::filesystem::path sqldb_path, bool readonly,
30  std::vector<std::string> schema_migrations,
31  std::vector<std::string> schema_rollback_migrations, std::string current_schema,
32  int current_schema_version)
33  : sqldb_path_(std::move(sqldb_path)),
34  readonly_(readonly),
35  mutex_(new std::mutex()),
36  schema_migrations_(std::move(schema_migrations)),
37  schema_rollback_migrations_(std::move(schema_rollback_migrations)),
38  current_schema_(std::move(current_schema)),
39  current_schema_version_(current_schema_version) {
40  boost::filesystem::path db_parent_path = dbPath().parent_path();
41  if (!boost::filesystem::is_directory(db_parent_path)) {
42  Utils::createDirectories(db_parent_path, S_IRWXU);
43  } else {
44  struct stat st {};
45  if (stat(db_parent_path.c_str(), &st) < 0) {
46  throw StorageException(std::string("Could not check storage directory permissions: ") + std::strerror(errno));
47  }
48  if ((st.st_mode & (S_IWGRP | S_IWOTH)) != 0) {
49  throw StorageException("Storage directory has unsafe permissions");
50  }
51  if ((st.st_mode & (S_IRGRP | S_IROTH)) != 0) {
52  // Remove read permissions for group and others
53  if (chmod(db_parent_path.c_str(), S_IRWXU) < 0) {
54  throw StorageException("Storage directory has unsafe permissions");
55  }
56  }
57  }
58 
59  if (!readonly) {
60  try {
61  lock = StorageLock(db_parent_path / "storage.lock");
62  } catch (StorageLock::locked_exception& e) {
63  LOG_WARNING << "\033[31m"
64  << "Storage in " << db_parent_path
65  << " is already in use, running several instances concurrently may result in data corruption!"
66  << "\033[0m";
67  }
68  }
69 
70  if (!dbMigrate()) {
71  throw StorageException("SQLite database migration failed");
72  }
73 }
74 
75 SQLite3Guard SQLStorageBase::dbConnection() const {
76  SQLite3Guard db(dbPath(), readonly_, mutex_);
77  if (db.get_rc() != SQLITE_OK) {
78  throw SQLException(std::string("Can't open database: ") + db.errmsg());
79  }
80  return db;
81 }
82 
83 std::string SQLStorageBase::getTableSchemaFromDb(const std::string& tablename) {
84  SQLite3Guard db = dbConnection();
85 
86  auto statement = db.prepareStatement<std::string>(
87  "SELECT sql FROM sqlite_master WHERE type='table' AND tbl_name=? LIMIT 1;", tablename);
88 
89  if (statement.step() != SQLITE_ROW) {
90  LOG_ERROR << "Can't get schema of " << tablename << ": " << db.errmsg();
91  return "";
92  }
93 
94  auto schema = statement.get_result_col_str(0);
95  if (schema == boost::none) {
96  return "";
97  }
98 
99  return schema.value() + ";";
100 }
101 
102 bool SQLStorageBase::dbInsertBackMigrations(SQLite3Guard& db, int version_latest) {
103  if (schema_rollback_migrations_.empty()) {
104  LOG_TRACE << "No backward migrations defined";
105  return true;
106  }
107 
108  if (schema_rollback_migrations_.size() < static_cast<size_t>(version_latest) + 1) {
109  LOG_ERROR << "Backward migrations from " << schema_rollback_migrations_.size() << " to " << version_latest
110  << " are missing";
111  return false;
112  }
113 
114  for (int k = 1; k <= version_latest; k++) {
115  if (schema_rollback_migrations_.at(static_cast<size_t>(k)).empty()) {
116  continue;
117  }
118  auto statement = db.prepareStatement("INSERT OR REPLACE INTO rollback_migrations VALUES (?,?);", k,
119  schema_rollback_migrations_.at(static_cast<uint32_t>(k)));
120  if (statement.step() != SQLITE_DONE) {
121  LOG_ERROR << "Can't insert rollback migration script: " << db.errmsg();
122  return false;
123  }
124  }
125 
126  return true;
127 }
128 
129 bool SQLStorageBase::dbMigrateForward(int version_from, int version_to) {
130  if (version_to <= 0) {
131  version_to = current_schema_version_;
132  }
133 
134  LOG_INFO << "Migrating DB from version " << version_from << " to version " << version_to;
135 
136  SQLite3Guard db = dbConnection();
137 
138  try {
139  db.beginTransaction();
140  } catch (const SQLException& e) {
141  return false;
142  }
143 
144  for (int32_t k = version_from + 1; k <= version_to; k++) {
145  auto result_code = db.exec(schema_migrations_.at(static_cast<size_t>(k)), nullptr, nullptr);
146  if (result_code != SQLITE_OK) {
147  LOG_ERROR << "Can't migrate DB from version " << (k - 1) << " to version " << k << ": " << db.errmsg();
148  return false;
149  }
150  }
151 
152  if (!dbInsertBackMigrations(db, version_to)) {
153  return false;
154  }
155 
156  db.commitTransaction();
157 
158  return true;
159 }
160 
161 bool SQLStorageBase::dbMigrateBackward(int version_from, int version_to) {
162  if (version_to <= 0) {
163  version_to = current_schema_version_;
164  }
165 
166  LOG_INFO << "Migrating DB backward from version " << version_from << " to version " << version_to;
167 
168  SQLite3Guard db = dbConnection();
169  for (int ver = version_from; ver > version_to; --ver) {
170  std::string migration;
171  {
172  // make sure the statement is destroyed before the next database operation
173  auto statement = db.prepareStatement("SELECT migration FROM rollback_migrations WHERE version_from=?;", ver);
174  if (statement.step() != SQLITE_ROW) {
175  LOG_ERROR << "Can't extract migration script: " << db.errmsg();
176  return false;
177  }
178  migration = *(statement.get_result_col_str(0));
179  }
180 
181  if (db.exec(migration, nullptr, nullptr) != SQLITE_OK) {
182  LOG_ERROR << "Can't migrate db from version " << (ver) << " to version " << ver - 1 << ": " << db.errmsg();
183  return false;
184  }
185  if (db.exec(std::string("DELETE FROM rollback_migrations WHERE version_from=") + std::to_string(ver) + ";", nullptr,
186  nullptr) != SQLITE_OK) {
187  LOG_ERROR << "Can't clear old migration script: " << db.errmsg();
188  return false;
189  }
190  }
191  return true;
192 }
193 
194 bool SQLStorageBase::dbMigrate() {
195  DbVersion schema_version = getVersion();
196 
197  if (schema_version == DbVersion::kInvalid) {
198  LOG_ERROR << "Sqlite database file is invalid.";
199  return false;
200  }
201 
202  auto schema_num_version = static_cast<int32_t>(schema_version);
203  if (schema_num_version == current_schema_version_) {
204  return true;
205  }
206 
207  if (readonly_) {
208  LOG_ERROR << "Database is opened in readonly mode and cannot be migrated to latest version";
209  return false;
210  }
211 
212  if (schema_version == DbVersion::kEmpty) {
213  LOG_INFO << "Bootstraping DB to version " << current_schema_version_;
214  SQLite3Guard db = dbConnection();
215 
216  try {
217  db.beginTransaction();
218  } catch (const SQLException& e) {
219  return false;
220  }
221 
222  auto result_code = db.exec(current_schema_, nullptr, nullptr);
223  if (result_code != SQLITE_OK) {
224  LOG_ERROR << "Can't bootstrap DB to version " << current_schema_version_ << ": " << db.errmsg();
225  return false;
226  }
227 
228  if (!dbInsertBackMigrations(db, current_schema_version_)) {
229  return false;
230  }
231 
232  db.commitTransaction();
233  } else if (schema_num_version > current_schema_version_) {
234  dbMigrateBackward(schema_num_version);
235  } else {
236  dbMigrateForward(schema_num_version);
237  }
238 
239  return true;
240 }
241 
242 DbVersion SQLStorageBase::getVersion() {
243  SQLite3Guard db = dbConnection();
244 
245  try {
246  auto statement = db.prepareStatement("SELECT count(*) FROM sqlite_master WHERE type='table';");
247  if (statement.step() != SQLITE_ROW) {
248  LOG_ERROR << "Can't get tables count: " << db.errmsg();
249  return DbVersion::kInvalid;
250  }
251 
252  if (statement.get_result_col_int(0) == 0) {
253  return DbVersion::kEmpty;
254  }
255 
256  statement = db.prepareStatement("SELECT version FROM version LIMIT 1;");
257 
258  if (statement.step() != SQLITE_ROW) {
259  LOG_ERROR << "Can't get database version: " << db.errmsg();
260  return DbVersion::kInvalid;
261  }
262 
263  try {
264  return DbVersion(statement.get_result_col_int(0));
265  } catch (const std::exception&) {
266  return DbVersion::kInvalid;
267  }
268  } catch (const SQLException&) {
269  return DbVersion::kInvalid;
270  }
271 }
STL namespace.