Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(mysql): port to MySQLdb instead of pymysql #10077

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/renovate.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"addLabels": ["druid"]
},
{
"matchPackagePatterns": ["pymysql", "mariadb"],
"matchPackagePatterns": ["mysqlclient", "mariadb"],
"addLabels": ["mysql"]
},
{
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ jobs:
- polars
sys-deps:
- libgeos-dev
- default-libmysqlclient-dev
- name: postgres
title: PostgreSQL
extras:
Expand Down Expand Up @@ -271,6 +272,7 @@ jobs:
- mysql
sys-deps:
- libgeos-dev
- default-libmysqlclient-dev
- os: windows-latest
backend:
name: clickhouse
Expand Down
2 changes: 1 addition & 1 deletion conda/environment-arm64-flink.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- pyarrow-hotfix >=0.4
- pydata-google-auth
- pydruid >=0.6.5
- pymysql >=1
- mysqlclient >=2.2.4
- pyspark >=3
- python-dateutil >=2.8.2
- python-duckdb >=0.8.1
Expand Down
2 changes: 1 addition & 1 deletion conda/environment-arm64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- pyarrow-hotfix >=0.4
- pydata-google-auth
- pydruid >=0.6.5
- pymysql >=1
- mysqlclient >=2.2.4
- pyodbc >=4.0.39
- pyspark >=3
- python-dateutil >=2.8.2
Expand Down
2 changes: 1 addition & 1 deletion conda/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- pyarrow-hotfix >=0.4
- pydata-google-auth
- pydruid >=0.6.5
- pymysql >=1
- mysqlclient >=2.2.4
- pyodbc >=4.0.39
- pyspark >=3
- python >=3.10
Expand Down
158 changes: 76 additions & 82 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
from __future__ import annotations

import contextlib
import re
import warnings
from functools import cached_property
from operator import itemgetter
from typing import TYPE_CHECKING, Any
from urllib.parse import unquote_plus

import pymysql
import MySQLdb
import sqlglot as sg
import sqlglot.expressions as sge
from pymysql.constants import ER
from pymysql.err import ProgrammingError
from MySQLdb import ProgrammingError
from MySQLdb.constants import ER

import ibis
import ibis.backends.sql.compilers as sc
Expand All @@ -24,7 +23,6 @@
import ibis.expr.types as ir
from ibis import util
from ibis.backends import CanCreateDatabase
from ibis.backends.mysql.datatypes import _type_from_cursor_info
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers.base import STAR, TRUE, C

Expand Down Expand Up @@ -89,16 +87,14 @@

@cached_property
def version(self):
matched = re.search(r"(\d+)\.(\d+)\.(\d+)", self.con.server_version)
return ".".join(matched.groups())
return ".".join(map(str, self.con._server_version))

def do_connect(
self,
host: str = "localhost",
user: str | None = None,
password: str | None = None,
port: int = 3306,
database: str | None = None,
autocommit: bool = True,
**kwargs,
) -> None:
Expand All @@ -114,12 +110,10 @@
Password
port
Port
database
Database to connect to
autocommit
Autocommit mode
kwargs
Additional keyword arguments passed to `pymysql.connect`
Additional keyword arguments passed to `MySQLdb.connect`

Examples
--------
Expand Down Expand Up @@ -149,22 +143,20 @@
year int32
month int32
"""
self.con = pymysql.connect(
self.con = MySQLdb.connect(
user=user,
host=host,
host="127.0.0.1" if host == "localhost" else host,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

localhost will try to use a socket, which is different from the behavior of pymysql, so I made localhost usage backwards compatible. There's a unix_socket argument if anyone really needs it.

port=port,
password=password,
database=database,
autocommit=autocommit,
conv=pymysql.converters.conversions,
**kwargs,
)

self._post_connect()

@util.experimental
@classmethod
def from_connection(cls, con: pymysql.Connection) -> Backend:
def from_connection(cls, con) -> Backend:
"""Create an Ibis client from an existing connection to a MySQL database.

Parameters
Expand All @@ -179,7 +171,7 @@
return new_backend

def _post_connect(self) -> None:
with contextlib.closing(self.con.cursor()) as cur:
with self.con.cursor() as cur:
try:
cur.execute("SET @@session.time_zone = 'UTC'")
except Exception as e: # noqa: BLE001
Expand All @@ -198,23 +190,34 @@
return self._filter_with_like(databases, like)

def _get_schema_using_query(self, query: str) -> sch.Schema:
with self.begin() as cur:
cur.execute(
sg.select(STAR)
.from_(
sg.parse_one(query, dialect=self.dialect).subquery(
sg.to_identifier("tmp", quoted=self.compiler.quoted)
)
from ibis.backends.mysql.datatypes import _type_from_cursor_info

sql = (
sg.select(STAR)
.from_(
sg.parse_one(query, dialect=self.dialect).subquery(
sg.to_identifier("tmp", quoted=self.compiler.quoted)
)
.limit(0)
.sql(self.dialect)
)
return sch.Schema(
{
field.name: _type_from_cursor_info(descr, field)
for descr, field in zip(cur.description, cur._result.fields)
}
.limit(0)
.sql(self.dialect)
)
with self.begin() as cur:
cur.execute(sql)
descr, flags = cur.description, cur.description_flags

items = {}
for (name, type_code, _, _, field_length, scale, _), raw_flags in zip(
descr, flags
):
item = _type_from_cursor_info(
flags=raw_flags,
type_code=type_code,
field_length=field_length,
scale=scale,
)
items[name] = item
return sch.Schema(items)

def get_schema(
self, name: str, *, catalog: str | None = None, database: str | None = None
Expand Down Expand Up @@ -258,38 +261,52 @@
def begin(self):
con = self.con
cur = con.cursor()
autocommit = con.get_autocommit()

if not autocommit:
con.begin()

Check warning on line 267 in ibis/backends/mysql/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/mysql/__init__.py#L267

Added line #L267 was not covered by tests

try:
yield cur
except Exception:
con.rollback()
if not autocommit:
con.rollback()

Check warning on line 273 in ibis/backends/mysql/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/mysql/__init__.py#L273

Added line #L273 was not covered by tests
raise
else:
con.commit()
if not autocommit:
con.commit()

Check warning on line 277 in ibis/backends/mysql/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/mysql/__init__.py#L277

Added line #L277 was not covered by tests
finally:
cur.close()

# TODO(kszucs): should make it an abstract method or remove the use of it
# from .execute()
@contextlib.contextmanager
def _safe_raw_sql(self, *args, **kwargs):
with contextlib.closing(self.raw_sql(*args, **kwargs)) as result:
with self.raw_sql(*args, **kwargs) as result:
yield result

def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
with contextlib.suppress(AttributeError):
query = query.sql(dialect=self.name)

con = self.con
autocommit = con.get_autocommit()

cursor = con.cursor()

if not autocommit:
con.begin()

Check warning on line 298 in ibis/backends/mysql/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/mysql/__init__.py#L298

Added line #L298 was not covered by tests

try:
cursor.execute(query, **kwargs)
except Exception:
con.rollback()
if not autocommit:
con.rollback()

Check warning on line 304 in ibis/backends/mysql/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/mysql/__init__.py#L304

Added line #L304 was not covered by tests
cursor.close()
raise
else:
con.commit()
if not autocommit:
con.commit()

Check warning on line 309 in ibis/backends/mysql/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/mysql/__init__.py#L309

Added line #L309 was not covered by tests
return cursor

# TODO: disable positional arguments
Expand Down Expand Up @@ -406,11 +423,9 @@
if temp:
properties.append(sge.TemporaryProperty())

temp_memtable_view = None
if obj is not None:
if not isinstance(obj, ir.Expr):
table = ibis.memtable(obj)
temp_memtable_view = table.op().name
else:
table = obj

Expand All @@ -428,39 +443,33 @@
if not schema:
schema = table.schema()

table_expr = sg.table(temp_name, catalog=database, quoted=self.compiler.quoted)
target = sge.Schema(
this=table_expr, expressions=schema.to_sqlglot(self.dialect)
)
quoted = self.compiler.quoted
dialect = self.dialect

table_expr = sg.table(temp_name, catalog=database, quoted=quoted)
target = sge.Schema(this=table_expr, expressions=schema.to_sqlglot(dialect))

create_stmt = sge.Create(
kind="TABLE",
this=target,
properties=sge.Properties(expressions=properties),
kind="TABLE", this=target, properties=sge.Properties(expressions=properties)
)

this = sg.table(name, catalog=database, quoted=self.compiler.quoted)
this = sg.table(name, catalog=database, quoted=quoted)
with self._safe_raw_sql(create_stmt) as cur:
if query is not None:
insert_stmt = sge.Insert(this=table_expr, expression=query).sql(
self.name
)
cur.execute(insert_stmt)
cur.execute(sge.Insert(this=table_expr, expression=query).sql(dialect))

if overwrite:
cur.execute(sge.Drop(kind="TABLE", this=this, exists=True).sql(dialect))
cur.execute(
sge.Drop(kind="TABLE", this=this, exists=True).sql(self.name)
)
cur.execute(
f"ALTER TABLE IF EXISTS {table_expr.sql(self.name)} RENAME TO {this.sql(self.name)}"
sge.Alter(
kind="TABLE",
this=table_expr,
exists=True,
actions=[sge.RenameTable(this=this)],
).sql(dialect)
)

if schema is None:
# Clean up temporary memtable if we've created one
# for in-memory reads
if temp_memtable_view is not None:
self.drop_table(temp_memtable_view)

return self.table(name, database=database)

# preserve the input schema if it was provided
Expand All @@ -477,7 +486,7 @@
with self.begin() as cur:
cur.execute(sql)
cur.fetchall()
except pymysql.err.ProgrammingError as e:
except MySQLdb.ProgrammingError as e:
err_code, _ = e.args
if err_code == ER.NO_SUCH_TABLE:
return False
Expand All @@ -495,16 +504,17 @@

name = op.name
quoted = self.compiler.quoted
dialect = self.dialect

create_stmt = sg.exp.Create(
kind="TABLE",
this=sg.exp.Schema(
this=sg.to_identifier(name, quoted=quoted),
expressions=schema.to_sqlglot(self.dialect),
expressions=schema.to_sqlglot(dialect),
),
properties=sg.exp.Properties(expressions=[sge.TemporaryProperty()]),
)
create_stmt_sql = create_stmt.sql(self.name)
create_stmt_sql = create_stmt.sql(dialect)

df = op.data.to_frame()
# nan can not be used with MySQL
Expand Down Expand Up @@ -549,23 +559,7 @@

from ibis.backends.mysql.converter import MySQLPandasData

try:
df = pd.DataFrame.from_records(
cursor, columns=schema.names, coerce_float=True
)
except Exception:
# clean up the cursor if we fail to create the DataFrame
#
# in the sqlite case failing to close the cursor results in
# artificially locked tables
cursor.close()
raise
df = MySQLPandasData.convert_table(df, schema)
return df

def _finalize_memtable(self, name: str) -> None:
"""No-op.

Executing **any** SQL in a finalizer causes the underlying connection
socket to be set to `None`. It is unclear why this happens.
"""
df = pd.DataFrame.from_records(
cursor.fetchall(), columns=schema.names, coerce_float=True
)
return MySQLPandasData.convert_table(df, schema)
Loading