from interface_meta import override
from omniduct.databases.base import DatabaseClient
from omniduct.databases.hiveserver2 import HiveServer2Client
[docs]class PySparkClient(DatabaseClient):
"""
This Duct connects to a local PySpark session using the `pyspark` library.
"""
PROTOCOLS = ['pyspark']
DEFAULT_PORT = None
SUPPORTS_SESSION_PROPERTIES = True
NAMESPACE_NAMES = ['schema', 'table']
NAMESPACE_QUOTECHAR = '`'
NAMESPACE_SEPARATOR = '.'
@override
def _init(self, app_name='omniduct', config=None, master=None, enable_hive_support=False):
"""
Args:
app_name (str): The application name of the SparkSession.
config (dict or None): Any additional configuration to pass through
to the SparkSession builder.
master (str): The Spark master URL to connect to (only necessary
if environment specified configuration is missing).
enable_hive_support (bool): Whether to enable Hive support for the
Spark session.
Note: Pyspark must be installed in order to use this backend.
"""
self.app_name = app_name
self.config = config or {}
self.master = master
self.enable_hive_support = enable_hive_support
self._spark_session = None
# Connection management
@override
def _connect(self):
from pyspark.sql import SparkSession
builder = SparkSession.builder.appName(self.app_name)
if self.master:
builder.master(self.master)
if self.enable_hive_support:
builder.enableHiveSupport()
if self.config:
for key, value in self.config.items():
builder.config(key, value)
self._spark_session = builder.getOrCreate()
@override
def _is_connected(self):
return self._spark_session is not None
@override
def _disconnect(self):
self._spark_session.sparkContext.stop()
# Database operations
@override
def _statement_prepare(self, statement, session_properties, **kwargs):
return (
"\n".join(
"SET {key} = {value};".format(key=key, value=value)
for key, value in session_properties.items()
) + statement
)
@override
def _execute(self, statement, cursor, wait, session_properties):
assert wait is True, "This Spark backend does not support asynchronous operations."
return SparkCursor(self._spark_session.sql(statement))
@override
def _query_to_table(self, statement, table, if_exists, **kwargs):
return HiveServer2Client._query_to_table(self, statement, table, if_exists, **kwargs)
@override
def _table_list(self, namespace, **kwargs):
return HiveServer2Client._table_list(self, namespace, **kwargs)
@override
def _table_exists(self, table, **kwargs):
return HiveServer2Client._table_exists(self, table, **kwargs)
@override
def _table_drop(self, table, **kwargs):
return HiveServer2Client._table_drop(self, table, **kwargs)
@override
def _table_desc(self, table, **kwargs):
return HiveServer2Client._table_desc(self, table, **kwargs)
@override
def _table_head(self, table, n=10, **kwargs):
return HiveServer2Client._table_head(self, table, n=n, **kwargs)
@override
def _table_props(self, table, **kwargs):
return HiveServer2Client._table_props(self, table, **kwargs)
class SparkCursor(object):
"""
This DBAPI2 compatible cursor wraps around a Spark DataFrame
"""
def __init__(self, df):
self.df = df
self._df_iter = None
@property
def df_iter(self):
if not getattr(self, '_df_iter'):
self._df_iter = self.df.toLocalIterator()
return self._df_iter
arraysize = 1
@property
def description(self):
return tuple([
(name, type_, None, None, None, None, None)
for name, type_ in self.df.dtypes
])
@property
def row_count(self):
return -1
def close(self):
pass
def execute(operation, parameters=None):
raise NotImplementedError
def executemany(operation, seq_of_parameters=None):
raise NotImplementedError
def fetchone(self):
return [
value or None
for value in next(self.df_iter)
]
def fetchmany(self, size=None):
size = size or self.arraysize
return [self.fetchone() for _ in range(size)]
def fetchall(self):
return self.df.collect()
def setinputsizes(self, sizes):
pass
def setoutputsize(self, size, column=None):
pass