SQLamarr
The stand-alone ultra-fast simulation option for the LHCb experiment
BasePlugin.cpp
1 // (c) Copyright 2022 CERN for the benefit of the LHCb Collaboration.
2 //
3 // This software is distributed under the terms of the GNU General Public
4 // Licence version 3 (GPL Version 3), copied verbatim in the file "LICENCE".
5 //
6 // In applying this licence, CERN does not waive the privileges and immunities
7 // granted to it by virtue of its status as an Intergovernmental Organization
8 // or submit itself to any jurisdiction.
9 
10 
11 // Standard
12 #include <iostream>
13 #include <cmath>
14 #include <sstream>
15 #include <algorithm>
16 #include <iterator>
17 
18 // SQLite3
19 #include "sqlite3.h"
20 
21 // SQLamarr
22 #include "SQLamarr/BasePlugin.h"
23 
24 namespace SQLamarr
25 {
26  //============================================================================
27  // Constructor
28  //============================================================================
30  SQLite3DB& db,
31  const std::string& library,
32  const std::string& function_name,
33  const std::string& select_query,
34  const std::string& output_table,
35  const std::vector<std::string> outputs,
36  const std::vector<std::string> reference_keys
37  )
38  : BaseSqlInterface(db)
39  , m_library (library)
40  , m_function_name (function_name)
41  , m_select_query (select_query)
42  , m_output_table (output_table)
43  , m_outputs (outputs)
44  , m_refkeys (reference_keys)
45  , m_handle (dlopen(library.c_str(), RTLD_LAZY))
46  {
47  if (!m_handle)
48  {
49  std::cerr << "Failure while loading " << m_library << std::endl;
50  throw std::runtime_error("Failed loading library");
51  }
52 
53  // Throw an error if tokens are not alphanumeric (possible SQL injection)
54  validate_token(m_output_table);
55  for (const std::string& t: m_outputs) validate_token(t);
56  for (const std::string& t: m_refkeys) validate_token(t);
57  }
58 
59  //============================================================================
60  // get_column_names. Internal.
61  //============================================================================
62  std::vector<std::string> BasePlugin::get_column_names() const
63  {
64  std::vector<std::string> ret;
65  ret.insert(ret.end(), m_refkeys.begin(), m_refkeys.end());
66  ret.insert(ret.end(), m_outputs.begin(), m_outputs.end());
67 
68  return ret;
69  }
70 
71  //============================================================================
72  // compose_delete_query. Internal.
73  //============================================================================
74  std::string BasePlugin::compose_delete_query()
75  {
76  std::stringstream s;
77  s << "DELETE FROM " << m_output_table << ";";
78  return s.str();
79  }
80 
81  //============================================================================
82  // compose_create_query. Internal.
83  //============================================================================
84  std::string BasePlugin::compose_create_query()
85  {
86  std::stringstream s;
87  s << "CREATE TEMPORARY TABLE IF NOT EXISTS "
88  << m_output_table << "(";
89 
90  for (auto c: m_refkeys)
91  s << c << " INTEGER" << ", ";
92 
93  for (auto c: m_outputs)
94  s << c << " REAL" << (c != m_outputs.back() ? ", ": "");
95 
96  s << ");";
97 
98  return s.str();
99  }
100 
101  //============================================================================
102  // compose_create_query. Internal.
103  //============================================================================
104  std::string BasePlugin::compose_insert_query()
105  {
106  std::stringstream s;
107  s << "INSERT INTO " << m_output_table << " (";
108 
109  std::vector<std::string> col_names = get_column_names();
110  for (auto c: col_names)
111  s << c << (c != col_names.back() ? ", ": "");
112  s << ") VALUES ( ";
113 
114  for (auto c: col_names)
115  s << "?" << (c != col_names.back() ? ", ": "");
116  s << ");";
117 
118  return s.str();
119  }
120 
121  //============================================================================
122  // execute
123  //============================================================================
125  {
127 
128  // Prepare the queries and initialize the database
129  // CREATE TEMPORARY TABLE IF NOT EXISTS
130  sqlite3_stmt* create_output_table = get_statement(
131  "create_output_table", compose_create_query().c_str()
132  );
133  exec_stmt(create_output_table);
134 
135  // DELETE FROM table
136  sqlite3_stmt* delete_output_table = get_statement(
137  "delete_output_table", compose_delete_query().c_str()
138  );
139  exec_stmt(delete_output_table);
140 
141  // INSERT INTO TABLE
142  sqlite3_stmt* insert_in_output_table = get_statement(
143  "insert_in_output_table", compose_insert_query().c_str()
144  );
145 
146  // SELECT ... FROM
147  sqlite3_stmt* select_input = get_statement(
148  "select_input",
149  m_select_query.c_str()
150  );
151 
152 
153  // Main loop on selected rows
154  while (exec_stmt(select_input))
155  {
156  sqlite3_reset(insert_in_output_table);
157  // Buffers for parametrization input and output
158  std::vector<float> input;
159  std::vector<float> output(m_outputs.size());
160 
161  // Loop on the columns of each row
162  const int nCols = sqlite3_column_count(select_input);
163  for (int iCol=0; iCol < nCols; ++iCol)
164  {
165  // Check for reserved column (external indices)
166  const std::string column(sqlite3_column_name(select_input, iCol));
167  std::vector<std::string> col_names = get_column_names();
168  auto col_iterator = std::find(m_refkeys.begin(), m_refkeys.end(), column);
169 
170  // if an index, wraps it to the insert query
171  if (col_iterator != m_refkeys.end())
172  {
173 
174  sqlite3_bind_int(
175  insert_in_output_table,
176  1 + col_iterator - m_refkeys.begin(),
177  sqlite3_column_int(select_input, iCol)
178  );
179  }
180  else // otherwise, it is an input for the parametrization
181  input.push_back(read_as_float(select_input, iCol));
182  }
183 
184  // Execute the external function defining the parametrization
185 
186  eval_parametrization(output.data(), input.data());
187 
188  // Fill the output table with the parametrization output
189  for (size_t iOutput = 0; iOutput < m_outputs.size(); ++iOutput)
190  sqlite3_bind_double(
191  insert_in_output_table,
192  m_refkeys.size() + iOutput + 1,
193  output[iOutput]
194  );
195 
196  exec_stmt(insert_in_output_table);
197  }
198 
199  end_transaction();
200  }
201 }
void execute() override
Execute the external function and copies the output in a new table.
Definition: BasePlugin.cpp:124
BasePlugin(SQLite3DB &db, const std::string &library, const std::string &function_name, const std::string &select_query, const std::string &output_table, const std::vector< std::string > outputs, const std::vector< std::string > reference_keys={"ref_id"})
Constructor.
Definition: BasePlugin.cpp:29
virtual void eval_parametrization(float *output, const float *input)=0
Evaluate the external parametrization.
Abstract interface with helper functions to access an SQLite DB.
sqlite3_stmt * get_statement(const std::string &name, const std::string &query)
Creates or retrieve from cache a statement.
void begin_transaction()
Begin an SQL transaction stopping update to disk util end_transaction() is issued
void end_transaction()
End an SQL transaction re-enabling disk updates.
bool exec_stmt(sqlite3_stmt *)
Execute a statement, possibly throwing an exception on failure.
A database connection handler easying sharing the DB between C++ and Python.
Definition: db_functions.py:24
void validate_token(const std::string &token)
Ensure a token is alphanumeric.
float read_as_float(sqlite3_stmt *, int)
Read a column field from a sqlite3 statement and convert it to float.