Module pyracmon.dialect.postgresql
A dialect module for PostgreSQL.
Expand source code
"""
A dialect module for PostgreSQL.
"""
import re
from decimal import Decimal
from datetime import date, datetime, time, timedelta
from uuid import UUID
from itertools import groupby
from typing import Optional
from pyracmon.connection import Connection
from pyracmon.model import Table, Column, ForeignKey, Relations
from pyracmon.dialect.shared import MultiInsertMixin, TruncateMixin
from pyracmon.query import Q, where
from pyracmon.clause import holders
SequencePattern = re.compile(r"nextval\(\'([a-zA-Z0-9_]+)\'(\:\:regclass)?\)")
def read_schema(db: Connection, excludes: Optional[list[str]] = None, includes: Optional[list[str]] = None) -> list[Table]:
"""
Collect tables in current database.
Args:
excludes: Excluding table names.
includes: Including table names. If not specified, all tables are collected.
Returns:
Table schemas.
"""
q = Q(excludes = excludes, includes = includes)
cond = Q.of("c.table_catalog = current_catalog") & Q.eq("c", table_schema="public") & Q.in_("t", table_type=["BASE TABLE", "VIEW"]) \
& q.excludes.not_in("c.table_name") & q.includes.in_("c.table_name")
w, params = where(cond)
cursor = db.stmt().execute(f"""\
SELECT
c.table_name, c.column_name, c.data_type, c.udt_name, c.is_nullable,
e.data_type, e.udt_name, k.constraint_type, c.column_default, c.ordinal_position
FROM
information_schema.columns AS c
INNER JOIN information_schema.tables AS t ON c.table_name = t.table_name
LEFT JOIN (
SELECT
tc.table_name, k.column_name, string_agg(tc.constraint_type, ',') AS constraint_type
FROM
information_schema.key_column_usage AS k
INNER JOIN information_schema.table_constraints AS tc ON k.constraint_name = tc.constraint_name
WHERE
tc.constraint_type = 'PRIMARY KEY' OR tc.constraint_type = 'FOREIGN KEY'
GROUP BY
tc.table_name, k.column_name
) AS k ON t.table_name = k.table_name AND c.column_name = k.column_name
LEFT JOIN information_schema.element_types AS e
ON ((c.table_catalog, c.table_schema, c.table_name, 'TABLE', c.dtd_identifier) =
(e.object_catalog, e.object_schema, e.object_name, e.object_type, e.collection_type_identifier))
{w}
ORDER BY c.table_name ASC, c.ordinal_position ASC
""", *params)
def map_types(t, udt):
base = db.context.config.type_mapping
ptype = base and base(t, udt_name=udt)
return ptype or _map_types(t)
def column_of(n, t, udt, nullable, et, eudt, constraint, default, pos):
m = SequencePattern.match(default or "")
cs = (constraint or "").split(',')
seq = m.group(1) if m else None
null = nullable == 'YES'
ptype = map_types(t, udt) if t != 'ARRAY' else list[map_types(et, eudt)]
info = (t, udt) if t != 'ARRAY' else (et, eudt)
return Column(n, ptype, info, 'PRIMARY KEY' in cs, Relations() if 'FOREIGN KEY' in cs else None, seq, null)
tables = []
column_positions = {}
for t, cols in groupby(cursor.fetchall(), lambda row: row[0]):
cols = list(cols)
columns = [column_of(*c[1:]) for c in cols]
tables.append(Table(t, columns))
column_positions[t] = {c[1]:c[-1] for c in cols}
cursor.close()
cursor = db.stmt().execute(f"""\
SELECT
k.table_name AS t1, k.column_name AS c1, k2.table_name AS t2, k2.column_name AS c2
FROM
information_schema.referential_constraints AS r
INNER JOIN information_schema.key_column_usage AS k ON r.constraint_name = k.constraint_name
INNER JOIN information_schema.key_column_usage AS k2
ON r.unique_constraint_name = k2.constraint_name AND k.ordinal_position = k2.ordinal_position
ORDER BY
k.table_name ASC
""")
table_map = {t.name:t for t in tables}
for row in cursor.fetchall():
table_from = table_map.get(row[0], None)
col_from = table_from.find(row[1]) if table_from else None
if col_from:
table_to = table_map.get(row[2], None)
col_to = table_to.find(row[3]) if table_to else None
col_from.fk.add(ForeignKey(table_to or row[2], col_to or row[3]))
cursor.close()
# Materialized views
cond = Q.eq("c", relkind = "m") & Q.ge("a", attnum = 1) \
& q.excludes.not_in("c.relname") & q.includes.in_("c.relname")
w, params = where(cond)
cursor = db.stmt().execute(f"""\
SELECT
c.relname, a.attname, a.attnotnull, t.typname, et.typname, a.attnum
FROM
pg_class AS c
INNER JOIN pg_attribute AS a ON c.oid = a.attrelid
INNER JOIN pg_type AS t ON a.atttypid = t.oid
LEFT JOIN pg_type AS et ON t.typelem = et.oid
{w}
ORDER BY
c.oid ASC, a.attnum ASC
""", *params)
def mv_column_of(n, not_null, udt, eudt, pos):
ptype = map_types(_map_alternates(udt), udt) if eudt is None else list[map_types(_map_alternates(eudt), eudt)]
info = (_map_alternates(udt), udt) if eudt is None else (_map_alternates(eudt), eudt)
return Column(n, ptype, info, False, None, None, not not_null)
for t, cols in groupby(cursor.fetchall(), lambda row: row[0]):
cols = list(cols)
columns = [mv_column_of(*c[1:]) for c in cols]
tables.append(Table(t, columns))
column_positions[t] = {c[1]:c[-1] for c in cols}
cursor.close()
if len(tables) == 0:
return tables
cursor = db.stmt().execute(f"""\
SELECT
relname, oid
FROM
pg_class
WHERE
relname IN ({holders(len(tables))})
""", *[t.name for t in tables])
table_oids = {}
for n, oid in cursor.fetchall():
table_oids[n] = oid
for t in tables:
cc = db.stmt().execute(f"SELECT col_description($_, 0)", *[table_oids[t.name]])
t.comment = cc.fetchone()[0] or "" # type: ignore
for i, col in enumerate(t.columns):
cc = db.stmt().execute(f"SELECT col_description($_, $_)", *[table_oids[t.name], column_positions[t.name][col.name]])
col.comment = cc.fetchone()[0] or "" # type: ignore
cc.close()
cursor.close()
return tables
def _map_types(t):
if t == "boolean":
return bool
elif t == "real" or t == "double precision":
return float
elif t == "smallint" or t == "integer" or t == "bigint":
return int
elif t == "numeric" or t == "decimal":
return Decimal
elif t == "character varying" or t == "text" or t == "character":
return str
elif t == "bytea":
return bytes
elif t == "date":
return date
elif t.startswith("timestamp "):
return datetime
elif t == "time" or t.startswith("time "):
return time
elif t == "interval":
return timedelta
elif t == "uuid":
return UUID
elif t == "json" or t == "jsonb":
return dict
else:
return object
def _map_alternates(n):
if n == "int2":
return "smallint"
elif n == "int" or n == "int4":
return "integer"
elif n == "int8":
return "bigint"
elif n == "float4":
return "real"
elif n == "float8":
return "double precision"
elif n == "decimal":
return "numeric"
elif n == "bool":
return "boolean"
elif n == "char":
return "character"
elif n == "varchar":
return "character varying"
elif n == "timetz":
return "time with time zone"
elif n == "timestamptz":
return "timestamp with time zone"
else:
return n
class PostgreSQLMixin(MultiInsertMixin, TruncateMixin):
"""
Model mixin whose methods are available in PostgreSQL.
"""
@classmethod
def last_sequences(cls, db: Connection, num: int) -> list[tuple[Column, int]]:
cols = [c for c in cls.columns if c.incremental]
if len(cols) > 0:
sequences = []
d = db.cursor()
for c in cols:
d.execute(f"SELECT currval('{c.incremental}')")
sequences.append((c, d.fetchone()[0])) # type: ignore
d.close()
return sequences
else:
return []
@classmethod
def support_returning(cls, db: Connection) -> bool:
return True
@classmethod
def truncate(cls, db: Connection):
db.cursor().execute(f"TRUNCATE {cls.name} RESTART IDENTITY CASCADE")
mixins = [PostgreSQLMixin]
Functions
def read_schema(db: Connection, excludes: Optional[list[str]] = None, includes: Optional[list[str]] = None) ‑> list[Table]
-
Collect tables in current database.
Args
excludes
- Excluding table names.
includes
- Including table names. If not specified, all tables are collected.
Returns
Table schemas.
Expand source code
def read_schema(db: Connection, excludes: Optional[list[str]] = None, includes: Optional[list[str]] = None) -> list[Table]: """ Collect tables in current database. Args: excludes: Excluding table names. includes: Including table names. If not specified, all tables are collected. Returns: Table schemas. """ q = Q(excludes = excludes, includes = includes) cond = Q.of("c.table_catalog = current_catalog") & Q.eq("c", table_schema="public") & Q.in_("t", table_type=["BASE TABLE", "VIEW"]) \ & q.excludes.not_in("c.table_name") & q.includes.in_("c.table_name") w, params = where(cond) cursor = db.stmt().execute(f"""\ SELECT c.table_name, c.column_name, c.data_type, c.udt_name, c.is_nullable, e.data_type, e.udt_name, k.constraint_type, c.column_default, c.ordinal_position FROM information_schema.columns AS c INNER JOIN information_schema.tables AS t ON c.table_name = t.table_name LEFT JOIN ( SELECT tc.table_name, k.column_name, string_agg(tc.constraint_type, ',') AS constraint_type FROM information_schema.key_column_usage AS k INNER JOIN information_schema.table_constraints AS tc ON k.constraint_name = tc.constraint_name WHERE tc.constraint_type = 'PRIMARY KEY' OR tc.constraint_type = 'FOREIGN KEY' GROUP BY tc.table_name, k.column_name ) AS k ON t.table_name = k.table_name AND c.column_name = k.column_name LEFT JOIN information_schema.element_types AS e ON ((c.table_catalog, c.table_schema, c.table_name, 'TABLE', c.dtd_identifier) = (e.object_catalog, e.object_schema, e.object_name, e.object_type, e.collection_type_identifier)) {w} ORDER BY c.table_name ASC, c.ordinal_position ASC """, *params) def map_types(t, udt): base = db.context.config.type_mapping ptype = base and base(t, udt_name=udt) return ptype or _map_types(t) def column_of(n, t, udt, nullable, et, eudt, constraint, default, pos): m = SequencePattern.match(default or "") cs = (constraint or "").split(',') seq = m.group(1) if m else None null = nullable == 'YES' ptype = map_types(t, udt) if t != 'ARRAY' else list[map_types(et, eudt)] info = (t, udt) if t != 'ARRAY' else (et, eudt) return Column(n, ptype, info, 'PRIMARY KEY' in cs, Relations() if 'FOREIGN KEY' in cs else None, seq, null) tables = [] column_positions = {} for t, cols in groupby(cursor.fetchall(), lambda row: row[0]): cols = list(cols) columns = [column_of(*c[1:]) for c in cols] tables.append(Table(t, columns)) column_positions[t] = {c[1]:c[-1] for c in cols} cursor.close() cursor = db.stmt().execute(f"""\ SELECT k.table_name AS t1, k.column_name AS c1, k2.table_name AS t2, k2.column_name AS c2 FROM information_schema.referential_constraints AS r INNER JOIN information_schema.key_column_usage AS k ON r.constraint_name = k.constraint_name INNER JOIN information_schema.key_column_usage AS k2 ON r.unique_constraint_name = k2.constraint_name AND k.ordinal_position = k2.ordinal_position ORDER BY k.table_name ASC """) table_map = {t.name:t for t in tables} for row in cursor.fetchall(): table_from = table_map.get(row[0], None) col_from = table_from.find(row[1]) if table_from else None if col_from: table_to = table_map.get(row[2], None) col_to = table_to.find(row[3]) if table_to else None col_from.fk.add(ForeignKey(table_to or row[2], col_to or row[3])) cursor.close() # Materialized views cond = Q.eq("c", relkind = "m") & Q.ge("a", attnum = 1) \ & q.excludes.not_in("c.relname") & q.includes.in_("c.relname") w, params = where(cond) cursor = db.stmt().execute(f"""\ SELECT c.relname, a.attname, a.attnotnull, t.typname, et.typname, a.attnum FROM pg_class AS c INNER JOIN pg_attribute AS a ON c.oid = a.attrelid INNER JOIN pg_type AS t ON a.atttypid = t.oid LEFT JOIN pg_type AS et ON t.typelem = et.oid {w} ORDER BY c.oid ASC, a.attnum ASC """, *params) def mv_column_of(n, not_null, udt, eudt, pos): ptype = map_types(_map_alternates(udt), udt) if eudt is None else list[map_types(_map_alternates(eudt), eudt)] info = (_map_alternates(udt), udt) if eudt is None else (_map_alternates(eudt), eudt) return Column(n, ptype, info, False, None, None, not not_null) for t, cols in groupby(cursor.fetchall(), lambda row: row[0]): cols = list(cols) columns = [mv_column_of(*c[1:]) for c in cols] tables.append(Table(t, columns)) column_positions[t] = {c[1]:c[-1] for c in cols} cursor.close() if len(tables) == 0: return tables cursor = db.stmt().execute(f"""\ SELECT relname, oid FROM pg_class WHERE relname IN ({holders(len(tables))}) """, *[t.name for t in tables]) table_oids = {} for n, oid in cursor.fetchall(): table_oids[n] = oid for t in tables: cc = db.stmt().execute(f"SELECT col_description($_, 0)", *[table_oids[t.name]]) t.comment = cc.fetchone()[0] or "" # type: ignore for i, col in enumerate(t.columns): cc = db.stmt().execute(f"SELECT col_description($_, $_)", *[table_oids[t.name], column_positions[t.name][col.name]]) col.comment = cc.fetchone()[0] or "" # type: ignore cc.close() cursor.close() return tables
Classes
class PostgreSQLMixin
-
Model mixin whose methods are available in PostgreSQL.
Expand source code
class PostgreSQLMixin(MultiInsertMixin, TruncateMixin): """ Model mixin whose methods are available in PostgreSQL. """ @classmethod def last_sequences(cls, db: Connection, num: int) -> list[tuple[Column, int]]: cols = [c for c in cls.columns if c.incremental] if len(cols) > 0: sequences = [] d = db.cursor() for c in cols: d.execute(f"SELECT currval('{c.incremental}')") sequences.append((c, d.fetchone()[0])) # type: ignore d.close() return sequences else: return [] @classmethod def support_returning(cls, db: Connection) -> bool: return True @classmethod def truncate(cls, db: Connection): db.cursor().execute(f"TRUNCATE {cls.name} RESTART IDENTITY CASCADE")
Ancestors
Static methods
def last_sequences(db: Connection, num: int) ‑> list[tuple[Column, int]]
-
Expand source code
@classmethod def last_sequences(cls, db: Connection, num: int) -> list[tuple[Column, int]]: cols = [c for c in cls.columns if c.incremental] if len(cols) > 0: sequences = [] d = db.cursor() for c in cols: d.execute(f"SELECT currval('{c.incremental}')") sequences.append((c, d.fetchone()[0])) # type: ignore d.close() return sequences else: return []
def support_returning(db: Connection) ‑> bool
-
Expand source code
@classmethod def support_returning(cls, db: Connection) -> bool: return True
def truncate(db: Connection)
-
Expand source code
@classmethod def truncate(cls, db: Connection): db.cursor().execute(f"TRUNCATE {cls.name} RESTART IDENTITY CASCADE")
Inherited members