Module pyracmon.dialect.mysql

A dialect module for MySQL.

Expand source code
"""
A dialect module for MySQL.
"""
from itertools import groupby
from decimal import Decimal
from enum import Enum
from datetime import date, datetime, time, timedelta
from typing import Optional
from pyracmon.connection import Connection
from pyracmon.model import Table, Column, Relations, ForeignKey
from pyracmon.dialect.shared import MultiInsertMixin, TruncateMixin
from pyracmon.query import Q, where
from pyracmon.clause import holders


def read_schema(db, excludes: Optional[list[str]] = None, includes: Optional[list[str]] = None) -> list[Table]:
    """
    Collect tables in current database.

    Args:
        excludes: Excluding table names.
        includes: Including table names. If not specified, all tables are collected.
    Returns:
        Table schemas.
    """
    q = Q(excludes = excludes, includes = includes)

    cond = Q.of("c.table_schema = DATABASE()") & q.excludes.not_in("c.table_name") & q.includes.in_("c.table_name")

    w, params = where(cond)

    cursor = db.stmt().execute(f"""\
        SELECT
            c.table_name, c.column_name, c.data_type, c.is_nullable, c.column_type, c.column_key, c.extra, c.column_comment
        FROM
            information_schema.columns AS c
        {w}
        ORDER BY
            c.table_name, c.ordinal_position ASC
        """, *params)

    def map_types(t):
        base = db.context.config.type_mapping
        ptype = base and base(t)
        return ptype or _map_types(t)

    def column_of(n, t, nullable, ct, key, extra, comment):
        return Column(n, map_types(t), ct, key == "PRI", None, True if extra == "auto_increment" else None, nullable == "YES", comment or "")

    tables = []

    for t, cols in groupby(cursor.fetchall(), lambda row: row[0]):
        tables.append(Table(t, [column_of(*c[1:]) for c in cols]))

    cursor.close()

    if len(tables) == 0:
        return []

    cursor = db.stmt().execute(f"""\
        SELECT
            table_name, column_name, referenced_table_name, referenced_column_name
        FROM
            information_schema.key_column_usage
        WHERE
            table_schema = DATABASE() AND referenced_table_name IS NOT NULL
        """)

    table_map = {t.name:t for t in tables}

    for row in cursor.fetchall():
        table_from = table_map.get(row[0], None)
        col_from = table_from.find(row[1]) if table_from else None

        if col_from:
            table_to = table_map.get(row[2], None)
            col_to = table_to.find(row[3]) if table_to else None
            col_from.fk = col_from.fk or Relations()
            col_from.fk.add(ForeignKey(table_to or row[2], col_to or row[3]))

    cursor.close()

    cursor = db.stmt().execute(f"""\
        SELECT
            table_name, table_comment
        FROM
            information_schema.tables
        WHERE
            table_name IN ({holders(len(tables))})
        """, *[t.name for t in tables])

    table_map = {t.name: t for t in tables}

    for n, cmt in cursor.fetchall():
        if n in table_map:
            table_map[n].comment = cmt or ""

    cursor.close()

    return tables


def _map_types(t):
    if t == "tinyint" or t == "smallint" or t == "mediumint" or t == "int" or t == "bigint":
        return int
    elif t == "decimal":
        return Decimal
    elif t == "float" or t == "double":
        return float
    elif t == "bit":
        return int
    elif t == "char" or t == "varchar" or t == "binary" or t == "varbinary" or t == "text":
        return str
    elif t == "blob":
        return bytes
    elif t == "enum":
        return Enum
    elif t == "date":
        return date
    elif t == "datetime" or t == "timestamp":
        return datetime
    else:
        return object


class MySQLMixin(MultiInsertMixin, TruncateMixin):
    """
    Model mixin whose methods are available in MySQL.
    """
    @classmethod
    def last_sequences(cls, db: Connection, num: int) -> list[tuple[Column, int]]:
        cols = [c for c in cls.columns if c.incremental]

        if len(cols) > 1:
            raise ValueError(f"MySQL allows tables having only an auto-increment column.")
        elif len(cols) == 1:
            d = db.cursor()
            d.execute(f"SELECT LAST_INSERT_ID()")
            sequence = d.fetchone()[0] + num - 1 # type: ignore
            d.close()
            return [(cols[0], sequence)]
        else:
            return []

    @classmethod
    def truncate(cls, db: Connection):
        db.cursor().execute(f"DELETE FROM {cls.name}")
        db.cursor().execute(f"ALTER TABLE {cls.name} auto_increment = 1")


mixins = [MySQLMixin]


def found_rows(db):
    with db.cursor() as c:
        c.execute("SELECT FOUND_ROWS()")
        return c.fetchone()[0]

Functions

def found_rows(db)
Expand source code
def found_rows(db):
    with db.cursor() as c:
        c.execute("SELECT FOUND_ROWS()")
        return c.fetchone()[0]
def read_schema(db, excludes: Optional[list[str]] = None, includes: Optional[list[str]] = None) ‑> list[Table]

Collect tables in current database.

Args

excludes
Excluding table names.
includes
Including table names. If not specified, all tables are collected.

Returns

Table schemas.

Expand source code
def read_schema(db, excludes: Optional[list[str]] = None, includes: Optional[list[str]] = None) -> list[Table]:
    """
    Collect tables in current database.

    Args:
        excludes: Excluding table names.
        includes: Including table names. If not specified, all tables are collected.
    Returns:
        Table schemas.
    """
    q = Q(excludes = excludes, includes = includes)

    cond = Q.of("c.table_schema = DATABASE()") & q.excludes.not_in("c.table_name") & q.includes.in_("c.table_name")

    w, params = where(cond)

    cursor = db.stmt().execute(f"""\
        SELECT
            c.table_name, c.column_name, c.data_type, c.is_nullable, c.column_type, c.column_key, c.extra, c.column_comment
        FROM
            information_schema.columns AS c
        {w}
        ORDER BY
            c.table_name, c.ordinal_position ASC
        """, *params)

    def map_types(t):
        base = db.context.config.type_mapping
        ptype = base and base(t)
        return ptype or _map_types(t)

    def column_of(n, t, nullable, ct, key, extra, comment):
        return Column(n, map_types(t), ct, key == "PRI", None, True if extra == "auto_increment" else None, nullable == "YES", comment or "")

    tables = []

    for t, cols in groupby(cursor.fetchall(), lambda row: row[0]):
        tables.append(Table(t, [column_of(*c[1:]) for c in cols]))

    cursor.close()

    if len(tables) == 0:
        return []

    cursor = db.stmt().execute(f"""\
        SELECT
            table_name, column_name, referenced_table_name, referenced_column_name
        FROM
            information_schema.key_column_usage
        WHERE
            table_schema = DATABASE() AND referenced_table_name IS NOT NULL
        """)

    table_map = {t.name:t for t in tables}

    for row in cursor.fetchall():
        table_from = table_map.get(row[0], None)
        col_from = table_from.find(row[1]) if table_from else None

        if col_from:
            table_to = table_map.get(row[2], None)
            col_to = table_to.find(row[3]) if table_to else None
            col_from.fk = col_from.fk or Relations()
            col_from.fk.add(ForeignKey(table_to or row[2], col_to or row[3]))

    cursor.close()

    cursor = db.stmt().execute(f"""\
        SELECT
            table_name, table_comment
        FROM
            information_schema.tables
        WHERE
            table_name IN ({holders(len(tables))})
        """, *[t.name for t in tables])

    table_map = {t.name: t for t in tables}

    for n, cmt in cursor.fetchall():
        if n in table_map:
            table_map[n].comment = cmt or ""

    cursor.close()

    return tables

Classes

class MySQLMixin

Model mixin whose methods are available in MySQL.

Expand source code
class MySQLMixin(MultiInsertMixin, TruncateMixin):
    """
    Model mixin whose methods are available in MySQL.
    """
    @classmethod
    def last_sequences(cls, db: Connection, num: int) -> list[tuple[Column, int]]:
        cols = [c for c in cls.columns if c.incremental]

        if len(cols) > 1:
            raise ValueError(f"MySQL allows tables having only an auto-increment column.")
        elif len(cols) == 1:
            d = db.cursor()
            d.execute(f"SELECT LAST_INSERT_ID()")
            sequence = d.fetchone()[0] + num - 1 # type: ignore
            d.close()
            return [(cols[0], sequence)]
        else:
            return []

    @classmethod
    def truncate(cls, db: Connection):
        db.cursor().execute(f"DELETE FROM {cls.name}")
        db.cursor().execute(f"ALTER TABLE {cls.name} auto_increment = 1")

Ancestors

Static methods

def last_sequences(db: Connection, num: int) ‑> list[tuple[Column, int]]
Expand source code
@classmethod
def last_sequences(cls, db: Connection, num: int) -> list[tuple[Column, int]]:
    cols = [c for c in cls.columns if c.incremental]

    if len(cols) > 1:
        raise ValueError(f"MySQL allows tables having only an auto-increment column.")
    elif len(cols) == 1:
        d = db.cursor()
        d.execute(f"SELECT LAST_INSERT_ID()")
        sequence = d.fetchone()[0] + num - 1 # type: ignore
        d.close()
        return [(cols[0], sequence)]
    else:
        return []
def truncate(db: Connection)
Expand source code
@classmethod
def truncate(cls, db: Connection):
    db.cursor().execute(f"DELETE FROM {cls.name}")
    db.cursor().execute(f"ALTER TABLE {cls.name} auto_increment = 1")

Inherited members