Source code for omegaml.backends.sqlalchemy

  1import logging
  2import os
  3import string
  4import warnings
  5from getpass import getuser
  6from hashlib import sha256
  7from logging import warning
  8from urllib.parse import quote_plus
  9
 10import sqlalchemy
 11from packaging.version import Version
 12from sqlalchemy.exc import StatementError
 13
 14from omegaml.backends.basedata import BaseDataBackend
 15from omegaml.util import ProcessLocal, KeepMissing, tqdm_if_interactive, signature
 16
 17try:
 18    import snowflake
 19
 20    sql_logger = logging.getLogger('snowflake')
 21    sql_logger.setLevel('CRITICAL')
 22except:
 23    pass
 24
 25try:
 26    # enable pandas >= 2.2 compatibility with sqlalchemy 1.4
 27    # -- workaround to due to https://github.com/pandas-dev/pandas/issues/57049
 28    # -- this forces the use of pd.SQLDatabase instead of pd.SQLiteDatabase
 29    # -- see https://github.com/pandas-dev/pandas/issues/57049#issuecomment-3398561199
 30    import pandas as pd
 31    import sqlalchemy as sqa
 32
 33    if (Version(pd.__version__) >= Version("2.2") and
 34            Version(sqa.__version__) < Version("2.2")):
 35        from pandas.compat._optional import VERSIONS
 36
 37        VERSIONS['sqlalchemy'] = '1.4'
 38        warnings.warn(
 39            "Patching pandas > 2.2 to support sqlalchemy >=1.4,<2 due to https://github.com/pandas-dev/pandas/issues/57049. To avoid this warning upgrade to sqlalchemy 2.x")
 40except:
 41    pass
 42
 43#: override by setting om.defaults.SQLALCHEMY_ALWAYS_CACHE
 44ALWAYS_CACHE = True
 45# -- enabled by default as this is the least-surprised option
 46# -- consistent with sqlalchemy connection pooling defaults
 47#: kwargs for create_engine()
 48ENGINE_KWARGS = dict(echo=False, pool_pre_ping=True, pool_recycle=3600)
 49# -- echo=False - do not log to stdout
 50# -- pool_pre_ping=True - always check, re-establish connection if no longer working
 51# -- pool_recylce=N - do not reuse connections older than N seconds
 52
 53logger = logging.getLogger(__name__)
 54
 55
[docs] 56class SQLAlchemyBackend(BaseDataBackend): 57 """ 58 sqlalchemy plugin for omegaml 59 60 Usage: 61 62 Define your sqlalchemy connection:: 63 64 sqlalchemy_constr = f'sqlalchemy://{user}:{password}@{account}/' 65 66 Store the connection in any of three ways:: 67 68 # -- just the connection 69 om.datasets.put(sqlalchemy_constr, 'mysqlalchemy') 70 om.datasets.get('mysqlalchemy', raw=True) 71 => the sql connection object 72 73 # -- store connection with a predefined sql 74 om.datasets.put(sqlalchemy_constr, 'mysqlalchemy', 75 sql='select ....') 76 om.datasets.get('mysqlalchemy') 77 => will return a pandas dataframe using the specified sql to run. 78 specify chunksize= to return an interable of dataframes 79 80 # -- predefined sqls can contain variables to be resolved at access time 81 # if you miss to specify required variables in sqlvars, a KeyError is raised 82 om.datasets.put(sqlaclhemy_constr, 'myview', 83 sql='select ... from T1 where col="{var}"') 84 om.datasets.get('mysqlalchemy', sqlvars=dict(var="value")) 85 86 # -- Variables are replaced by binding parameters, which is safe for 87 # untrusted inputs. To replace variables as strings, use double 88 # `{{variable}}` notation. A warning will be issued because this 89 # is considered an unsafe practice for untrusted input (REST API). 90 # It is in your responsibility to sanitize the value of the `cols` variable. 91 om.datasets.put(sqlaclhemy_constr, 'myview', 92 sql='select {{cols}} from T1 where col="{var}"') 93 om.datasets.get('mysqlalchemy', sqlvars=dict(cols='foo, bar', 94 var="value")) 95 96 Query data from a connection and store into an omega-ml dataset:: 97 98 # -- copy the result of the sqlalchemy query to omegaml 99 om.datasets.put(sqlalchemy_constr, 'mysqlalchemy', 100 sql='select ...', copy=True) 101 om.datasets.get('mysqlalchemy') 102 => will return a pandas dataframe (without executing any additional queries) 103 => can also use with om.datasets.getl('mysqlalchemy') to return a MDataFrame 104 105 Controlling the table used in the connection:: 106 107 # -- the default table is {bucket}_{name}, override using table='myname' 108 om.datasets.put(sqlalchemy_constr, 'mysqlalchemy', 109 table='mytable', 110 sql='select ...', 111 copy=True) 112 om.datasets.get('mysqlalchemy') # read from {bucket}_myname 113 114 # -- to use a specific table, without bucket information use table=':myname' 115 om.datasets.put(sqlalchemy_constr, 'mysqlalchemy', 116 table=':mytable', 117 sql='select ...', 118 copy=True) 119 om.datasets.get('mysqlalchemy') # read from myname 120 121 Inserting data via a previously stored connection:: 122 123 # -- store data back through the connection 124 om.datasets.put(sqlalchemy_constr, 'mysqlalchemy') 125 om.datasets.put(df, 'mysqlalchemy', 126 table='SOMETABLE') 127 128 Using variables in connection strings: 129 130 Connection strings may contain variables, e.g. userid and password. 131 By default variables are resolved from the os environment. Can also 132 specify using any dict.:: 133 134 # -- use connection string with variables 135 sqlalchemy_constr = 'sqlite:///{dbname}.db' 136 om.datasets.put(sqlalchemy_constr, 'userdb') 137 om.datasets.get('userdb', secrets=dict(dbname='chuckdb')) 138 139 # -- alternatively, create a vault dataset: 140 secrets = dict(userid='chuck', dbname='chuckdb') 141 om.datasets.put(secrets, '.omega/vault') 142 om.datasets.get('userdb') 143 144 the '.omega/vault' dataset will be queried using the current userid as 145 the secret name, and the dbname retrieved from the document. This is 146 experimental and the vault is not encrypted. 147 148 Advanced: 149 150 ``om.datasets.put()`` supports the following additional keyword arguments 151 152 * ``chunksize=int`` - specify the number of rows to read from sqlalchemy in one chunk. 153 defaults to 10000 154 155 * ``parse_dates=['col', ...]`` - list of column names to parse for date, time or datetime. 156 see pd.read_sql for details 157 158 * ``transform=callable`` - a callable, is passed the DataFrame of each chunk before it 159 is inserted into the database. use to provide custom transformations. 160 only works on copy=True 161 162 * any other kwargs supported by ``pandas.read_sql`` 163 164 """ 165 KIND = 'sqlalchemy.conx' 166 PROMOTE = 'metadata' 167 168 #: sqlalchemy.Engine cache to enable pooled connections 169 __CNX_CACHE = ProcessLocal() 170 171 # -- https://docs.sqlalchemy.org/en/14/core/pooling.html#module-sqlalchemy.pool 172 # -- create_engine() must be called per-process, hence using ProcessLocal 173 # -- meaning when using a multiprocessing.Pool or other fork()-ed processes, 174 # the cache will be cleared in child processes, forcing the engine to be 175 # recreated automatically in _get_connection 176 177 def __init__(self, model_store=None, data_store=None, tracking=None, **kwargs): 178 super().__init__(model_store=model_store, data_store=data_store, tracking=tracking, **kwargs) 179
[docs] 180 @classmethod 181 def supports(cls, obj, name, insert=False, data_store=None, model_store=None, *args, **kwargs): 182 valid = cls._is_valid_url(cls, obj) 183 support_via = cls._supports_via(cls, data_store, name, obj) 184 return valid or support_via
185 186 def drop(self, name, secrets=None, **kwargs): 187 # ensure cache is cleared 188 clear_cache = True if secrets is None else False 189 try: 190 self.get(name, secrets=secrets, raw=True, keep=False) 191 except KeyError as e: 192 warnings.warn(f'Connection cache was cleared, however secret {e} was missing.') 193 clear_cache = True 194 if clear_cache: 195 self.__CNX_CACHE.clear() 196 return super().drop(name, **kwargs) 197 198 def sign(self, values): 199 return signature(values) 200
[docs] 201 def get(self, name, sql=None, chunksize=None, raw=False, sqlvars=None, 202 secrets=None, index=True, keep=None, lazy=False, table=None, trusted=False, *args, **kwargs): 203 """ retrieve a stored connection or query data from connection 204 205 Args: 206 name (str): the name of the connection 207 secrets (dict): dict to resolve variables in the connection string 208 keep (bool): if True connection is kept open, defaults to True (change 209 default by setting om.defaults.SQLALCHEMY_ALWAYS_CACHE = False) 210 table (str): the name of the table, will be prefixed with the 211 store's bucket name unless the table is specified as ':name' 212 trusted (bool|str): if passed must be the value for store.sign(sqlvars or kwargs), 213 otherwise a warning is issued for any remaining variables in the sql statement 214 215 Returns: 216 connection 217 218 To query data and return a DataFrame, specify ``sql='select ...'``: 219 220 Args: 221 sql (str): the sql query, defaults to the query specific on .put() 222 chunksize (int): the number of records for each chunk, if 223 specified returns an iterator 224 sqlvars (dict): optional, if specified will be used to format sql 225 226 Returns: 227 pd.DataFrame 228 229 To get the connection for a data query, instead of a DataFrame: 230 231 Args: 232 233 raw (bool): if True, returns the raw sql alchemy connection 234 keep (bool): option, if True keeps the connection open. Lazy=True 235 implies keep=True. This is potentially unsafe in a multi-user 236 environment where connection strings contain user-specific 237 secrets. To always keep connections open, set 238 ``om.datasets.defaults.SQLALCHEMY_ALWAYS_CACHE=True`` 239 240 Returns: 241 connection 242 243 To get a cursor for a data query, instead of a DataFrame. Note this 244 implies keep=True. 245 246 Args: 247 248 lazy (bool): if True, returns a cursor instead of a DataFrame 249 sql (str): the sql query, defaults to the query specific on .put() 250 251 Returns: 252 cursor 253 """ 254 meta = self.data_store.metadata(name) 255 connection_str = meta.kind_meta.get('sqlalchemy_connection') 256 valid_sql = lambda v: isinstance(v, str) or v is not None 257 sql = sql if valid_sql(sql) else meta.kind_meta.get('sql') 258 sqlvars = sqlvars or {} 259 table = self._default_table(table or meta.kind_meta.get('table') or name) 260 if not raw and not valid_sql(sql): 261 sql = f'select * from :sqltable' 262 chunksize = chunksize or meta.kind_meta.get('chunksize') 263 _default_keep = getattr(self.data_store.defaults, 264 'SQLALCHEMY_ALWAYS_CACHE', 265 ALWAYS_CACHE) 266 keep = keep if keep is not None else _default_keep 267 if connection_str: 268 secrets = self._get_secrets(meta, secrets) 269 connection = self._get_connection(name, connection_str, secrets=secrets, keep=keep) 270 else: 271 raise ValueError('no connection string') 272 if not raw and valid_sql(sql): 273 sql = sql.replace(':sqltable', table) 274 index_cols = _meta_to_indexcols(meta) if index else kwargs.get('index_col') 275 stmt = self._sanitize_statement(sql, sqlvars, trusted=trusted) 276 kwargs = meta.kind_meta.get('kwargs') or {} 277 kwargs.update(kwargs) 278 if not lazy: 279 logger.debug(f'executing sql {stmt} with parameters {sqlvars}') 280 pd_kwargs = {**dict(chunksize=chunksize, index_col=index_cols, 281 params=(sqlvars or {})), **kwargs} 282 result = pd.read_sql(stmt, connection, **pd_kwargs) 283 else: 284 # lazy returns a cursor 285 logger.debug(f'preparing a cursor for sql {sql} with parameters {sqlvars}') 286 result = connection.execute(stmt, sqlvars) 287 keep = True 288 if not keep: 289 connection.close() 290 return result 291 return connection
292
[docs] 293 def put(self, obj, name, sql=None, copy=False, append=True, chunksize=None, 294 transform=None, table=None, attributes=None, insert=False, 295 secrets=None, *args, **kwargs): 296 """ store sqlalchemy connection or insert data into an existing connection 297 298 Args: 299 obj (str|pd.DataFrame): the sqlalchemy connection string or a dataframe object 300 name (str): the name of the object 301 table (str): optional, if specified is stored along connection 302 sql (str): optional, if specified is stored along connection 303 copy (bool): optional, if True the connection is queried using sql 304 and the resulting data is stored instead, see below 305 attributes (dict): optional, set or update metadata.attributes 306 307 Returns: 308 metadata of the stored connection 309 310 Instead of inserting the connection specify ``copy=True`` to query data 311 and store it as a DataFrame dataset given by ``name``: 312 313 Args: 314 sql (str): sql to query 315 append (bool): if True the data is appended if exists already 316 chunksize (int): number of records to query in each chunk 317 transform (callable): passed as DataFrame.to_sql(method=) 318 319 Returns: 320 metadata of the inserted dataframe 321 322 To insert data via a previously stored connection, specify ``insert=True``: 323 324 Args: 325 insert (bool): specify True to insert via the connection 326 table (str): the table name to use for inserting data 327 append (bool): if False will replace any existing table, defaults to True 328 index (bool): if False will not attempt to create an index in target, defaults to False 329 chunksize (int): number of records to insert in each chunk 330 331 Returns: 332 metadata of the connection 333 """ 334 meta = self.data_store.metadata(name) 335 if not insert and self._is_valid_url(obj): 336 # store a connection object 337 url = obj 338 cnx_name = name if not copy else '_cnx_{}'.format(name) 339 table = self._default_table(table or name) 340 metadata = self._put_as_connection(url, cnx_name, sql=sql, chunksize=chunksize, 341 table=table, attributes=attributes, **kwargs) 342 if copy: 343 secrets = self._get_secrets(metadata, secrets) 344 metadata = self._put_as_data(url, name, cnx_name, 345 sql=sql, chunksize=chunksize, 346 append=append, transform=transform, 347 secrets=secrets, 348 **kwargs) 349 elif meta is not None: 350 table = self._default_table(table or meta.kind_meta.get('table') or name) 351 metadata = self._put_via(obj, name, append=append, table=table, chunksize=chunksize, 352 transform=transform, **kwargs) 353 else: 354 raise ValueError('type {} is not supported by {}'.format(type(obj), self.KIND)) 355 metadata.attributes.update(attributes) if attributes else None 356 return metadata.save()
357 358 def _put_via(self, obj, name, append=True, table=None, chunksize=None, transform=None, 359 index_columns=None, index=True, **kwargs): 360 # write data back through the connection 361 # -- ensure we have a valid object 362 if not hasattr(obj, 'to_sql'): 363 warning('obj.to_sql() does not exist, trying pd.DataFrame(obj)') 364 obj = pd.DataFrame(obj) 365 # -- get the connection 366 connection = self.get(name, raw=True) 367 metadata = self.data_store.metadata(name) 368 if isinstance(obj, pd.DataFrame) and index: 369 index_cols = _dataframe_to_indexcols(obj, metadata, index_columns=index_columns) 370 else: 371 index_cols = index_columns 372 metadata.kind_meta['index_columns'] = index_cols 373 exists_action = 'append' if append else 'replace' 374 transform = transform 375 self._chunked_to_sql(obj, table, connection, chunksize=chunksize, method=transform, 376 if_exists=exists_action, index=index, index_label=index_cols, **kwargs) 377 connection.close() 378 return metadata 379 380 def _put_as_data(self, url, name, cnx_name, sql=None, chunksize=None, append=True, 381 transform=None, **kwargs): 382 # use the url to query the connection and store resulting data instead 383 if not sql: 384 raise ValueError('a valid SQL statement is required with copy=True') 385 metadata = self.copy_from_sql(sql, url, name, chunksize=chunksize, 386 append=append, transform=transform, 387 **kwargs) 388 metadata.attributes['created_from'] = cnx_name 389 return metadata 390 391 def _put_as_connection(self, url, name, sql=None, chunksize=None, attributes=None, 392 table=None, index_columns=None, secrets=None, **kwargs): 393 kind_meta = { 394 'sqlalchemy_connection': str(url), 395 'sql': sql, 396 'chunksize': chunksize, 397 'table': ':' + table, 398 'index_columns': index_columns, 399 'kwargs': kwargs, 400 } 401 if secrets is True or (secrets is None and '{' in url and '}' in url): 402 kind_meta['secrets'] = { 403 'dsname': '.omega/vault', 404 'query': { 405 'data_userid': '{user}' 406 } 407 } 408 else: 409 kind_meta['secrets'] = secrets 410 metadata = self.data_store.metadata(name) 411 if metadata is not None: 412 metadata.kind_meta.update(kind_meta) 413 else: 414 metadata = self.data_store.make_metadata(name, self.KIND, 415 kind_meta=kind_meta, 416 attributes=attributes) 417 return metadata.save() 418 419 def _get_connection(self, name, connection_str, secrets=None, keep=False): 420 import sqlalchemy as sqa 421 # passwords should be encoded 422 # https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls 423 encoded = lambda d: { 424 k: (quote_plus(v.decode('utf-8')) if isinstance(v, bytes) 425 else quote_plus(v)) for k, v in d.items() if isinstance(v, (str, bytes)) 426 } 427 connection = None 428 cache_key = None 429 try: 430 # SECDEV: the cache key is a secret in order to avoid privilege escalation 431 # -- if it is not secret, user A could create the connection (=> cache) 432 # -- user B could reuse the connection by retrieving the dataset without secrets 433 # -- this way the user needs to have the same secrets in order to reuse the connection 434 enc_secrets = encoded(secrets or {}) 435 connection_str = connection_str.format(**enc_secrets) 436 cache_key = sha256(f'{name}:{connection_str}'.encode('utf8')).hexdigest() 437 engine = self.__CNX_CACHE.get(cache_key) or sqa.create_engine(connection_str, **ENGINE_KWARGS) 438 connection = engine.connect() 439 except KeyError as e: 440 msg = ('{e}, ensure secrets are specified for connection ' 441 '>{connection_str}<'.format(**locals())) 442 raise KeyError(msg) 443 except Exception as e: 444 if connection is not None: 445 connection.close() 446 self.__CNX_CACHE.pop(cache_key, None) 447 raise 448 else: 449 if keep: 450 self.__CNX_CACHE[cache_key] = engine 451 else: 452 self.__CNX_CACHE.pop(cache_key, None) 453 return connection 454 455 def copy_from_sql(self, sql, connstr, name, chunksize=10000, 456 append=False, transform=None, secrets=None, **kwargs): 457 connection = self._get_connection(name, connstr, secrets=secrets) 458 chunksize = chunksize or 10000 # avoid None 459 pditer = pd.read_sql(sql, connection, chunksize=chunksize, **kwargs) 460 with tqdm_if_interactive().tqdm(unit='rows') as pbar: 461 meta = self._chunked_insert(pditer, name, append=append, 462 transform=transform, pbar=pbar) 463 connection.close() 464 return meta 465 466 def _chunked_to_sql(self, df, table, connection, if_exists='append', chunksize=None, pbar=True, **kwargs): 467 # insert large df in chunks and with a progress bar 468 # from https://stackoverflow.com/a/39495229 469 chunksize = chunksize if chunksize is not None else 10000 470 471 def chunker(seq, size): 472 return (seq.iloc[pos:pos + size] for pos in range(0, len(seq), size)) 473 474 def to_sql(df, table, connection, pbar=None): 475 for i, cdf in enumerate(chunker(df, chunksize)): 476 exists_action = if_exists if i == 0 else "append" 477 cdf.to_sql(table, con=connection, if_exists=exists_action, **kwargs) 478 if pbar: 479 pbar.update(len(cdf)) 480 else: 481 print("writing chunk {}".format(i)) 482 483 with tqdm_if_interactive().tqdm(total=len(df), unit='rows') as pbar: 484 to_sql(df, table, connection, pbar=pbar) 485 486 def _chunked_insert(self, pditer, name, append=True, transform=None, pbar=None): 487 # insert into om dataset 488 for i, df in enumerate(pditer): 489 if pbar is not None: 490 pbar.update(len(df)) 491 should_append = (i > 0) or append 492 if transform: 493 df = transform(df) 494 try: 495 meta = self.data_store.put(df, name, append=should_append) 496 except Exception as e: 497 rows = df.iloc[0:10].to_dict() 498 raise ValueError("{e}: {rows}".format(**locals())) 499 return meta 500 501 def _is_valid_url(self, url): 502 # enable subclass override 503 return _is_valid_url(url) 504 505 def _supports_via(self, data_store, name, obj): 506 obj_ok = isinstance(obj, (pd.Series, pd.DataFrame, dict)) 507 if obj_ok and data_store: 508 meta = data_store.metadata(name) 509 same_kind = meta and meta.kind == self.KIND 510 return obj_ok and same_kind 511 return False 512 513 def _get_secrets(self, meta, secrets): 514 secrets_specs = meta.kind_meta.get('secrets') 515 values = ({k: v for k, v in os.environ.items() if k.isupper() and isinstance(v, (str, bytes))} 516 if self.data_store.defaults.OMEGA_ALLOW_ENV_CONFIG else dict()) 517 values.update(**self.data_store.defaults) 518 if not secrets and secrets_specs: 519 dsname = secrets_specs['dsname'] 520 query = secrets_specs['query'] 521 # -- format query values 522 query = _format_dict(query, replace=('_', '.'), **values, user=self._getuser()) 523 # -- run query 524 secrets = self.data_store.get(dsname, filter=query) 525 secrets = secrets[0] if isinstance(secrets, list) and len(secrets) == 1 else {} 526 secrets.update(values) 527 # -- format secrets 528 if secrets: 529 secrets = _format_dict(secrets, **values, user=self._getuser()) 530 return secrets 531 532 def _getuser(self): 533 return getattr(self.data_store.defaults, 'OMEGA_USERNAME', getuser()) 534 535 def _default_table(self, name): 536 if name is None: 537 return name 538 if not name.startswith(':'): 539 name = f'{self.data_store.bucket}_{name}' 540 else: 541 name = name[1:] 542 return name 543 544 def _sanitize_statement(self, sql, sqlvars, trusted=False): 545 # sanitize sql:string statement in two steps 546 # -- step 1: replace all {} variables by :notation 547 # -- step 2: replace all remaining {} variables from sqlvars 548 # and issue a warning. step 2 is considered unsafe if 549 # the sqlvars source cannot be trusted 550 # -- step 3: prepare a SQL statement with bound variables 551 # see https://realpython.com/prevent-python-sql-injection/#using-query-parameters-in-sql 552 if not isinstance(sql, str): 553 return sql 554 # replace all {...} variables with bound parameters 555 # sql = "select * from foo where user={username}" 556 # => "select * from foo where user=:username" 557 placeholders = list(string.Formatter().parse(sql)) 558 vars = [spec[1] for spec in placeholders if spec[1]] 559 safe_replacements = {var: f':{var}' for var in vars} 560 sql = sql.format(**safe_replacements) 561 # build parameter list for tuples and lists 562 # -- sqlalchemy+pyodbc do not support lists of values 563 # -- list of values must be passed as single parameters 564 # -- e.g. sql=select * from x in :list 565 # => select * from x in (:x_1, :x_2, :x_3, ...) 566 for k in vars: 567 # note we are iterating vars, not sqlvars 568 # -- sqlvars is not used in constructing sql text 569 v = sqlvars[k] 570 if isinstance(v, (list, tuple)): 571 bind_vars = {f'{k}_{i}': lv for i, lv in enumerate(v)} 572 placeholders = ','.join(f':{bk}' for bk in bind_vars) 573 sql = sql.replace(f':{k}', f'({placeholders})') 574 sqlvars.update(bind_vars) 575 try: 576 # format remaining {{}} for selection 577 # sql = "select {{cols}} from foo where user=:username 578 # => "select a, b from foo where user=:username 579 placeholders = list(string.Formatter().parse(sql)) 580 vars = [spec[1] for spec in placeholders if spec[1]] 581 if vars and trusted != self.sign(sqlvars): 582 warnings.warn(f'Statement >{sql}< contains unsafe variables {vars}. Use :notation or sanitize input.') 583 sql = sql.format(**{**sqlvars, **safe_replacements}) 584 except KeyError as e: 585 raise KeyError('{e}, specify sqlvars= to build query >{sql}<'.format(**locals())) 586 # prepare sql statement with bound variables 587 try: 588 stmt = sqlalchemy.sql.text(sql) 589 except StatementError as exc: 590 raise 591 return stmt
592 593 594def _is_valid_url(url): 595 # check if we have a valid url with a registered backend 596 import sqlalchemy 597 598 try: 599 url = sqlalchemy.engine.url.make_url(url) 600 drivername = url.drivername.split('+')[0] # e.g. mssql+pyodbc => mssql 601 valid = url.drivername in sqlalchemy.dialects.__all__ 602 valid |= sqlalchemy.dialects.registry.load(drivername) is not None 603 except: 604 valid = False 605 return valid 606 607 608def _dataframe_to_indexcols(df, metadata, index_columns=None): 609 # from a dataframe get index column names 610 # works like pd.DataFrame.to_sql except for creating default index_i cols 611 # for any missing (None) index labels in a MultiIndex. 612 index_cols = metadata.kind_meta.get('index_columns') or index_columns or list(df.index.names) 613 multi = isinstance(df.index, pd.MultiIndex) 614 if index_cols is not None: 615 for i, col in enumerate(index_cols): 616 if col is None: 617 index_cols[i] = 'index' if not multi else 'index_{}'.format(i) 618 return index_cols 619 620 621def _meta_to_indexcols(meta): 622 index_cols = meta.kind_meta.get('index_columns') 623 multi = isinstance(index_cols, (list, tuple)) and len(index_cols) > 1 624 if index_cols is not None and not isinstance(index_cols, str): 625 for i, col in enumerate(index_cols): 626 if col is None: 627 index_cols[i] = 'index' if not multi else 'index_{}'.format(i) 628 return index_cols 629 630 631def _format_dict(d, replace=None, **kwargs): 632 for k, v in dict(d).items(): 633 if replace: 634 del d[k] 635 k = k.replace(*replace) if replace else k 636 d[k] = v.format_map(KeepMissing(kwargs)) if isinstance(v, str) else v 637 return d 638 639 640def load_sql(om=None, kind=SQLAlchemyBackend.KIND): 641 """ 642 load ipython sql magic, loading all sql alchemy connections 643 644 Usage: 645 !pip install ipython-sql 646 647 # prepare some connection, insert some data 648 df = pd.DataFrame(...) 649 om.datasets.put('sqlite:///test.db', 'testdb') 650 om.datasets.put(df, 'testdb', table='foobar', insert=True) 651 652 from omegaml.backends.sqlalchemy import load_sql 653 load_sql() 654 655 # list registered connections 656 %sql 657 omsql://testdb 658 659 # run queries 660 %sql omsql://testdb select * from foobar 661 662 See Also: 663 https://pypi.org/project/ipython-sql/ 664 665 Args: 666 om (Omega): optional, specify omega instance, defaults to om.setup() 667 kind (str): the backend's kind, used to find connections to register 668 with the sql magic 669 670 Returns: 671 None 672 """ 673 from unittest.mock import MagicMock 674 from IPython import get_ipython 675 import omegaml as om 676 from sql.connection import Connection # noqa 677 678 class ConnectionShim: 679 # this is required to trick sql magic into accepting existing connection objects 680 # (by default sql magic expects connection strings, not connection objects) 681 def __init__(self, url, conn): 682 self.session = conn 683 self.metadata = MagicMock() 684 self.metadata.bind.url = url 685 self.dialect = getattr(conn, 'dialect', 'omsql') 686 687 # load sql magic 688 ipython = get_ipython() 689 ipython.magic('load_ext sql') 690 # load registered sqlalchemy datasets 691 om = om or om.setup() 692 for ds in om.datasets.list(kind=kind, raw=True): 693 cnxstr = 'omsql://{ds.name}'.format(**locals()) 694 conn = om.datasets.get(ds.name, raw=True) 695 Connection.connections[cnxstr] = ConnectionShim(cnxstr, conn) 696 ipython.magic('sql')