Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 152 additions & 15 deletions python/grass/script/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@
.. sectionauthor:: Martin Landa <landa.martin gmail.com>
"""

from __future__ import annotations

import os

from ctypes import byref

from .core import (
gisenv,
run_command,
parse_command,
read_command,
Expand Down Expand Up @@ -137,6 +143,15 @@ def db_connection(force=False, env=None):
run_command("db.connect", flags="c", env=env)
conn = parse_command("db.connect", flags="g", env=env)

if conn and conn.get("driver") == "sqlite":
gis_env = gisenv()
conn["database"] = (
conn["database"]
.replace("$GISDBASE", gis_env["GISDBASE"])
.replace("$LOCATION_NAME", gis_env["LOCATION_NAME"])
.replace("$MAPSET", gis_env["MAPSET"])
)

return conn


Expand Down Expand Up @@ -233,23 +248,145 @@ def db_table_in_vector(table, mapset=".", env=None):
return None


def db_begin_transaction(driver):
"""Begin transaction.
class DBHandler:
"""DB handler

:return: SQL command as string
Allow execute SQL command(s) in transaction mode.

Public methods:

::execute
"""
if driver in {"sqlite", "pg"}:
return "BEGIN"
if driver == "mysql":
return "START TRANSACTION"
return ""

def __init__(self, driver_name: str, database: str) -> None:
"""Constructor

:param driver_name: DB driver name
:param database: database name
"""
self._driver_name = driver_name
self._database = database
self._import_c_interface()

def _import_c_interface(self):
"""Import C interface"""
try:
from grass.lib.dbmi import (
db_begin_transaction,
db_close_database_shutdown_driver,
db_commit_transaction,
db_execute_immediate,
db_free_string,
db_init_string,
db_set_string,
db_start_driver_open_database,
dbString,
DB_OK,
)
from grass.lib.gis import G_gisinit
from grass.lib.vector import (
Map_info,
Vect_subst_var,
)
except (ImportError, OSError, TypeError) as e:
fatal(_("Unable to import C functions: {e}").format(e))

class CInterface:
def __init__(self):
self.db_begin_transaction = db_begin_transaction
self.db_execute_immediate = db_execute_immediate
self.db_free_string = db_free_string
self.db_init_string = db_init_string
self.db_close_database_shutdown_driver = (
db_close_database_shutdown_driver
)
self.db_commit_transaction = db_commit_transaction
self.db_set_string = db_set_string
self.db_start_driver_open_database = db_start_driver_open_database
self.dbString = dbString
self.DB_OK = DB_OK
self.G_gisinit = G_gisinit
self.Map_info = Map_info
self.Vect_subst_var = Vect_subst_var

self._c_interface = CInterface()

def _init_driver(self):
"""Init DB driver"""
map = self._c_interface.Map_info()
self._pdriver = self._c_interface.db_start_driver_open_database(
self._driver_name,
self._c_interface.Vect_subst_var(self._database, byref(map)),
)
if not self._pdriver:
fatal(
_("Unable to open database <{db}> by driver <{driver}>.").format(
db=self._database, driver=self._driver_name
)
)

def db_commit_transaction(driver):
"""Commit transaction.
def _begin_transaction(self):
"""Begin DB transaction."""
ret = self._c_interface.db_begin_transaction(self._pdriver)
if ret != self._c_interface.DB_OK:
self._shutdown_driver()
fatal(
_(
"Error while starting database <{db}> transaction by"
" driver <{driver}>."
).format(db=self._database, driver=self._driver_name)
)

:return: SQL command as string
"""
if driver in {"sqlite", "pg", "mysql"}:
return "COMMIT"
return ""
def _commit_transaction(self):
"""Commit DB transaction."""
ret = self._c_interface.db_commit_transaction(self._pdriver)
if ret != self._c_interface.DB_OK:
self._shutdown_driver()
fatal(
_(
"Error while commit database <{db}> transaction"
" by driver <{driver}>."
).format(db=self._database, driver=self._driver_name)
)

def _execute(self, sql: str | list | tuple) -> None:
"""Execute SQL

:param sql: SQL command string or list of SQLs commands
"""
stmt = self._c_interface.dbString()
self._c_interface.db_init_string(byref(stmt))
self._c_interface.db_set_string(byref(stmt), sql)
if (
self._c_interface.db_execute_immediate(self._pdriver, byref(stmt))
!= self._c_interface.DB_OK
):
self._c_interface.db_free_string(byref(stmt))
self._shutdown_driver()
fatal(_("Error while executing SQL <{}>.").format(sql))
self._c_interface.db_free_string(byref(stmt))

def _shutdown_driver(self):
"""Close DB and shutdown driver"""
self._c_interface.db_close_database_shutdown_driver(self._pdriver)

def execute(self, sql: str | list | tuple) -> None:
"""Execute SQL

:param sql: SQL command string or list of SQLs statement
"""
self._c_interface.G_gisinit("")
self._init_driver()

# Begin DB transaction
self._begin_transaction()
# Execute SQL string
if isinstance(sql, (list, tuple)):
for statement in sql:
self._execute(sql=statement)
else:
self._execute(sql)
# Commit DB transaction
self._commit_transaction()
Comment thread
tmszi marked this conversation as resolved.

self._shutdown_driver()
6 changes: 3 additions & 3 deletions python/grass/temporal/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,7 +1521,7 @@ def fetchall(self):
def execute_transaction(self, statement, mapset=None):
"""Execute a transactional SQL statement

The BEGIN and END TRANSACTION statements will be added automatically
The BEGIN and COMMIT statements will be added automatically
to the sql statement

:param statement: The executable SQL statement or SQL script
Expand All @@ -1532,9 +1532,9 @@ def execute_transaction(self, statement, mapset=None):
connected = True

sql_script = ""
sql_script += "BEGIN TRANSACTION;\n"
sql_script += "BEGIN;\n"
sql_script += statement
sql_script += "END TRANSACTION;"
sql_script += "COMMIT;"

try:
if self.dbmi.__name__ == "sqlite3":
Expand Down
19 changes: 10 additions & 9 deletions scripts/db.dropcolumn/db.dropcolumn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@
import sys
import string

from grass.exceptions import CalledModuleError
import grass.script as gs
from grass.exceptions import CalledModuleError
from grass.script.db import DBHandler


def main():
Expand All @@ -56,6 +57,8 @@ def main():
driver = options["driver"]
force = flags["f"]

db_handler = DBHandler(driver_name=driver, database=database)

# check if DB parameters are set, and if not set them.
gs.run_command("db.connect", flags="c")

Expand Down Expand Up @@ -93,6 +96,7 @@ def main():
)
return 0

sqls = []
if driver == "sqlite":
sqlite3_version = gs.read_command(
"db.select",
Expand All @@ -103,9 +107,9 @@ def main():
).split(".")[0:2]

if [int(i) for i in sqlite3_version] >= [int(i) for i in "3.35".split(".")]:
sql = "ALTER TABLE %s DROP COLUMN %s" % (table, column)
if column == "cat":
sql = "DROP INDEX %s_%s; %s" % (table, column, sql)
sqls.append(f"DROP INDEX {table}_{column};")
sqls.append(f"ALTER TABLE {table} DROP COLUMN {column};")
else:
# for older sqlite3 versions, use old way to remove column
colnames = []
Expand All @@ -119,24 +123,21 @@ def main():
coltypes = ", ".join(coltypes)

cmds = [
"BEGIN TRANSACTION",
"CREATE TEMPORARY TABLE ${table}_backup(${coldef})",
"INSERT INTO ${table}_backup SELECT ${colnames} FROM ${table}",
"DROP TABLE ${table}",
"CREATE TABLE ${table}(${coldef})",
"INSERT INTO ${table} SELECT ${colnames} FROM ${table}_backup",
"DROP TABLE ${table}_backup",
"COMMIT",
]
tmpl = string.Template(";\n".join(cmds))
sql = tmpl.substitute(table=table, coldef=coltypes, colnames=colnames)
sqls.extend(sql.split("\n"))
else:
sql = "ALTER TABLE %s DROP COLUMN %s" % (table, column)
sqls.append(f"ALTER TABLE {table} DROP COLUMN {column};")

try:
gs.write_command(
"db.execute", input="-", database=database, driver=driver, stdin=sql
)
db_handler.execute(sql=";".join(sqls))
except CalledModuleError:
gs.fatal(_("Cannot continue (problem deleting column)"))

Expand Down
41 changes: 8 additions & 33 deletions scripts/v.db.addcolumn/v.db.addcolumn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,32 +40,14 @@
# % key_desc: name type
# %end

import atexit
import os
from pathlib import Path
import re

from grass.exceptions import CalledModuleError
import grass.script as gs

rm_files = []


def cleanup():
for file in rm_files:
if os.path.isfile(file):
try:
os.remove(file)
except Exception as e:
gs.warning(
_("Unable to remove file {file}: {message}").format(
file=file, message=e
)
)
from grass.exceptions import CalledModuleError
from grass.script.db import DBHandler


def main():
global rm_files
map = options["map"]
layer = options["layer"]
columns = options["columns"]
Expand Down Expand Up @@ -100,7 +82,9 @@ def main():
driver = f["driver"]
column_existing = gs.vector_columns(map, int(layer)).keys()

add_str = "BEGIN TRANSACTION\n"
db_handler = DBHandler(driver_name=driver, database=database)

sqls = []
pattern = re.compile(r"\s+")
for col in columns:
if not col:
Expand All @@ -120,19 +104,11 @@ def main():
)
continue
gs.verbose(_("Adding column <{}> to the table").format(col_name))
add_str += f'ALTER TABLE {table} ADD COLUMN "{col_name}" {col_type};\n'
add_str += "END TRANSACTION"
sql_file = gs.tempfile()
rm_files.append(sql_file)
sqls.append(f'ALTER TABLE {table} ADD COLUMN "{col_name}" {col_type};')

cols_add_str = ",".join([col[0] for col in columns])
Path(sql_file).write_text(add_str)
try:
gs.run_command(
"db.execute",
input=sql_file,
database=database,
driver=driver,
)
db_handler.execute(sql=sqls)
except CalledModuleError:
gs.fatal(_("Error adding columns {}").format(cols_add_str))
# write cmd history:
Expand All @@ -141,5 +117,4 @@ def main():

if __name__ == "__main__":
options, flags = gs.parser()
atexit.register(cleanup)
main()
Loading