diff --git a/README.md b/README.md index 856bfa9b..684ae082 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,11 @@ in DATABASES control the behavior of the backend: - PASSWORD String. Database user password. + +- SCHEMA + String. Default schema to use. Not required. + - TOKEN String. Access token fetched as a user or service principal which diff --git a/mssql/compiler.py b/mssql/compiler.py index 3e97ae40..7788a3ab 100644 --- a/mssql/compiler.py +++ b/mssql/compiler.py @@ -18,6 +18,9 @@ if django.VERSION >= (4, 2): from django.core.exceptions import EmptyResultSet, FullResultSet +from .introspection import get_table_name, get_schema_name +from django.apps import apps + def _as_sql_agv(self, compiler, connection): return self.as_sql(compiler, connection, template='%(function)s(CONVERT(float, %(field)s))') @@ -196,7 +199,6 @@ def _cursor_iter(cursor, sentinel, col_count, itersize): compiler.cursor_iter = _cursor_iter - class SQLCompiler(compiler.SQLCompiler): def as_sql(self, with_limits=True, with_col_aliases=False): @@ -227,6 +229,7 @@ def as_sql(self, with_limits=True, with_col_aliases=False): do_offset_emulation = do_offset and not supports_offset_clause if combinator: + if not getattr(features, 'supports_select_{}'.format(combinator)): raise NotSupportedError('{} is not supported on this database backend.'.format(combinator)) result, params = self.get_combinator_sql(combinator, self.query.combinator_all) @@ -285,7 +288,8 @@ def as_sql(self, with_limits=True, with_col_aliases=False): if do_offset: meta = self.query.get_meta() qn = self.quote_name_unless_alias - offsetting_order_by = '%s.%s' % (qn(meta.db_table), qn(meta.pk.db_column or meta.pk.column)) + table = qn(get_table_name(self, meta.db_table, getattr(meta, "db_table_schema", False))) + offsetting_order_by = '%s.%s' % (table, qn(meta.pk.db_column or meta.pk.column)) if do_offset_emulation: if order_by: ordering = [] @@ -431,12 +435,53 @@ def as_sql(self, with_limits=True, with_col_aliases=False): ', '.join(sub_selects), ' '.join(result), ), tuple(sub_params + params) - return ' '.join(result), tuple(params) finally: # Finally do cleanup - get rid of the joins we created above. self.query.reset_refcounts(refcounts_before) - + def get_from_clause(self): + """ + Return a list of strings that are joined together to go after the + "FROM" part of the query, as well as a list any extra parameters that + need to be included. Subclasses, can override this to create a + from-clause via a "select". + + This should only be called after any SQL construction methods that + might change the tables that are needed. This means the select columns, + ordering, and distinct must be done first. + """ + result = [] + params = [] + for alias in tuple(self.query.alias_map): + if not self.query.alias_refcount[alias]: + continue + try: + from_clause = self.query.alias_map[alias] + except KeyError: + # Extra tables can end up in self.tables, but not in the + # alias_map if they aren't in a join. That's OK. We skip them. + continue + settings_dict = self.connection.settings_dict + clause_sql, clause_params = self.compile(from_clause) + model = next((m for m in apps.get_models() if m._meta.db_table == from_clause.table_name), None) + schema = getattr(getattr(model,"_meta", None), "db_table_schema", settings_dict.get('SCHEMA', False)) + if schema: + if 'JOIN' in clause_sql: + table_clause_sql = clause_sql.split('JOIN ')[1].split(' ON')[0] + table_clause_sql = f'[{schema}].{table_clause_sql}' + clause_sql = clause_sql.split('JOIN')[0] + 'JOIN ' + table_clause_sql + ' ON' + clause_sql.split('JOIN')[1].split('ON')[1] + else: + clause_sql = f'[{schema}].{clause_sql}' + result.append(clause_sql) + params.extend(clause_params) + for t in self.query.extra_tables: + alias, _ = self.query.table_alias(t) + # Only add the alias if it's not already present (the table_alias() + # call increments the refcount, so an alias refcount of one means + # this is the only reference). + if alias not in self.query.alias_map or self.query.alias_refcount[alias] == 1: + result.append(', %s' % self.quote_name_unless_alias(alias)) + return result, params def compile(self, node, *args, **kwargs): node = self._as_microsoft(node) return super().compile(node, *args, **kwargs) @@ -550,7 +595,7 @@ def fix_auto(self, sql, opts, fields, qn): columns = [f.column for f in fields] if auto_field_column in columns: id_insert_sql = [] - table = qn(opts.db_table) + table = qn(get_table_name(self, opts.db_table, getattr(opts, "db_table_schema", False))) sql_format = 'SET IDENTITY_INSERT %s ON; %s; SET IDENTITY_INSERT %s OFF' for q, p in sql: id_insert_sql.append((sql_format % (table, q, table), p)) @@ -587,7 +632,8 @@ def as_sql(self): # going to be column names (so we can avoid the extra overhead). qn = self.connection.ops.quote_name opts = self.query.get_meta() - result = ['INSERT INTO %s' % qn(opts.db_table)] + table = qn(get_table_name(self, opts.db_table, getattr(opts, "db_table_schema", False))) + result = ['INSERT INTO %s' % table] if self.query.fields: fields = self.query.fields @@ -617,7 +663,7 @@ def as_sql(self): # There isn't really a single statement to bulk multiple DEFAULT VALUES insertions, # so we have to use a workaround: # https://dba.stackexchange.com/questions/254771/insert-multiple-rows-into-a-table-with-only-an-identity-column - result = [self.bulk_insert_default_values_sql(qn(opts.db_table))] + result = [self.bulk_insert_default_values_sql(qn(table))] r_sql, self.returning_params = self.connection.ops.return_insert_columns(self.get_returned_fields()) if r_sql: result.append(r_sql) @@ -660,13 +706,82 @@ def as_sql(self): sql = '; '.join(['SET NOCOUNT OFF', sql]) return sql, params + def _as_sql(self, query): + opts = self.query.get_meta() + table = get_table_name(self, query.base_table, getattr(opts, "db_table_schema", False)) + delete = "DELETE FROM %s" % self.quote_name_unless_alias(table) + try: + where, params = self.compile(query.where) + except FullResultSet: + return delete, () + return f"{delete} WHERE {where}", tuple(params) + class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): def as_sql(self): - sql, params = super().as_sql() - if sql: - sql = '; '.join(['SET NOCOUNT OFF', sql]) - return sql, params + """ + Create the SQL for this query. Return the SQL string and list of + parameters. + """ + self.pre_sql_setup() + if not self.query.values: + return "", () + qn = self.quote_name_unless_alias + values, update_params = [], [] + for field, model, val in self.query.values: + if hasattr(val, "resolve_expression"): + val = val.resolve_expression( + self.query, allow_joins=False, for_save=True + ) + if val.contains_aggregate: + raise FieldError( + "Aggregate functions are not allowed in this query " + "(%s=%r)." % (field.name, val) + ) + if val.contains_over_clause: + raise FieldError( + "Window expressions are not allowed in this query " + "(%s=%r)." % (field.name, val) + ) + elif hasattr(val, "prepare_database_save"): + if field.remote_field: + val = val.prepare_database_save(field) + else: + raise TypeError( + "Tried to update field %s with a model instance, %r. " + "Use a value compatible with %s." + % (field, val, field.__class__.__name__) + ) + val = field.get_db_prep_save(val, connection=self.connection) + + # Getting the placeholder for the field. + if hasattr(field, "get_placeholder"): + placeholder = field.get_placeholder(val, self, self.connection) + else: + placeholder = "%s" + name = field.column + if hasattr(val, "as_sql"): + sql, params = self.compile(val) + values.append("%s = %s" % (qn(name), placeholder % sql)) + update_params.extend(params) + elif val is not None: + values.append("%s = %s" % (qn(name), placeholder)) + update_params.append(val) + else: + values.append("%s = NULL" % qn(name)) + opts = self.query.get_meta() + table = get_table_name(self, self.query.base_table, getattr(opts, "db_table_schema", False)) + result = [ + "UPDATE %s SET" % qn(table), + ", ".join(values), + ] + try: + where, params = self.compile(self.query.where) + except FullResultSet: + params = [] + else: + result.append("WHERE %s" % where) + return " ".join(result), tuple(update_params + params) class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler): diff --git a/mssql/introspection.py b/mssql/introspection.py index 96974ec4..76e969ea 100644 --- a/mssql/introspection.py +++ b/mssql/introspection.py @@ -17,13 +17,34 @@ SQL_BIGAUTOFIELD = -777444 SQL_SMALLAUTOFIELD = -777333 SQL_TIMESTAMP_WITH_TIMEZONE = -155 +from django.db import connection FieldInfo = namedtuple("FieldInfo", BaseFieldInfo._fields + ("comment",)) TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",)) def get_schema_name(): - return getattr(settings, 'SCHEMA_TO_INSPECT', 'SCHEMA_NAME()') + # get default schema choosen by user in settings.py else SCHEMA_NAME() + settings_dict = connection.settings_dict + schema = settings_dict.get('SCHEMA', False) + return f"'{schema}'" if schema else 'SCHEMA_NAME()' +def get_table_name(object, table_name, custom_schema): + """ + get the name of the table on this format schema].[table_name + if + schema = custom schema defined in medels meta (db_table_schema) + else + schema choosen by user in settings.py + else + return the name of the table without schema (defalut one will be used) + """ + if custom_schema: + return f'{custom_schema}].[{table_name}' + settings_dict = object.connection.settings_dict + schema_name = settings_dict.get('SCHEMA', False) + if schema_name: + return f'{schema_name}].[{table_name}' + return table_name class DatabaseIntrospection(BaseDatabaseIntrospection): # Map type codes to Django Field types. @@ -237,7 +258,7 @@ def get_key_columns(self, cursor, table_name): key_columns.extend([tuple(row) for row in cursor.fetchall()]) return key_columns - def get_constraints(self, cursor, table_name): + def get_constraints(self, cursor, table_name, table_name_schema='SCHEMA_NAME()'): """ Retrieves any constraints or keys (unique, pk, fk, check, index) across one or more columns. @@ -296,12 +317,12 @@ def get_constraints(self, cursor, table_name): kc.table_name = fk.table_name AND kc.column_name = fk.column_name WHERE - kc.table_schema = {get_schema_name()} AND + kc.table_schema = {table_name_schema} AND kc.table_name = %s ORDER BY kc.constraint_name ASC, kc.ordinal_position ASC - """, [table_name]) + """ , [table_name]) for constraint, column, kind, ref_table, ref_column in cursor.fetchall(): # If we're the first column, make the record if constraint not in constraints: @@ -331,7 +352,7 @@ def get_constraints(self, cursor, table_name): kc.constraint_name = c.constraint_name WHERE c.constraint_type = 'CHECK' AND - kc.table_schema = {get_schema_name()} AND + kc.table_schema = {table_name_schema} AND kc.table_name = %s """, [table_name]) for constraint, column in cursor.fetchall(): @@ -398,7 +419,7 @@ def get_constraints(self, cursor, table_name): ic.object_id = c.object_id AND ic.column_id = c.column_id WHERE - t.schema_id = SCHEMA_ID({get_schema_name()}) AND + t.schema_id = SCHEMA_ID({table_name_schema}) AND t.name = %s ORDER BY i.index_id ASC, diff --git a/mssql/schema.py b/mssql/schema.py index 87cb4cfa..e2165fda 100644 --- a/mssql/schema.py +++ b/mssql/schema.py @@ -15,6 +15,7 @@ from django.db.backends.ddl_references import ( Columns, IndexName, + ForeignKeyName, Statement as DjStatement, Table, ) @@ -30,6 +31,8 @@ from django.db.models.sql import Query from django.db.backends.ddl_references import Expressions +from .introspection import get_table_name, get_schema_name + class Statement(DjStatement): def __hash__(self): @@ -92,7 +95,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): DROP TABLE %(table)s """ sql_rename_column = "EXEC sp_rename '%(table)s.%(old_column)s', %(new_column)s, 'COLUMN'" - sql_rename_table = "EXEC sp_rename %(old_table)s, %(new_table)s" + sql_rename_table = "EXEC sp_rename '%(old_table)s', '%(new_table)s'" sql_create_unique_null = "CREATE UNIQUE INDEX %(name)s ON %(table)s(%(columns)s) " \ "WHERE %(columns)s IS NOT NULL" sql_alter_table_comment= """ @@ -146,9 +149,10 @@ def _alter_column_default_sql(self, model, old_field, new_field, drop=False): if drop: params = [] # SQL Server requires the name of the default constraint + db_table = model._meta.db_table result = self.execute( self._sql_select_default_constraint_name % { - "table": self.quote_value(model._meta.db_table), + "table": self.quote_value(db_table), "column": self.quote_value(new_field.column), }, has_result=True @@ -226,7 +230,18 @@ def _alter_column_comment_sql(self, model, new_field, new_type, new_db_comment): }, [], ) - + + def _alter_column_comment_sql(self, model, new_field, new_type, new_db_comment): + return ( + self.sql_alter_column_comment + % { + "table": self.quote_name(model._meta.db_table), + "column": new_field.column, + "comment": self._comment_sql(new_db_comment), + }, + [], + ) + def _alter_column_null_sql(self, model, old_field, new_field): """ Hook to specialize column null alteration. @@ -327,7 +342,7 @@ def _model_indexes_sql(self, model): output.append(index.create_sql(model, self)) return output - def _db_table_constraint_names(self, db_table, column_names=None, column_match_any=False, + def _db_table_constraint_names(self, model, column_names=None, column_match_any=False, unique=None, primary_key=None, index=None, foreign_key=None, check=None, type_=None, exclude=None, unique_constraint=None): """ @@ -338,13 +353,16 @@ def _db_table_constraint_names(self, db_table, column_names=None, column_match_a False: (default) only return constraints covering exactly `column_names` True : return any constraints which include at least 1 of `column_names` """ + db_table = model._meta.db_table + db_table_schema = getattr(model._meta, "db_table_schema", False) + db_table_schema = f"'{db_table_schema}'" if db_table_schema else get_schema_name() if column_names is not None: column_names = [ self.connection.introspection.identifier_converter(name) for name in column_names ] with self.connection.cursor() as cursor: - constraints = self.connection.introspection.get_constraints(cursor, db_table) + constraints = self.connection.introspection.get_constraints(cursor, db_table, db_table_schema) result = [] for name, infodict in constraints.items(): if column_names is None or column_names == infodict['columns'] or ( @@ -400,6 +418,9 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, # the backend doesn't support altering a column to/from AutoField as # SQL Server cannot alter columns to add and remove IDENTITY properties + + db_table = model._meta.db_table + table = get_table_name(self, model._meta.db_table, getattr(model._meta, "db_table_schema", False)) old_is_auto = False new_is_auto = False for t in (AutoField, BigAutoField): @@ -446,7 +467,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, if strict and len(fk_names) != 1: raise ValueError("Found wrong number (%s) of foreign key constraints for %s.%s" % ( len(fk_names), - model._meta.db_table, + table, old_field.column, )) for fk_name in fk_names: @@ -520,7 +541,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, if strict and len(constraint_names) != 1: raise ValueError("Found wrong number (%s) of check constraints for %s.%s" % ( len(constraint_names), - model._meta.db_table, + table, old_field.column, )) for constraint_name in constraint_names: @@ -530,7 +551,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, sql_restore_index = '' # Drop any unique indexes which include the column to be renamed index_names = self._db_table_constraint_names( - db_table=model._meta.db_table, column_names=[old_field.column], column_match_any=True, + model=model, column_names=[old_field.column], column_match_any=True, index=True, unique=True, ) for index_name in index_names: @@ -542,24 +563,24 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, FROM sys.indexes AS i INNER JOIN sys.index_columns AS ic ON i.object_id = ic.object_id AND i.index_id = ic.index_id - WHERE i.object_id = OBJECT_ID('{model._meta.db_table}') + WHERE i.object_id = OBJECT_ID('{db_table}') and i.name = '{index_name}' """) result = cursor.fetchall() columns_to_recreate_index = ', '.join(['%s' % self.quote_name(column[0]) for column in result]) filter_definition = result[0][1] sql_restore_index += 'CREATE UNIQUE INDEX %s ON %s (%s) WHERE %s;' % ( - index_name, model._meta.db_table, columns_to_recreate_index, filter_definition) + index_name, '[' + table + ']', columns_to_recreate_index, filter_definition) self.execute(self._db_table_delete_constraint_sql( - self.sql_delete_index, model._meta.db_table, index_name)) - self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type)) + self.sql_delete_index, table, index_name)) + self.execute(self._rename_field_sql(table, old_field, new_field, new_type)) # Restore index(es) now the column has been renamed if sql_restore_index: self.execute(sql_restore_index.replace(f'[{old_field.column}]', f'[{new_field.column}]')) # Rename all references to the renamed column. for sql in self.deferred_sql: if isinstance(sql, DjStatement): - sql.rename_column_references(model._meta.db_table, old_field.column, new_field.column) + sql.rename_column_references(table, old_field.column, new_field.column) # Next, start accumulating actions to do actions = [] @@ -629,7 +650,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, indexes_dropped = self._delete_indexes(model, old_field, new_field) auto_index_names = [] for index_from_meta in model._meta.indexes: - auto_index_names.append(self._create_index_name(model._meta.db_table, index_from_meta.fields)) + auto_index_names.append(self._create_index_name(db_table, index_from_meta.fields)) if ( new_field.get_internal_type() not in ("JSONField", "TextField") and @@ -659,7 +680,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, for sql, params in actions: self.execute( self.sql_alter_column % { - "table": self.quote_name(model._meta.db_table), + "table": self.quote_name(table), "changes": sql, }, params, @@ -673,7 +694,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, # Update existing rows with default value self.execute( self.sql_update_with_default % { - "table": self.quote_name(model._meta.db_table), + "table": self.quote_name(table), "column": self.quote_name(new_field.column), "default": default_sql, }, @@ -684,7 +705,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, for sql, params in null_actions: self.execute( self.sql_alter_column % { - "table": self.quote_name(model._meta.db_table), + "table": self.quote_name(table), "changes": sql, }, params, @@ -762,9 +783,9 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, if old_field.primary_key and new_field.primary_key: self.execute( self.sql_create_pk % { - "table": self.quote_name(model._meta.db_table), + "table": self.quote_name(table), "name": self.quote_name( - self._create_index_name(model._meta.db_table, [new_field.column], suffix="_pk") + self._create_index_name(db_table, [new_field.column], suffix="_pk") ), "columns": self.quote_name(new_field.column), } @@ -825,9 +846,9 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, # Make the new one self.execute( self.sql_create_pk % { - "table": self.quote_name(model._meta.db_table), + "table": self.quote_name(table), "name": self.quote_name( - self._create_index_name(model._meta.db_table, [new_field.column], suffix="_pk") + self._create_index_name(db_table, [new_field.column], suffix="_pk") ), "columns": self.quote_name(new_field.column), } @@ -847,7 +868,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, new_rel.related_model, old_rel.field, new_rel.field, rel_type ) # Drop related_model indexes, so it can be altered - index_names = self._db_table_constraint_names(old_rel.related_model._meta.db_table, index=True) + index_names = self._db_table_constraint_names(old_rel.related_model, index=True) for index_name in index_names: self.execute(self._db_table_delete_constraint_sql( self.sql_delete_index, old_rel.related_model._meta.db_table, index_name)) @@ -899,9 +920,9 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, ): self.execute( self.sql_create_check % { - "table": self.quote_name(model._meta.db_table), + "table": self.quote_name(table), "name": self.quote_name( - self._create_index_name(model._meta.db_table, [new_field.column], suffix="_check") + self._create_index_name(db_table, [new_field.column], suffix="_check") ), "column": self.quote_name(new_field.column), "check": new_db_params['check'], @@ -912,7 +933,7 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type, if needs_database_default: changes_sql, params = self._alter_column_default_sql(model, old_field, new_field, drop=True) sql = self.sql_alter_column % { - "table": self.quote_name(model._meta.db_table), + "table": self.quote_name(table), "changes": changes_sql, } self.execute(sql, params) @@ -985,12 +1006,12 @@ def _delete_unique_constraints(self, model, old_field, new_field, strict=False): def _delete_unique_constraint_for_columns(self, model, columns, strict=False, **constraint_names_kwargs): constraint_names_unique = self._db_table_constraint_names( - model._meta.db_table, columns, unique=True, unique_constraint=True, **constraint_names_kwargs) + model, columns, unique=True, unique_constraint=True, **constraint_names_kwargs) constraint_names_primary = self._db_table_constraint_names( - model._meta.db_table, columns, unique=True, primary_key=True, **constraint_names_kwargs) + model, columns, unique=True, primary_key=True, **constraint_names_kwargs) constraint_names_normal = constraint_names_unique + constraint_names_primary constraint_names_index = self._db_table_constraint_names( - model._meta.db_table, columns, unique=True, unique_constraint=False, primary_key=False, + model, columns, unique=True, unique_constraint=False, primary_key=False, **constraint_names_kwargs) constraint_names = constraint_names_normal + constraint_names_index if django_version >= (4, 1): @@ -1035,6 +1056,9 @@ def add_field(self, model, field): Create a field on a model. Usually involves adding a column, but may involve adding a table instead (for M2M fields). """ + + db_table = model._meta.db_table + table = get_table_name(self, model._meta.db_table, getattr(model._meta, "db_table_schema", False)) # Special-case implicit M2M tables if field.many_to_many and field.remote_field.through._meta.auto_created: return self.create_model(field.remote_field.through) @@ -1068,7 +1092,7 @@ def add_field(self, model, field): definition += " CHECK (%s)" % db_params['check'] # Build the SQL and run it sql = self.sql_create_column % { - "table": self.quote_name(model._meta.db_table), + "table": self.quote_name(table), "column": self.quote_name(field.column), "definition": definition, } @@ -1082,7 +1106,7 @@ def add_field(self, model, field): ): changes_sql, params = self._alter_column_default_sql(model, None, field, drop=True) sql = self.sql_alter_column % { - "table": self.quote_name(model._meta.db_table), + "table": self.quote_name(table), "changes": changes_sql, } self.execute(sql, params) @@ -1135,8 +1159,7 @@ def create_unique_name(*args, **kwargs): compiler = Query(model, alias_cols=False).get_compiler(connection=self.connection) columns = [field.column for field in fields] - table = model._meta.db_table - + table = get_table_name(self, model._meta.db_table, getattr(model._meta, "db_table_schema", False)) if name is None: name = IndexName(table, columns, '_uniq', create_unique_name) else: @@ -1184,10 +1207,10 @@ def _create_unique_sql(self, model, columns, def create_unique_name(*args, **kwargs): return self.quote_name(self._create_index_name(*args, **kwargs)) - - table = Table(model._meta.db_table, self.quote_name) + db_table = get_table_name(self, model._meta.db_table, getattr(model._meta, "db_table_schema", False)) + table = Table(db_table, self.quote_name) if name is None: - name = IndexName(model._meta.db_table, columns, '_uniq', create_unique_name) + name = IndexName(db_table, columns, '_uniq', create_unique_name) else: name = self.quote_name(name) columns = Columns(table, columns, self.quote_name) @@ -1225,16 +1248,58 @@ def _create_index_sql(self, model, fields, *, name=None, suffix='', using='', indexes, ...). """ if django_version >= (3, 2): - return super()._create_index_sql( - model, fields=fields, name=name, suffix=suffix, using=using, - db_tablespace=db_tablespace, col_suffixes=col_suffixes, sql=sql, - opclasses=opclasses, condition=condition, include=include, - expressions=expressions, + fields = fields or [] + expressions = expressions or [] + compiler = Query(model, alias_cols=False).get_compiler( + connection=self.connection, + ) + tablespace_sql = self._get_index_tablespace_sql( + model, fields, db_tablespace=db_tablespace + ) + columns = [field.column for field in fields] + sql_create_index = sql or self.sql_create_index + table_name = model._meta.db_table + table = get_table_name(self, model._meta.db_table, getattr(model._meta, "db_table_schema", False)) + def create_index_name(*args, **kwargs): + nonlocal name + if name is None: + name = self._create_index_name(*args, **kwargs) + return self.quote_name(name) + + return Statement( + sql_create_index, + table=Table(table, self.quote_name), + name=IndexName(table_name, columns, suffix, create_index_name), + using=using, + columns=( + self._index_columns(table, columns, col_suffixes, opclasses) + if columns + else Expressions(table, expressions, compiler, self.quote_value) + ), + extra=tablespace_sql, + condition=self._index_condition_sql(condition), + include=self._index_include_sql(model, include), ) - return super()._create_index_sql( - model, fields=fields, name=name, suffix=suffix, using=using, - db_tablespace=db_tablespace, col_suffixes=col_suffixes, sql=sql, - opclasses=opclasses, condition=condition, + table_name = model._meta.db_table + table = get_table_name(self, model._meta.db_table, getattr(model._meta, "db_table_schema", False)) + tablespace_sql = self._get_index_tablespace_sql(model, fields, db_tablespace=db_tablespace) + columns = [field.column for field in fields] + sql_create_index = sql or self.sql_create_index + + def create_index_name(*args, **kwargs): + nonlocal name + if name is None: + name = self._create_index_name(*args, **kwargs) + return self.quote_name(name) + + return Statement( + sql_create_index, + table=Table(table, self.quote_name), + name=IndexName(table_name, columns, suffix, create_index_name), + using=using, + columns=self._index_columns(table, columns, col_suffixes, opclasses), + extra=tablespace_sql, + condition=(' WHERE ' + condition) if condition else '', ) def create_model(self, model): @@ -1243,6 +1308,8 @@ def create_model(self, model): Will also create any accompanying indexes or unique constraints. """ # Create column SQL, add FK deferreds if needed + table_name = model._meta.db_table + db_table = get_table_name(self, model._meta.db_table, getattr(model._meta, "db_table_schema", False)) column_sqls = [] params = [] for field in model._meta.local_fields: @@ -1269,7 +1336,7 @@ def create_model(self, model): if db_params['check']: # SQL Server requires a name for the check constraint definition += self._sql_check_constraint % { - "name": self._create_index_name(model._meta.db_table, [field.column], suffix="_check"), + "name": self._create_index_name(table_name, [field.column], suffix="_check"), "check": db_params['check'] } # Autoincrement SQL (for backends with inline variant) @@ -1295,7 +1362,7 @@ def create_model(self, model): )) # Autoincrement SQL (for backends with post table definition variant) if field.get_internal_type() in ("AutoField", "BigAutoField", "SmallAutoField"): - autoinc_sql = self.connection.ops.autoinc_sql(model._meta.db_table, field.column) + autoinc_sql = self.connection.ops.autoinc_sql(db_table, field.column) if autoinc_sql: self.deferred_sql.extend(autoinc_sql) @@ -1313,7 +1380,7 @@ def create_model(self, model): constraints = [constraint.constraint_sql(model, self) for constraint in model._meta.constraints] # Make the table sql = self.sql_create_table % { - "table": self.quote_name(model._meta.db_table), + "table": self.quote_name(db_table), 'definition': ', '.join(constraint for constraint in (*column_sqls, *constraints) if constraint), } if model._meta.db_tablespace: @@ -1381,7 +1448,24 @@ def _delete_unique_sql( return self._delete_constraint_sql(sql, model, name) def delete_model(self, model): - super().delete_model(model) + """Delete a model from the database.""" + # Handle auto-created intermediary models + db_table = get_table_name(self, model._meta.db_table, getattr(model._meta, "db_table_schema", False)) + for field in model._meta.local_many_to_many: + if field.remote_field.through._meta.auto_created: + self.delete_model(field.remote_field.through) + + # Delete the table + self.execute( + self.sql_delete_table + % { + "table": self.quote_name(db_table), + } + ) + # Remove all deferred statements referencing the deleted table. + for sql in list(self.deferred_sql): + if isinstance(sql, Statement) and sql.references_table(db_table): + self.deferred_sql.remove(sql) def execute(self, sql, params=(), has_result=False): """ @@ -1438,7 +1522,38 @@ def quote_value(self, value): return "1" if value else "0" else: return str(value) - + def _create_fk_sql(self, model, field, suffix): + table_name = model._meta.db_table + db_table = get_table_name(self, model._meta.db_table, getattr(model._meta, "db_table_schema", False)) + to_db_table = get_table_name(self, field.target_field.model._meta.db_table, getattr(field.target_field.model._meta, "db_table_schema", False)) + table = Table(db_table, self.quote_name) + name = self._fk_constraint_name(model, field, suffix) + column = Columns(table_name, [field.column], self.quote_name) + to_table = Table(to_db_table, self.quote_name) + to_column = Columns(field.target_field.model._meta.db_table, [field.target_field.column], self.quote_name) + deferrable = self.connection.ops.deferrable_sql() + return Statement( + self.sql_create_fk, + table=table, + name=name, + column=column, + to_table=to_table, + to_column=to_column, + deferrable=deferrable, + ) + def _fk_constraint_name(self, model, field, suffix): + table_name = model._meta.db_table + to_db_table = get_table_name(self, field.target_field.model._meta.db_table, getattr(field.target_field.model._meta, "db_table_schema", False)) + def create_fk_name(*args, **kwargs): + return self.quote_name(self._create_index_name(*args, **kwargs)) + return ForeignKeyName( + table_name, + [field.column], + to_db_table, + [field.target_field.column], + suffix, + create_fk_name, + ) def remove_field(self, model, field): """ Removes a field from a model. Usually involves deleting a column, @@ -1453,6 +1568,7 @@ def remove_field(self, model, field): # Drop any FK constraints, SQL Server requires explicit deletion with self.connection.cursor() as cursor: constraints = self.connection.introspection.get_constraints(cursor, model._meta.db_table) + db_table = get_table_name(self, model._meta.db_table, getattr(model._meta, "db_table_schema", False)) for name, infodict in constraints.items(): if field.column in infodict['columns'] and infodict['foreign_key']: self.execute(self._delete_constraint_sql(self.sql_delete_fk, model, name)) @@ -1460,21 +1576,21 @@ def remove_field(self, model, field): for name, infodict in constraints.items(): if field.column in infodict['columns'] and infodict['index']: self.execute(self.sql_delete_index % { - "table": self.quote_name(model._meta.db_table), + "table": self.quote_name(db_table), "name": self.quote_name(name), }) # Drop primary key constraint, SQL Server requires explicit deletion for name, infodict in constraints.items(): if field.column in infodict['columns'] and infodict['primary_key']: self.execute(self.sql_delete_pk % { - "table": self.quote_name(model._meta.db_table), + "table": self.quote_name(db_table), "name": self.quote_name(name), }) # Drop check constraints, SQL Server requires explicit deletion for name, infodict in constraints.items(): if field.column in infodict['columns'] and infodict['check']: self.execute(self.sql_delete_check % { - "table": self.quote_name(model._meta.db_table), + "table": self.quote_name(db_table), "name": self.quote_name(name), }) # Drop unique constraints, SQL Server requires explicit deletion @@ -1482,7 +1598,7 @@ def remove_field(self, model, field): if (field.column in infodict['columns'] and infodict['unique'] and not infodict['primary_key'] and not infodict['index']): self.execute(self.sql_delete_unique % { - "table": self.quote_name(model._meta.db_table), + "table": self.quote_name(db_table), "name": self.quote_name(name), }) # Drop default constraint, SQL Server requires explicit deletion @@ -1494,7 +1610,7 @@ def remove_field(self, model, field): }) # Delete the column sql = self.sql_delete_column % { - "table": self.quote_name(model._meta.db_table), + "table": self.quote_name(db_table), "column": self.quote_name(field.column), } self.execute(sql) @@ -1503,7 +1619,7 @@ def remove_field(self, model, field): self.connection.close() # Remove all deferred statements referencing the deleted column. for sql in list(self.deferred_sql): - if isinstance(sql, Statement) and sql.references_column(model._meta.db_table, field.column): + if isinstance(sql, Statement) and sql.references_column(db_table, field.column): self.deferred_sql.remove(sql) def add_constraint(self, model, constraint): @@ -1526,6 +1642,13 @@ def _create_index_name(self, table_name, column_names, suffix=""): new_index_name = index_name.replace('[', '').replace(']', '').replace('.', '_') return new_index_name return index_name + + def _delete_constraint_sql(self, template, model, name): + table = get_table_name(self, model._meta.db_table, getattr(model._meta, "db_table_schema", False)) + return Statement( + template, + table=Table(table, self.quote_name), + name=self.quote_name(name), def _unique_supported( self, diff --git a/testapp/migrations/0001_initial.py b/testapp/migrations/0001_initial.py index 6d898a4d..4feac24e 100644 --- a/testapp/migrations/0001_initial.py +++ b/testapp/migrations/0001_initial.py @@ -1,10 +1,15 @@ # Generated by Django 2.2.8.dev20191112211527 on 2019-11-15 01:38 import uuid - from django.db import migrations, models import django +def forwards(apps, schema_editor): + #create schema for testing purpose + if not schema_editor.connection.vendor == 'microsoft': + return + schema_editor.execute("CREATE SCHEMA test_schema;") + class Migration(migrations.Migration): @@ -12,7 +17,6 @@ class Migration(migrations.Migration): dependencies = [ ] - operations = [ migrations.CreateModel( name='Author', @@ -21,12 +25,17 @@ class Migration(migrations.Migration): ('name', models.CharField(max_length=100)), ], ), + migrations.RunPython(forwards), migrations.CreateModel( name='Editor', fields=[ ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), ('name', models.CharField(max_length=100)), ], + options={ + 'db_table': 'editor', + 'db_table_schema': 'test_schema' + }, ), migrations.CreateModel( name='Post', diff --git a/testapp/models.py b/testapp/models.py index 2b874da7..0d0a0fa4 100644 --- a/testapp/models.py +++ b/testapp/models.py @@ -19,11 +19,12 @@ class Meta: class Author(models.Model): name = models.CharField(max_length=100) - class Editor(BigAutoFieldMixin, models.Model): name = models.CharField(max_length=100) - - + class Meta: + db_table = 'editor' + db_table_schema = 'test_schema' + class Post(BigAutoFieldMixin, models.Model): title = models.CharField('title', max_length=255) author = models.ForeignKey(Author, models.CASCADE) @@ -76,7 +77,6 @@ class TestUniqueNullableModel(models.Model): # but for a unique index (not db_index) y_renamed = models.IntegerField(null=True, unique=True) - class TestNullableUniqueTogetherModel(models.Model): class Meta: unique_together = (('a', 'b', 'c'),) diff --git a/testapp/settings.py b/testapp/settings.py index 23d6696d..d756eaf6 100644 --- a/testapp/settings.py +++ b/testapp/settings.py @@ -5,6 +5,8 @@ from django import VERSION +import django.db.models.options as options + BASE_DIR = Path(__file__).resolve().parent.parent DATABASES = { @@ -15,6 +17,7 @@ "PASSWORD": "MyPassword42", "HOST": "localhost", "PORT": "1433", + "SCHEMA" :"dbo", "OPTIONS": {"driver": "ODBC Driver 17 for SQL Server", "return_rows_bulk_insert": True}, }, 'other': { @@ -305,3 +308,6 @@ 'model_fields.test_jsonfield.TestQuerying.test_key_iregex', 'model_fields.test_jsonfield.TestQuerying.test_key_regex', ] + +if not 'db_table_schema' in options.DEFAULT_NAMES : + options.DEFAULT_NAMES = options.DEFAULT_NAMES + ('db_table_schema',) diff --git a/testapp/tests/test_queries.py b/testapp/tests/test_queries.py index 54267233..b24a64ce 100644 --- a/testapp/tests/test_queries.py +++ b/testapp/tests/test_queries.py @@ -13,7 +13,7 @@ def test_insert_into_table_with_trigger(self): ON [testapp_author] FOR INSERT AS - INSERT INTO [testapp_editor]([name]) VALUES ('Bar') + INSERT INTO [test_schema].[editor]([name]) VALUES ('Bar') """) try: