12from typing
import List, Tuple, Any, Union, Dict
17 sequence: Union[List[Tuple[Any]],
None] =
None,
18 loader: str =
"HepMC2DataLoader",
20 dbfile_fmt=
"file:/tmp/lamarr.{thread:016x}.db",
21 clean_before_loading=
True,
22 clean_after_finishing=
True,
24 self.
logger = logging.getLogger(self.__class__.__name__)
25 PyLamarr.configure_logger()
26 self.
logger.info(f
"Python {sys.version}".replace(
"\n",
" "))
28 with sqlite3.connect(
":memory:")
as c:
29 sqlite_version = c.execute(
"SELECT sqlite_version()").fetchall()[0][0]
30 self.
logger.info(f
"Running with SQLite version {sqlite_version} "
31 f
"(bindings: {sqlite3.version})")
35 except (ImportError, OSError):
37 self.
logger.warning(f
"SQLite not found. "
38 "You can still build a configuration, but not run it.")
41 self.
logger.info(f
"Running with SQLamarr version {SQLamarr.version}")
51 def default_sequence(self):
62 return getattr(SQLamarr, self.
_loader)
70 def loader(self, new_loader):
74 def batch(self, new_batch):
78 def _batched(batch, batch_size):
80 raise ValueError(
"Batch size must be larger than 1")
84 batch = tuple(itertools.islice(it, batch_size))
91 load_args: List[Tuple[Dict]],
92 thread_id: Union[int,
None] =
None
95 raise ImportError(
"SQLite is needed for pipeline.execute(). "
96 "Please reinstall as `pip install PyLamarr[SQLamarr]`")
100 tid = thread_id
if thread_id
is not None else threading.get_ident()
104 self.
logger.info(f
"Connecting to SQLite db: {parsed_fmt}")
117 self.
logger.info(f
"Algorithms:")
119 self.
logger.info(f
" {iAlg:>2d}. {name}")
122 for load_arg
in batch:
123 self.
logger.info(f
"Loading {load_arg}")
124 if isinstance(load_arg, (list, tuple)):
125 sub_batch_generator = loader.load(*load_arg)
126 elif isinstance(load_arg, (dict,)):
127 sub_batch_generator = loader.load(**load_arg)
129 sub_batch_generator = loader.load(load_arg)
131 for sub_batch
in sub_batch_generator:
132 self.
logger.info(f
"Processing {sub_batch}")
134 self.
logger.debug(
"Cleaning database for processing a new batch")
137 self.
logger.warning(
"Cleaning database was DISABLED")
139 self.
logger.debug(f
"Executing pipeline on a batch of {len(sub_batch)} events")
142 self.
logger.debug(f
"Executing the pipeline")
145 self.
logger.debug(f
"Completed processing of batch")
148 if parsed_fmt.startswith(
"file:"):
149 if "mode=memory" not in parsed_fmt:
150 if '?' in parsed_fmt:
151 os.remove(parsed_fmt[len(
'file:'):parsed_fmt.index(
'?')])
153 os.remove(parsed_fmt[len(
'file:'):])
155 os.remove(parsed_fmt)
158 def to_xml(self, file_like) -> None:
162 if hasattr(w,
'to_xml'):
163 w.to_xml(root).attrib[
'step'] = k
165 self.
logger.warning(f
"XML serialization unavailable for {k}. Skipped.")
167 file_like.write(e3.tostring(root, encoding=
'unicode'))
170 def read_xml(cls, file_like):
171 root = e3.fromstring(file_like.read())
172 if root.tag.lower()
not in [
'pipeline']:
173 raise IOError(f
"Unexpected ROOT tag {root.tag}")
175 batch_size = int(root.attrib.get(
"batch", 1))
180 alg_name = child.attrib.get(
"name", alg_type)
181 step_name = child.attrib.get(
'step', alg_name)
183 for cfg_node
in child:
184 if cfg_node.tag.lower() ==
"config":
185 if cfg_node.attrib[
'type'] ==
'str':
186 config[cfg_node.attrib[
'key']] = cfg_node.text
187 elif cfg_node.attrib[
'type'] ==
'seq':
188 config[cfg_node.attrib[
'key']] = cfg_node.text.split(
";")
189 elif cfg_node.attrib[
'type'] ==
'url':
190 config[cfg_node.attrib[
'key']] = RemoteRes(cfg_node.text)
192 raise NotImplementedError(
193 f
"Unexpected type {cfg_node.attrib['type']} for {cfg_node.attrib['key']}"
196 algs.append((step_name,
GenericWrapper(implements=alg_type, config=config)))
198 return cls(algs, batch=batch_size)