constraints.py 9.4 KB


  1. from types import NoneType
  2. from django.contrib.postgres.indexes import OpClass
  3. from django.core.exceptions import ValidationError
  4. from django.db import DEFAULT_DB_ALIAS, NotSupportedError
  5. from django.db.backends.ddl_references import Expressions, Statement, Table
  6. from django.db.models import BaseConstraint, Deferrable, F, Q
  7. from django.db.models.expressions import Exists, ExpressionList
  8. from django.db.models.indexes import IndexExpression
  9. from django.db.models.lookups import PostgresOperatorLookup
  10. from django.db.models.sql import Query
  11. __all__ = ["ExclusionConstraint"]
  12. class ExclusionConstraintExpression(IndexExpression):
  13. template = "%(expressions)s WITH %(operator)s"
  14. class ExclusionConstraint(BaseConstraint):
  15. template = (
  16. "CONSTRAINT %(name)s EXCLUDE USING %(index_type)s "
  17. "(%(expressions)s)%(include)s%(where)s%(deferrable)s"
  18. )
  19. def __init__(
  20. self,
  21. *,
  22. name,
  23. expressions,
  24. index_type=None,
  25. condition=None,
  26. deferrable=None,
  27. include=None,
  28. violation_error_code=None,
  29. violation_error_message=None,
  30. ):
  31. if index_type and index_type.lower() not in {"gist", "spgist"}:
  32. raise ValueError(
  33. "Exclusion constraints only support GiST or SP-GiST indexes."
  34. )
  35. if not expressions:
  36. raise ValueError(
  37. "At least one expression is required to define an exclusion "
  38. "constraint."
  39. )
  40. if not all(
  41. isinstance(expr, (list, tuple)) and len(expr) == 2 for expr in expressions
  42. ):
  43. raise ValueError("The expressions must be a list of 2-tuples.")
  44. if not isinstance(condition, (NoneType, Q)):
  45. raise ValueError("ExclusionConstraint.condition must be a Q instance.")
  46. if not isinstance(deferrable, (NoneType, Deferrable)):
  47. raise ValueError(
  48. "ExclusionConstraint.deferrable must be a Deferrable instance."
  49. )
  50. if not isinstance(include, (NoneType, list, tuple)):
  51. raise ValueError("ExclusionConstraint.include must be a list or tuple.")
  52. self.expressions = expressions
  53. self.index_type = index_type or "GIST"
  54. self.condition = condition
  55. self.deferrable = deferrable
  56. self.include = tuple(include) if include else ()
  57. super().__init__(
  58. name=name,
  59. violation_error_code=violation_error_code,
  60. violation_error_message=violation_error_message,
  61. )
  62. def _get_expressions(self, schema_editor, query):
  63. expressions = []
  64. for idx, (expression, operator) in enumerate(self.expressions):
  65. if isinstance(expression, str):
  66. expression = F(expression)
  67. expression = ExclusionConstraintExpression(expression, operator=operator)
  68. expression.set_wrapper_classes(schema_editor.connection)
  69. expressions.append(expression)
  70. return ExpressionList(*expressions).resolve_expression(query)
  71. def _get_condition_sql(self, compiler, schema_editor, query):
  72. if self.condition is None:
  73. return None
  74. where = query.build_where(self.condition)
  75. sql, params = where.as_sql(compiler, schema_editor.connection)
  76. return sql % tuple(schema_editor.quote_value(p) for p in params)
  77. def constraint_sql(self, model, schema_editor):
  78. query = Query(model, alias_cols=False)
  79. compiler = query.get_compiler(connection=schema_editor.connection)
  80. expressions = self._get_expressions(schema_editor, query)
  81. table = model._meta.db_table
  82. condition = self._get_condition_sql(compiler, schema_editor, query)
  83. include = [
  84. model._meta.get_field(field_name).column for field_name in self.include
  85. ]
  86. return Statement(
  87. self.template,
  88. table=Table(table, schema_editor.quote_name),
  89. name=schema_editor.quote_name(self.name),
  90. index_type=self.index_type,
  91. expressions=Expressions(
  92. table, expressions, compiler, schema_editor.quote_value
  93. ),
  94. where=" WHERE (%s)" % condition if condition else "",
  95. include=schema_editor._index_include_sql(model, include),
  96. deferrable=schema_editor._deferrable_constraint_sql(self.deferrable),
  97. )
  98. def create_sql(self, model, schema_editor):
  99. self.check_supported(schema_editor)
  100. return Statement(
  101. "ALTER TABLE %(table)s ADD %(constraint)s",
  102. table=Table(model._meta.db_table, schema_editor.quote_name),
  103. constraint=self.constraint_sql(model, schema_editor),
  104. )
  105. def remove_sql(self, model, schema_editor):
  106. return schema_editor._delete_constraint_sql(
  107. schema_editor.sql_delete_check,
  108. model,
  109. schema_editor.quote_name(self.name),
  110. )
  111. def check_supported(self, schema_editor):
  112. if (
  113. self.include
  114. and self.index_type.lower() == "spgist"
  115. and not schema_editor.connection.features.supports_covering_spgist_indexes
  116. ):
  117. raise NotSupportedError(
  118. "Covering exclusion constraints using an SP-GiST index "
  119. "require PostgreSQL 14+."
  120. )
  121. def deconstruct(self):
  122. path, args, kwargs = super().deconstruct()
  123. kwargs["expressions"] = self.expressions
  124. if self.condition is not None:
  125. kwargs["condition"] = self.condition
  126. if self.index_type.lower() != "gist":
  127. kwargs["index_type"] = self.index_type
  128. if self.deferrable:
  129. kwargs["deferrable"] = self.deferrable
  130. if self.include:
  131. kwargs["include"] = self.include
  132. return path, args, kwargs
  133. def __eq__(self, other):
  134. if isinstance(other, self.__class__):
  135. return (
  136. self.name == other.name
  137. and self.index_type == other.index_type
  138. and self.expressions == other.expressions
  139. and self.condition == other.condition
  140. and self.deferrable == other.deferrable
  141. and self.include == other.include
  142. and self.violation_error_code == other.violation_error_code
  143. and self.violation_error_message == other.violation_error_message
  144. )
  145. return super().__eq__(other)
  146. def __repr__(self):
  147. return "<%s: index_type=%s expressions=%s name=%s%s%s%s%s%s>" % (
  148. self.__class__.__qualname__,
  149. repr(self.index_type),
  150. repr(self.expressions),
  151. repr(self.name),
  152. "" if self.condition is None else " condition=%s" % self.condition,
  153. "" if self.deferrable is None else " deferrable=%r" % self.deferrable,
  154. "" if not self.include else " include=%s" % repr(self.include),
  155. (
  156. ""
  157. if self.violation_error_code is None
  158. else " violation_error_code=%r" % self.violation_error_code
  159. ),
  160. (
  161. ""
  162. if self.violation_error_message is None
  163. or self.violation_error_message == self.default_violation_error_message
  164. else " violation_error_message=%r" % self.violation_error_message
  165. ),
  166. )
  167. def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
  168. queryset = model._default_manager.using(using)
  169. replacement_map = instance._get_field_value_map(
  170. meta=model._meta, exclude=exclude
  171. )
  172. replacements = {F(field): value for field, value in replacement_map.items()}
  173. lookups = []
  174. for idx, (expression, operator) in enumerate(self.expressions):
  175. if isinstance(expression, str):
  176. expression = F(expression)
  177. if exclude:
  178. if isinstance(expression, F):
  179. if expression.name in exclude:
  180. return
  181. else:
  182. for expr in expression.flatten():
  183. if isinstance(expr, F) and expr.name in exclude:
  184. return
  185. rhs_expression = expression.replace_expressions(replacements)
  186. # Remove OpClass because it only has sense during the constraint
  187. # creation.
  188. if isinstance(expression, OpClass):
  189. expression = expression.get_source_expressions()[0]
  190. if isinstance(rhs_expression, OpClass):
  191. rhs_expression = rhs_expression.get_source_expressions()[0]
  192. lookup = PostgresOperatorLookup(lhs=expression, rhs=rhs_expression)
  193. lookup.postgres_operator = operator
  194. lookups.append(lookup)
  195. queryset = queryset.filter(*lookups)
  196. model_class_pk = instance._get_pk_val(model._meta)
  197. if not instance._state.adding and model_class_pk is not None:
  198. queryset = queryset.exclude(pk=model_class_pk)
  199. if not self.condition:
  200. if queryset.exists():
  201. raise ValidationError(
  202. self.get_violation_error_message(), code=self.violation_error_code
  203. )
  204. else:
  205. if (self.condition & Exists(queryset.filter(self.condition))).check(
  206. replacement_map, using=using
  207. ):
  208. raise ValidationError(
  209. self.get_violation_error_message(), code=self.violation_error_code
  210. )