operations.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. from django.contrib.postgres.signals import (
  2. get_citext_oids,
  3. get_hstore_oids,
  4. register_type_handlers,
  5. )
  6. from django.db import NotSupportedError, router
  7. from django.db.migrations import AddConstraint, AddIndex, RemoveIndex
  8. from django.db.migrations.operations.base import Operation
  9. from django.db.models.constraints import CheckConstraint
  10. class CreateExtension(Operation):
  11. reversible = True
  12. def __init__(self, name):
  13. self.name = name
  14. def state_forwards(self, app_label, state):
  15. pass
  16. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  17. if schema_editor.connection.vendor != "postgresql" or not router.allow_migrate(
  18. schema_editor.connection.alias, app_label
  19. ):
  20. return
  21. if not self.extension_exists(schema_editor, self.name):
  22. schema_editor.execute(
  23. "CREATE EXTENSION IF NOT EXISTS %s"
  24. % schema_editor.quote_name(self.name)
  25. )
  26. # Clear cached, stale oids.
  27. get_hstore_oids.cache_clear()
  28. get_citext_oids.cache_clear()
  29. # Registering new type handlers cannot be done before the extension is
  30. # installed, otherwise a subsequent data migration would use the same
  31. # connection.
  32. register_type_handlers(schema_editor.connection)
  33. if hasattr(schema_editor.connection, "register_geometry_adapters"):
  34. schema_editor.connection.register_geometry_adapters(
  35. schema_editor.connection.connection, True
  36. )
  37. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  38. if not router.allow_migrate(schema_editor.connection.alias, app_label):
  39. return
  40. if self.extension_exists(schema_editor, self.name):
  41. schema_editor.execute(
  42. "DROP EXTENSION IF EXISTS %s" % schema_editor.quote_name(self.name)
  43. )
  44. # Clear cached, stale oids.
  45. get_hstore_oids.cache_clear()
  46. get_citext_oids.cache_clear()
  47. def extension_exists(self, schema_editor, extension):
  48. with schema_editor.connection.cursor() as cursor:
  49. cursor.execute(
  50. "SELECT 1 FROM pg_extension WHERE extname = %s",
  51. [extension],
  52. )
  53. return bool(cursor.fetchone())
  54. def describe(self):
  55. return "Creates extension %s" % self.name
  56. @property
  57. def migration_name_fragment(self):
  58. return "create_extension_%s" % self.name
  59. class BloomExtension(CreateExtension):
  60. def __init__(self):
  61. self.name = "bloom"
  62. class BtreeGinExtension(CreateExtension):
  63. def __init__(self):
  64. self.name = "btree_gin"
  65. class BtreeGistExtension(CreateExtension):
  66. def __init__(self):
  67. self.name = "btree_gist"
  68. class CITextExtension(CreateExtension):
  69. def __init__(self):
  70. self.name = "citext"
  71. class CryptoExtension(CreateExtension):
  72. def __init__(self):
  73. self.name = "pgcrypto"
  74. class HStoreExtension(CreateExtension):
  75. def __init__(self):
  76. self.name = "hstore"
  77. class TrigramExtension(CreateExtension):
  78. def __init__(self):
  79. self.name = "pg_trgm"
  80. class UnaccentExtension(CreateExtension):
  81. def __init__(self):
  82. self.name = "unaccent"
  83. class NotInTransactionMixin:
  84. def _ensure_not_in_transaction(self, schema_editor):
  85. if schema_editor.connection.in_atomic_block:
  86. raise NotSupportedError(
  87. "The %s operation cannot be executed inside a transaction "
  88. "(set atomic = False on the migration)." % self.__class__.__name__
  89. )
  90. class AddIndexConcurrently(NotInTransactionMixin, AddIndex):
  91. """Create an index using PostgreSQL's CREATE INDEX CONCURRENTLY syntax."""
  92. atomic = False
  93. def describe(self):
  94. return "Concurrently create index %s on field(s) %s of model %s" % (
  95. self.index.name,
  96. ", ".join(self.index.fields),
  97. self.model_name,
  98. )
  99. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  100. self._ensure_not_in_transaction(schema_editor)
  101. model = to_state.apps.get_model(app_label, self.model_name)
  102. if self.allow_migrate_model(schema_editor.connection.alias, model):
  103. schema_editor.add_index(model, self.index, concurrently=True)
  104. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  105. self._ensure_not_in_transaction(schema_editor)
  106. model = from_state.apps.get_model(app_label, self.model_name)
  107. if self.allow_migrate_model(schema_editor.connection.alias, model):
  108. schema_editor.remove_index(model, self.index, concurrently=True)
  109. class RemoveIndexConcurrently(NotInTransactionMixin, RemoveIndex):
  110. """Remove an index using PostgreSQL's DROP INDEX CONCURRENTLY syntax."""
  111. atomic = False
  112. def describe(self):
  113. return "Concurrently remove index %s from %s" % (self.name, self.model_name)
  114. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  115. self._ensure_not_in_transaction(schema_editor)
  116. model = from_state.apps.get_model(app_label, self.model_name)
  117. if self.allow_migrate_model(schema_editor.connection.alias, model):
  118. from_model_state = from_state.models[app_label, self.model_name_lower]
  119. index = from_model_state.get_index_by_name(self.name)
  120. schema_editor.remove_index(model, index, concurrently=True)
  121. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  122. self._ensure_not_in_transaction(schema_editor)
  123. model = to_state.apps.get_model(app_label, self.model_name)
  124. if self.allow_migrate_model(schema_editor.connection.alias, model):
  125. to_model_state = to_state.models[app_label, self.model_name_lower]
  126. index = to_model_state.get_index_by_name(self.name)
  127. schema_editor.add_index(model, index, concurrently=True)
  128. class CollationOperation(Operation):
  129. def __init__(self, name, locale, *, provider="libc", deterministic=True):
  130. self.name = name
  131. self.locale = locale
  132. self.provider = provider
  133. self.deterministic = deterministic
  134. def state_forwards(self, app_label, state):
  135. pass
  136. def deconstruct(self):
  137. kwargs = {"name": self.name, "locale": self.locale}
  138. if self.provider and self.provider != "libc":
  139. kwargs["provider"] = self.provider
  140. if self.deterministic is False:
  141. kwargs["deterministic"] = self.deterministic
  142. return (
  143. self.__class__.__qualname__,
  144. [],
  145. kwargs,
  146. )
  147. def create_collation(self, schema_editor):
  148. args = {"locale": schema_editor.quote_name(self.locale)}
  149. if self.provider != "libc":
  150. args["provider"] = schema_editor.quote_name(self.provider)
  151. if self.deterministic is False:
  152. args["deterministic"] = "false"
  153. schema_editor.execute(
  154. "CREATE COLLATION %(name)s (%(args)s)"
  155. % {
  156. "name": schema_editor.quote_name(self.name),
  157. "args": ", ".join(
  158. f"{option}={value}" for option, value in args.items()
  159. ),
  160. }
  161. )
  162. def remove_collation(self, schema_editor):
  163. schema_editor.execute(
  164. "DROP COLLATION %s" % schema_editor.quote_name(self.name),
  165. )
  166. class CreateCollation(CollationOperation):
  167. """Create a collation."""
  168. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  169. if schema_editor.connection.vendor != "postgresql" or not router.allow_migrate(
  170. schema_editor.connection.alias, app_label
  171. ):
  172. return
  173. self.create_collation(schema_editor)
  174. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  175. if not router.allow_migrate(schema_editor.connection.alias, app_label):
  176. return
  177. self.remove_collation(schema_editor)
  178. def describe(self):
  179. return f"Create collation {self.name}"
  180. @property
  181. def migration_name_fragment(self):
  182. return "create_collation_%s" % self.name.lower()
  183. class RemoveCollation(CollationOperation):
  184. """Remove a collation."""
  185. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  186. if schema_editor.connection.vendor != "postgresql" or not router.allow_migrate(
  187. schema_editor.connection.alias, app_label
  188. ):
  189. return
  190. self.remove_collation(schema_editor)
  191. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  192. if not router.allow_migrate(schema_editor.connection.alias, app_label):
  193. return
  194. self.create_collation(schema_editor)
  195. def describe(self):
  196. return f"Remove collation {self.name}"
  197. @property
  198. def migration_name_fragment(self):
  199. return "remove_collation_%s" % self.name.lower()
  200. class AddConstraintNotValid(AddConstraint):
  201. """
  202. Add a table constraint without enforcing validation, using PostgreSQL's
  203. NOT VALID syntax.
  204. """
  205. def __init__(self, model_name, constraint):
  206. if not isinstance(constraint, CheckConstraint):
  207. raise TypeError(
  208. "AddConstraintNotValid.constraint must be a check constraint."
  209. )
  210. super().__init__(model_name, constraint)
  211. def describe(self):
  212. return "Create not valid constraint %s on model %s" % (
  213. self.constraint.name,
  214. self.model_name,
  215. )
  216. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  217. model = from_state.apps.get_model(app_label, self.model_name)
  218. if self.allow_migrate_model(schema_editor.connection.alias, model):
  219. constraint_sql = self.constraint.create_sql(model, schema_editor)
  220. if constraint_sql:
  221. # Constraint.create_sql returns interpolated SQL which makes
  222. # params=None a necessity to avoid escaping attempts on
  223. # execution.
  224. schema_editor.execute(str(constraint_sql) + " NOT VALID", params=None)
  225. @property
  226. def migration_name_fragment(self):
  227. return super().migration_name_fragment + "_not_valid"
  228. class ValidateConstraint(Operation):
  229. """Validate a table NOT VALID constraint."""
  230. def __init__(self, model_name, name):
  231. self.model_name = model_name
  232. self.name = name
  233. def describe(self):
  234. return "Validate constraint %s on model %s" % (self.name, self.model_name)
  235. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  236. model = from_state.apps.get_model(app_label, self.model_name)
  237. if self.allow_migrate_model(schema_editor.connection.alias, model):
  238. schema_editor.execute(
  239. "ALTER TABLE %s VALIDATE CONSTRAINT %s"
  240. % (
  241. schema_editor.quote_name(model._meta.db_table),
  242. schema_editor.quote_name(self.name),
  243. )
  244. )
  245. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  246. # PostgreSQL does not provide a way to make a constraint invalid.
  247. pass
  248. def state_forwards(self, app_label, state):
  249. pass
  250. @property
  251. def migration_name_fragment(self):
  252. return "%s_validate_%s" % (self.model_name.lower(), self.name.lower())
  253. def deconstruct(self):
  254. return (
  255. self.__class__.__name__,
  256. [],
  257. {
  258. "model_name": self.model_name,
  259. "name": self.name,
  260. },
  261. )