1import logging
2import os
3import string
4import warnings
5from getpass import getuser
6from hashlib import sha256
7from logging import warning
8from urllib.parse import quote_plus
9
10import sqlalchemy
11from packaging.version import Version
12from sqlalchemy.exc import StatementError
13
14from omegaml.backends.basedata import BaseDataBackend
15from omegaml.util import ProcessLocal, KeepMissing, tqdm_if_interactive, signature
16
17try:
18 import snowflake
19
20 sql_logger = logging.getLogger('snowflake')
21 sql_logger.setLevel('CRITICAL')
22except:
23 pass
24
25try:
26 # enable pandas >= 2.2 compatibility with sqlalchemy 1.4
27 # -- workaround to due to https://github.com/pandas-dev/pandas/issues/57049
28 # -- this forces the use of pd.SQLDatabase instead of pd.SQLiteDatabase
29 # -- see https://github.com/pandas-dev/pandas/issues/57049#issuecomment-3398561199
30 import pandas as pd
31 import sqlalchemy as sqa
32
33 if (Version(pd.__version__) >= Version("2.2") and
34 Version(sqa.__version__) < Version("2.2")):
35 from pandas.compat._optional import VERSIONS
36
37 VERSIONS['sqlalchemy'] = '1.4'
38 warnings.warn(
39 "Patching pandas > 2.2 to support sqlalchemy >=1.4,<2 due to https://github.com/pandas-dev/pandas/issues/57049. To avoid this warning upgrade to sqlalchemy 2.x")
40except:
41 pass
42
43#: override by setting om.defaults.SQLALCHEMY_ALWAYS_CACHE
44ALWAYS_CACHE = True
45# -- enabled by default as this is the least-surprised option
46# -- consistent with sqlalchemy connection pooling defaults
47#: kwargs for create_engine()
48ENGINE_KWARGS = dict(echo=False, pool_pre_ping=True, pool_recycle=3600)
49# -- echo=False - do not log to stdout
50# -- pool_pre_ping=True - always check, re-establish connection if no longer working
51# -- pool_recylce=N - do not reuse connections older than N seconds
52
53logger = logging.getLogger(__name__)
54
55
[docs]
56class SQLAlchemyBackend(BaseDataBackend):
57 """
58 sqlalchemy plugin for omegaml
59
60 Usage:
61
62 Define your sqlalchemy connection::
63
64 sqlalchemy_constr = f'sqlalchemy://{user}:{password}@{account}/'
65
66 Store the connection in any of three ways::
67
68 # -- just the connection
69 om.datasets.put(sqlalchemy_constr, 'mysqlalchemy')
70 om.datasets.get('mysqlalchemy', raw=True)
71 => the sql connection object
72
73 # -- store connection with a predefined sql
74 om.datasets.put(sqlalchemy_constr, 'mysqlalchemy',
75 sql='select ....')
76 om.datasets.get('mysqlalchemy')
77 => will return a pandas dataframe using the specified sql to run.
78 specify chunksize= to return an interable of dataframes
79
80 # -- predefined sqls can contain variables to be resolved at access time
81 # if you miss to specify required variables in sqlvars, a KeyError is raised
82 om.datasets.put(sqlaclhemy_constr, 'myview',
83 sql='select ... from T1 where col="{var}"')
84 om.datasets.get('mysqlalchemy', sqlvars=dict(var="value"))
85
86 # -- Variables are replaced by binding parameters, which is safe for
87 # untrusted inputs. To replace variables as strings, use double
88 # `{{variable}}` notation. A warning will be issued because this
89 # is considered an unsafe practice for untrusted input (REST API).
90 # It is in your responsibility to sanitize the value of the `cols` variable.
91 om.datasets.put(sqlaclhemy_constr, 'myview',
92 sql='select {{cols}} from T1 where col="{var}"')
93 om.datasets.get('mysqlalchemy', sqlvars=dict(cols='foo, bar',
94 var="value"))
95
96 Query data from a connection and store into an omega-ml dataset::
97
98 # -- copy the result of the sqlalchemy query to omegaml
99 om.datasets.put(sqlalchemy_constr, 'mysqlalchemy',
100 sql='select ...', copy=True)
101 om.datasets.get('mysqlalchemy')
102 => will return a pandas dataframe (without executing any additional queries)
103 => can also use with om.datasets.getl('mysqlalchemy') to return a MDataFrame
104
105 Controlling the table used in the connection::
106
107 # -- the default table is {bucket}_{name}, override using table='myname'
108 om.datasets.put(sqlalchemy_constr, 'mysqlalchemy',
109 table='mytable',
110 sql='select ...',
111 copy=True)
112 om.datasets.get('mysqlalchemy') # read from {bucket}_myname
113
114 # -- to use a specific table, without bucket information use table=':myname'
115 om.datasets.put(sqlalchemy_constr, 'mysqlalchemy',
116 table=':mytable',
117 sql='select ...',
118 copy=True)
119 om.datasets.get('mysqlalchemy') # read from myname
120
121 Inserting data via a previously stored connection::
122
123 # -- store data back through the connection
124 om.datasets.put(sqlalchemy_constr, 'mysqlalchemy')
125 om.datasets.put(df, 'mysqlalchemy',
126 table='SOMETABLE')
127
128 Using variables in connection strings:
129
130 Connection strings may contain variables, e.g. userid and password.
131 By default variables are resolved from the os environment. Can also
132 specify using any dict.::
133
134 # -- use connection string with variables
135 sqlalchemy_constr = 'sqlite:///{dbname}.db'
136 om.datasets.put(sqlalchemy_constr, 'userdb')
137 om.datasets.get('userdb', secrets=dict(dbname='chuckdb'))
138
139 # -- alternatively, create a vault dataset:
140 secrets = dict(userid='chuck', dbname='chuckdb')
141 om.datasets.put(secrets, '.omega/vault')
142 om.datasets.get('userdb')
143
144 the '.omega/vault' dataset will be queried using the current userid as
145 the secret name, and the dbname retrieved from the document. This is
146 experimental and the vault is not encrypted.
147
148 Advanced:
149
150 ``om.datasets.put()`` supports the following additional keyword arguments
151
152 * ``chunksize=int`` - specify the number of rows to read from sqlalchemy in one chunk.
153 defaults to 10000
154
155 * ``parse_dates=['col', ...]`` - list of column names to parse for date, time or datetime.
156 see pd.read_sql for details
157
158 * ``transform=callable`` - a callable, is passed the DataFrame of each chunk before it
159 is inserted into the database. use to provide custom transformations.
160 only works on copy=True
161
162 * any other kwargs supported by ``pandas.read_sql``
163
164 """
165 KIND = 'sqlalchemy.conx'
166 PROMOTE = 'metadata'
167
168 #: sqlalchemy.Engine cache to enable pooled connections
169 __CNX_CACHE = ProcessLocal()
170
171 # -- https://docs.sqlalchemy.org/en/14/core/pooling.html#module-sqlalchemy.pool
172 # -- create_engine() must be called per-process, hence using ProcessLocal
173 # -- meaning when using a multiprocessing.Pool or other fork()-ed processes,
174 # the cache will be cleared in child processes, forcing the engine to be
175 # recreated automatically in _get_connection
176
177 def __init__(self, model_store=None, data_store=None, tracking=None, **kwargs):
178 super().__init__(model_store=model_store, data_store=data_store, tracking=tracking, **kwargs)
179
[docs]
180 @classmethod
181 def supports(cls, obj, name, insert=False, data_store=None, model_store=None, *args, **kwargs):
182 valid = cls._is_valid_url(cls, obj)
183 support_via = cls._supports_via(cls, data_store, name, obj)
184 return valid or support_via
185
186 def drop(self, name, secrets=None, **kwargs):
187 # ensure cache is cleared
188 clear_cache = True if secrets is None else False
189 try:
190 self.get(name, secrets=secrets, raw=True, keep=False)
191 except KeyError as e:
192 warnings.warn(f'Connection cache was cleared, however secret {e} was missing.')
193 clear_cache = True
194 if clear_cache:
195 self.__CNX_CACHE.clear()
196 return super().drop(name, **kwargs)
197
198 def sign(self, values):
199 return signature(values)
200
[docs]
201 def get(self, name, sql=None, chunksize=None, raw=False, sqlvars=None,
202 secrets=None, index=True, keep=None, lazy=False, table=None, trusted=False, *args, **kwargs):
203 """ retrieve a stored connection or query data from connection
204
205 Args:
206 name (str): the name of the connection
207 secrets (dict): dict to resolve variables in the connection string
208 keep (bool): if True connection is kept open, defaults to True (change
209 default by setting om.defaults.SQLALCHEMY_ALWAYS_CACHE = False)
210 table (str): the name of the table, will be prefixed with the
211 store's bucket name unless the table is specified as ':name'
212 trusted (bool|str): if passed must be the value for store.sign(sqlvars or kwargs),
213 otherwise a warning is issued for any remaining variables in the sql statement
214
215 Returns:
216 connection
217
218 To query data and return a DataFrame, specify ``sql='select ...'``:
219
220 Args:
221 sql (str): the sql query, defaults to the query specific on .put()
222 chunksize (int): the number of records for each chunk, if
223 specified returns an iterator
224 sqlvars (dict): optional, if specified will be used to format sql
225
226 Returns:
227 pd.DataFrame
228
229 To get the connection for a data query, instead of a DataFrame:
230
231 Args:
232
233 raw (bool): if True, returns the raw sql alchemy connection
234 keep (bool): option, if True keeps the connection open. Lazy=True
235 implies keep=True. This is potentially unsafe in a multi-user
236 environment where connection strings contain user-specific
237 secrets. To always keep connections open, set
238 ``om.datasets.defaults.SQLALCHEMY_ALWAYS_CACHE=True``
239
240 Returns:
241 connection
242
243 To get a cursor for a data query, instead of a DataFrame. Note this
244 implies keep=True.
245
246 Args:
247
248 lazy (bool): if True, returns a cursor instead of a DataFrame
249 sql (str): the sql query, defaults to the query specific on .put()
250
251 Returns:
252 cursor
253 """
254 meta = self.data_store.metadata(name)
255 connection_str = meta.kind_meta.get('sqlalchemy_connection')
256 valid_sql = lambda v: isinstance(v, str) or v is not None
257 sql = sql if valid_sql(sql) else meta.kind_meta.get('sql')
258 sqlvars = sqlvars or {}
259 table = self._default_table(table or meta.kind_meta.get('table') or name)
260 if not raw and not valid_sql(sql):
261 sql = f'select * from :sqltable'
262 chunksize = chunksize or meta.kind_meta.get('chunksize')
263 _default_keep = getattr(self.data_store.defaults,
264 'SQLALCHEMY_ALWAYS_CACHE',
265 ALWAYS_CACHE)
266 keep = keep if keep is not None else _default_keep
267 if connection_str:
268 secrets = self._get_secrets(meta, secrets)
269 connection = self._get_connection(name, connection_str, secrets=secrets, keep=keep)
270 else:
271 raise ValueError('no connection string')
272 if not raw and valid_sql(sql):
273 sql = sql.replace(':sqltable', table)
274 index_cols = _meta_to_indexcols(meta) if index else kwargs.get('index_col')
275 stmt = self._sanitize_statement(sql, sqlvars, trusted=trusted)
276 kwargs = meta.kind_meta.get('kwargs') or {}
277 kwargs.update(kwargs)
278 if not lazy:
279 logger.debug(f'executing sql {stmt} with parameters {sqlvars}')
280 pd_kwargs = {**dict(chunksize=chunksize, index_col=index_cols,
281 params=(sqlvars or {})), **kwargs}
282 result = pd.read_sql(stmt, connection, **pd_kwargs)
283 else:
284 # lazy returns a cursor
285 logger.debug(f'preparing a cursor for sql {sql} with parameters {sqlvars}')
286 result = connection.execute(stmt, sqlvars)
287 keep = True
288 if not keep:
289 connection.close()
290 return result
291 return connection
292
[docs]
293 def put(self, obj, name, sql=None, copy=False, append=True, chunksize=None,
294 transform=None, table=None, attributes=None, insert=False,
295 secrets=None, *args, **kwargs):
296 """ store sqlalchemy connection or insert data into an existing connection
297
298 Args:
299 obj (str|pd.DataFrame): the sqlalchemy connection string or a dataframe object
300 name (str): the name of the object
301 table (str): optional, if specified is stored along connection
302 sql (str): optional, if specified is stored along connection
303 copy (bool): optional, if True the connection is queried using sql
304 and the resulting data is stored instead, see below
305 attributes (dict): optional, set or update metadata.attributes
306
307 Returns:
308 metadata of the stored connection
309
310 Instead of inserting the connection specify ``copy=True`` to query data
311 and store it as a DataFrame dataset given by ``name``:
312
313 Args:
314 sql (str): sql to query
315 append (bool): if True the data is appended if exists already
316 chunksize (int): number of records to query in each chunk
317 transform (callable): passed as DataFrame.to_sql(method=)
318
319 Returns:
320 metadata of the inserted dataframe
321
322 To insert data via a previously stored connection, specify ``insert=True``:
323
324 Args:
325 insert (bool): specify True to insert via the connection
326 table (str): the table name to use for inserting data
327 append (bool): if False will replace any existing table, defaults to True
328 index (bool): if False will not attempt to create an index in target, defaults to False
329 chunksize (int): number of records to insert in each chunk
330
331 Returns:
332 metadata of the connection
333 """
334 meta = self.data_store.metadata(name)
335 if not insert and self._is_valid_url(obj):
336 # store a connection object
337 url = obj
338 cnx_name = name if not copy else '_cnx_{}'.format(name)
339 table = self._default_table(table or name)
340 metadata = self._put_as_connection(url, cnx_name, sql=sql, chunksize=chunksize,
341 table=table, attributes=attributes, **kwargs)
342 if copy:
343 secrets = self._get_secrets(metadata, secrets)
344 metadata = self._put_as_data(url, name, cnx_name,
345 sql=sql, chunksize=chunksize,
346 append=append, transform=transform,
347 secrets=secrets,
348 **kwargs)
349 elif meta is not None:
350 table = self._default_table(table or meta.kind_meta.get('table') or name)
351 metadata = self._put_via(obj, name, append=append, table=table, chunksize=chunksize,
352 transform=transform, **kwargs)
353 else:
354 raise ValueError('type {} is not supported by {}'.format(type(obj), self.KIND))
355 metadata.attributes.update(attributes) if attributes else None
356 return metadata.save()
357
358 def _put_via(self, obj, name, append=True, table=None, chunksize=None, transform=None,
359 index_columns=None, index=True, **kwargs):
360 # write data back through the connection
361 # -- ensure we have a valid object
362 if not hasattr(obj, 'to_sql'):
363 warning('obj.to_sql() does not exist, trying pd.DataFrame(obj)')
364 obj = pd.DataFrame(obj)
365 # -- get the connection
366 connection = self.get(name, raw=True)
367 metadata = self.data_store.metadata(name)
368 if isinstance(obj, pd.DataFrame) and index:
369 index_cols = _dataframe_to_indexcols(obj, metadata, index_columns=index_columns)
370 else:
371 index_cols = index_columns
372 metadata.kind_meta['index_columns'] = index_cols
373 exists_action = 'append' if append else 'replace'
374 transform = transform
375 self._chunked_to_sql(obj, table, connection, chunksize=chunksize, method=transform,
376 if_exists=exists_action, index=index, index_label=index_cols, **kwargs)
377 connection.close()
378 return metadata
379
380 def _put_as_data(self, url, name, cnx_name, sql=None, chunksize=None, append=True,
381 transform=None, **kwargs):
382 # use the url to query the connection and store resulting data instead
383 if not sql:
384 raise ValueError('a valid SQL statement is required with copy=True')
385 metadata = self.copy_from_sql(sql, url, name, chunksize=chunksize,
386 append=append, transform=transform,
387 **kwargs)
388 metadata.attributes['created_from'] = cnx_name
389 return metadata
390
391 def _put_as_connection(self, url, name, sql=None, chunksize=None, attributes=None,
392 table=None, index_columns=None, secrets=None, **kwargs):
393 kind_meta = {
394 'sqlalchemy_connection': str(url),
395 'sql': sql,
396 'chunksize': chunksize,
397 'table': ':' + table,
398 'index_columns': index_columns,
399 'kwargs': kwargs,
400 }
401 if secrets is True or (secrets is None and '{' in url and '}' in url):
402 kind_meta['secrets'] = {
403 'dsname': '.omega/vault',
404 'query': {
405 'data_userid': '{user}'
406 }
407 }
408 else:
409 kind_meta['secrets'] = secrets
410 metadata = self.data_store.metadata(name)
411 if metadata is not None:
412 metadata.kind_meta.update(kind_meta)
413 else:
414 metadata = self.data_store.make_metadata(name, self.KIND,
415 kind_meta=kind_meta,
416 attributes=attributes)
417 return metadata.save()
418
419 def _get_connection(self, name, connection_str, secrets=None, keep=False):
420 import sqlalchemy as sqa
421 # passwords should be encoded
422 # https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
423 encoded = lambda d: {
424 k: (quote_plus(v.decode('utf-8')) if isinstance(v, bytes)
425 else quote_plus(v)) for k, v in d.items() if isinstance(v, (str, bytes))
426 }
427 connection = None
428 cache_key = None
429 try:
430 # SECDEV: the cache key is a secret in order to avoid privilege escalation
431 # -- if it is not secret, user A could create the connection (=> cache)
432 # -- user B could reuse the connection by retrieving the dataset without secrets
433 # -- this way the user needs to have the same secrets in order to reuse the connection
434 enc_secrets = encoded(secrets or {})
435 connection_str = connection_str.format(**enc_secrets)
436 cache_key = sha256(f'{name}:{connection_str}'.encode('utf8')).hexdigest()
437 engine = self.__CNX_CACHE.get(cache_key) or sqa.create_engine(connection_str, **ENGINE_KWARGS)
438 connection = engine.connect()
439 except KeyError as e:
440 msg = ('{e}, ensure secrets are specified for connection '
441 '>{connection_str}<'.format(**locals()))
442 raise KeyError(msg)
443 except Exception as e:
444 if connection is not None:
445 connection.close()
446 self.__CNX_CACHE.pop(cache_key, None)
447 raise
448 else:
449 if keep:
450 self.__CNX_CACHE[cache_key] = engine
451 else:
452 self.__CNX_CACHE.pop(cache_key, None)
453 return connection
454
455 def copy_from_sql(self, sql, connstr, name, chunksize=10000,
456 append=False, transform=None, secrets=None, **kwargs):
457 connection = self._get_connection(name, connstr, secrets=secrets)
458 chunksize = chunksize or 10000 # avoid None
459 pditer = pd.read_sql(sql, connection, chunksize=chunksize, **kwargs)
460 with tqdm_if_interactive().tqdm(unit='rows') as pbar:
461 meta = self._chunked_insert(pditer, name, append=append,
462 transform=transform, pbar=pbar)
463 connection.close()
464 return meta
465
466 def _chunked_to_sql(self, df, table, connection, if_exists='append', chunksize=None, pbar=True, **kwargs):
467 # insert large df in chunks and with a progress bar
468 # from https://stackoverflow.com/a/39495229
469 chunksize = chunksize if chunksize is not None else 10000
470
471 def chunker(seq, size):
472 return (seq.iloc[pos:pos + size] for pos in range(0, len(seq), size))
473
474 def to_sql(df, table, connection, pbar=None):
475 for i, cdf in enumerate(chunker(df, chunksize)):
476 exists_action = if_exists if i == 0 else "append"
477 cdf.to_sql(table, con=connection, if_exists=exists_action, **kwargs)
478 if pbar:
479 pbar.update(len(cdf))
480 else:
481 print("writing chunk {}".format(i))
482
483 with tqdm_if_interactive().tqdm(total=len(df), unit='rows') as pbar:
484 to_sql(df, table, connection, pbar=pbar)
485
486 def _chunked_insert(self, pditer, name, append=True, transform=None, pbar=None):
487 # insert into om dataset
488 for i, df in enumerate(pditer):
489 if pbar is not None:
490 pbar.update(len(df))
491 should_append = (i > 0) or append
492 if transform:
493 df = transform(df)
494 try:
495 meta = self.data_store.put(df, name, append=should_append)
496 except Exception as e:
497 rows = df.iloc[0:10].to_dict()
498 raise ValueError("{e}: {rows}".format(**locals()))
499 return meta
500
501 def _is_valid_url(self, url):
502 # enable subclass override
503 return _is_valid_url(url)
504
505 def _supports_via(self, data_store, name, obj):
506 obj_ok = isinstance(obj, (pd.Series, pd.DataFrame, dict))
507 if obj_ok and data_store:
508 meta = data_store.metadata(name)
509 same_kind = meta and meta.kind == self.KIND
510 return obj_ok and same_kind
511 return False
512
513 def _get_secrets(self, meta, secrets):
514 secrets_specs = meta.kind_meta.get('secrets')
515 values = ({k: v for k, v in os.environ.items() if k.isupper() and isinstance(v, (str, bytes))}
516 if self.data_store.defaults.OMEGA_ALLOW_ENV_CONFIG else dict())
517 values.update(**self.data_store.defaults)
518 if not secrets and secrets_specs:
519 dsname = secrets_specs['dsname']
520 query = secrets_specs['query']
521 # -- format query values
522 query = _format_dict(query, replace=('_', '.'), **values, user=self._getuser())
523 # -- run query
524 secrets = self.data_store.get(dsname, filter=query)
525 secrets = secrets[0] if isinstance(secrets, list) and len(secrets) == 1 else {}
526 secrets.update(values)
527 # -- format secrets
528 if secrets:
529 secrets = _format_dict(secrets, **values, user=self._getuser())
530 return secrets
531
532 def _getuser(self):
533 return getattr(self.data_store.defaults, 'OMEGA_USERNAME', getuser())
534
535 def _default_table(self, name):
536 if name is None:
537 return name
538 if not name.startswith(':'):
539 name = f'{self.data_store.bucket}_{name}'
540 else:
541 name = name[1:]
542 return name
543
544 def _sanitize_statement(self, sql, sqlvars, trusted=False):
545 # sanitize sql:string statement in two steps
546 # -- step 1: replace all {} variables by :notation
547 # -- step 2: replace all remaining {} variables from sqlvars
548 # and issue a warning. step 2 is considered unsafe if
549 # the sqlvars source cannot be trusted
550 # -- step 3: prepare a SQL statement with bound variables
551 # see https://realpython.com/prevent-python-sql-injection/#using-query-parameters-in-sql
552 if not isinstance(sql, str):
553 return sql
554 # replace all {...} variables with bound parameters
555 # sql = "select * from foo where user={username}"
556 # => "select * from foo where user=:username"
557 placeholders = list(string.Formatter().parse(sql))
558 vars = [spec[1] for spec in placeholders if spec[1]]
559 safe_replacements = {var: f':{var}' for var in vars}
560 sql = sql.format(**safe_replacements)
561 # build parameter list for tuples and lists
562 # -- sqlalchemy+pyodbc do not support lists of values
563 # -- list of values must be passed as single parameters
564 # -- e.g. sql=select * from x in :list
565 # => select * from x in (:x_1, :x_2, :x_3, ...)
566 for k in vars:
567 # note we are iterating vars, not sqlvars
568 # -- sqlvars is not used in constructing sql text
569 v = sqlvars[k]
570 if isinstance(v, (list, tuple)):
571 bind_vars = {f'{k}_{i}': lv for i, lv in enumerate(v)}
572 placeholders = ','.join(f':{bk}' for bk in bind_vars)
573 sql = sql.replace(f':{k}', f'({placeholders})')
574 sqlvars.update(bind_vars)
575 try:
576 # format remaining {{}} for selection
577 # sql = "select {{cols}} from foo where user=:username
578 # => "select a, b from foo where user=:username
579 placeholders = list(string.Formatter().parse(sql))
580 vars = [spec[1] for spec in placeholders if spec[1]]
581 if vars and trusted != self.sign(sqlvars):
582 warnings.warn(f'Statement >{sql}< contains unsafe variables {vars}. Use :notation or sanitize input.')
583 sql = sql.format(**{**sqlvars, **safe_replacements})
584 except KeyError as e:
585 raise KeyError('{e}, specify sqlvars= to build query >{sql}<'.format(**locals()))
586 # prepare sql statement with bound variables
587 try:
588 stmt = sqlalchemy.sql.text(sql)
589 except StatementError as exc:
590 raise
591 return stmt
592
593
594def _is_valid_url(url):
595 # check if we have a valid url with a registered backend
596 import sqlalchemy
597
598 try:
599 url = sqlalchemy.engine.url.make_url(url)
600 drivername = url.drivername.split('+')[0] # e.g. mssql+pyodbc => mssql
601 valid = url.drivername in sqlalchemy.dialects.__all__
602 valid |= sqlalchemy.dialects.registry.load(drivername) is not None
603 except:
604 valid = False
605 return valid
606
607
608def _dataframe_to_indexcols(df, metadata, index_columns=None):
609 # from a dataframe get index column names
610 # works like pd.DataFrame.to_sql except for creating default index_i cols
611 # for any missing (None) index labels in a MultiIndex.
612 index_cols = metadata.kind_meta.get('index_columns') or index_columns or list(df.index.names)
613 multi = isinstance(df.index, pd.MultiIndex)
614 if index_cols is not None:
615 for i, col in enumerate(index_cols):
616 if col is None:
617 index_cols[i] = 'index' if not multi else 'index_{}'.format(i)
618 return index_cols
619
620
621def _meta_to_indexcols(meta):
622 index_cols = meta.kind_meta.get('index_columns')
623 multi = isinstance(index_cols, (list, tuple)) and len(index_cols) > 1
624 if index_cols is not None and not isinstance(index_cols, str):
625 for i, col in enumerate(index_cols):
626 if col is None:
627 index_cols[i] = 'index' if not multi else 'index_{}'.format(i)
628 return index_cols
629
630
631def _format_dict(d, replace=None, **kwargs):
632 for k, v in dict(d).items():
633 if replace:
634 del d[k]
635 k = k.replace(*replace) if replace else k
636 d[k] = v.format_map(KeepMissing(kwargs)) if isinstance(v, str) else v
637 return d
638
639
640def load_sql(om=None, kind=SQLAlchemyBackend.KIND):
641 """
642 load ipython sql magic, loading all sql alchemy connections
643
644 Usage:
645 !pip install ipython-sql
646
647 # prepare some connection, insert some data
648 df = pd.DataFrame(...)
649 om.datasets.put('sqlite:///test.db', 'testdb')
650 om.datasets.put(df, 'testdb', table='foobar', insert=True)
651
652 from omegaml.backends.sqlalchemy import load_sql
653 load_sql()
654
655 # list registered connections
656 %sql
657 omsql://testdb
658
659 # run queries
660 %sql omsql://testdb select * from foobar
661
662 See Also:
663 https://pypi.org/project/ipython-sql/
664
665 Args:
666 om (Omega): optional, specify omega instance, defaults to om.setup()
667 kind (str): the backend's kind, used to find connections to register
668 with the sql magic
669
670 Returns:
671 None
672 """
673 from unittest.mock import MagicMock
674 from IPython import get_ipython
675 import omegaml as om
676 from sql.connection import Connection # noqa
677
678 class ConnectionShim:
679 # this is required to trick sql magic into accepting existing connection objects
680 # (by default sql magic expects connection strings, not connection objects)
681 def __init__(self, url, conn):
682 self.session = conn
683 self.metadata = MagicMock()
684 self.metadata.bind.url = url
685 self.dialect = getattr(conn, 'dialect', 'omsql')
686
687 # load sql magic
688 ipython = get_ipython()
689 ipython.magic('load_ext sql')
690 # load registered sqlalchemy datasets
691 om = om or om.setup()
692 for ds in om.datasets.list(kind=kind, raw=True):
693 cnxstr = 'omsql://{ds.name}'.format(**locals())
694 conn = om.datasets.get(ds.name, raw=True)
695 Connection.connections[cnxstr] = ConnectionShim(cnxstr, conn)
696 ipython.magic('sql')