array.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. import json
  2. from django.contrib.postgres import lookups
  3. from django.contrib.postgres.forms import SimpleArrayField
  4. from django.contrib.postgres.validators import ArrayMaxLengthValidator
  5. from django.core import checks, exceptions
  6. from django.db.models import Field, Func, IntegerField, Transform, Value
  7. from django.db.models.fields.mixins import CheckFieldDefaultMixin
  8. from django.db.models.lookups import Exact, In
  9. from django.utils.translation import gettext_lazy as _
  10. from ..utils import prefix_validation_error
  11. from .utils import AttributeSetter
  12. __all__ = ["ArrayField"]
  13. class ArrayField(CheckFieldDefaultMixin, Field):
  14. empty_strings_allowed = False
  15. default_error_messages = {
  16. "item_invalid": _("Item %(nth)s in the array did not validate:"),
  17. "nested_array_mismatch": _("Nested arrays must have the same length."),
  18. }
  19. _default_hint = ("list", "[]")
  20. def __init__(self, base_field, size=None, **kwargs):
  21. self.base_field = base_field
  22. self.db_collation = getattr(self.base_field, "db_collation", None)
  23. self.size = size
  24. if self.size:
  25. self.default_validators = [
  26. *self.default_validators,
  27. ArrayMaxLengthValidator(self.size),
  28. ]
  29. # For performance, only add a from_db_value() method if the base field
  30. # implements it.
  31. if hasattr(self.base_field, "from_db_value"):
  32. self.from_db_value = self._from_db_value
  33. super().__init__(**kwargs)
  34. @property
  35. def model(self):
  36. try:
  37. return self.__dict__["model"]
  38. except KeyError:
  39. raise AttributeError(
  40. "'%s' object has no attribute 'model'" % self.__class__.__name__
  41. )
  42. @model.setter
  43. def model(self, model):
  44. self.__dict__["model"] = model
  45. self.base_field.model = model
  46. @classmethod
  47. def _choices_is_value(cls, value):
  48. return isinstance(value, (list, tuple)) or super()._choices_is_value(value)
  49. def check(self, **kwargs):
  50. errors = super().check(**kwargs)
  51. if self.base_field.remote_field:
  52. errors.append(
  53. checks.Error(
  54. "Base field for array cannot be a related field.",
  55. obj=self,
  56. id="postgres.E002",
  57. )
  58. )
  59. else:
  60. # Remove the field name checks as they are not needed here.
  61. base_checks = self.base_field.check()
  62. if base_checks:
  63. error_messages = "\n ".join(
  64. "%s (%s)" % (base_check.msg, base_check.id)
  65. for base_check in base_checks
  66. if isinstance(base_check, checks.Error)
  67. )
  68. if error_messages:
  69. errors.append(
  70. checks.Error(
  71. "Base field for array has errors:\n %s" % error_messages,
  72. obj=self,
  73. id="postgres.E001",
  74. )
  75. )
  76. warning_messages = "\n ".join(
  77. "%s (%s)" % (base_check.msg, base_check.id)
  78. for base_check in base_checks
  79. if isinstance(base_check, checks.Warning)
  80. )
  81. if warning_messages:
  82. errors.append(
  83. checks.Warning(
  84. "Base field for array has warnings:\n %s"
  85. % warning_messages,
  86. obj=self,
  87. id="postgres.W004",
  88. )
  89. )
  90. return errors
  91. def set_attributes_from_name(self, name):
  92. super().set_attributes_from_name(name)
  93. self.base_field.set_attributes_from_name(name)
  94. @property
  95. def description(self):
  96. return "Array of %s" % self.base_field.description
  97. def db_type(self, connection):
  98. size = self.size or ""
  99. return "%s[%s]" % (self.base_field.db_type(connection), size)
  100. def cast_db_type(self, connection):
  101. size = self.size or ""
  102. return "%s[%s]" % (self.base_field.cast_db_type(connection), size)
  103. def db_parameters(self, connection):
  104. db_params = super().db_parameters(connection)
  105. db_params["collation"] = self.db_collation
  106. return db_params
  107. def get_placeholder(self, value, compiler, connection):
  108. return "%s::{}".format(self.db_type(connection))
  109. def get_db_prep_value(self, value, connection, prepared=False):
  110. if isinstance(value, (list, tuple)):
  111. return [
  112. self.base_field.get_db_prep_value(i, connection, prepared=False)
  113. for i in value
  114. ]
  115. return value
  116. def deconstruct(self):
  117. name, path, args, kwargs = super().deconstruct()
  118. if path == "django.contrib.postgres.fields.array.ArrayField":
  119. path = "django.contrib.postgres.fields.ArrayField"
  120. kwargs.update(
  121. {
  122. "base_field": self.base_field.clone(),
  123. "size": self.size,
  124. }
  125. )
  126. return name, path, args, kwargs
  127. def to_python(self, value):
  128. if isinstance(value, str):
  129. # Assume we're deserializing
  130. vals = json.loads(value)
  131. value = [self.base_field.to_python(val) for val in vals]
  132. return value
  133. def _from_db_value(self, value, expression, connection):
  134. if value is None:
  135. return value
  136. return [
  137. self.base_field.from_db_value(item, expression, connection)
  138. for item in value
  139. ]
  140. def value_to_string(self, obj):
  141. values = []
  142. vals = self.value_from_object(obj)
  143. base_field = self.base_field
  144. for val in vals:
  145. if val is None:
  146. values.append(None)
  147. else:
  148. obj = AttributeSetter(base_field.attname, val)
  149. values.append(base_field.value_to_string(obj))
  150. return json.dumps(values)
  151. def get_transform(self, name):
  152. transform = super().get_transform(name)
  153. if transform:
  154. return transform
  155. if "_" not in name:
  156. try:
  157. index = int(name)
  158. except ValueError:
  159. pass
  160. else:
  161. index += 1 # postgres uses 1-indexing
  162. return IndexTransformFactory(index, self.base_field)
  163. try:
  164. start, end = name.split("_")
  165. start = int(start) + 1
  166. end = int(end) # don't add one here because postgres slices are weird
  167. except ValueError:
  168. pass
  169. else:
  170. return SliceTransformFactory(start, end)
  171. def validate(self, value, model_instance):
  172. super().validate(value, model_instance)
  173. for index, part in enumerate(value):
  174. try:
  175. self.base_field.validate(part, model_instance)
  176. except exceptions.ValidationError as error:
  177. raise prefix_validation_error(
  178. error,
  179. prefix=self.error_messages["item_invalid"],
  180. code="item_invalid",
  181. params={"nth": index + 1},
  182. )
  183. if isinstance(self.base_field, ArrayField):
  184. if len({len(i) for i in value}) > 1:
  185. raise exceptions.ValidationError(
  186. self.error_messages["nested_array_mismatch"],
  187. code="nested_array_mismatch",
  188. )
  189. def run_validators(self, value):
  190. super().run_validators(value)
  191. for index, part in enumerate(value):
  192. try:
  193. self.base_field.run_validators(part)
  194. except exceptions.ValidationError as error:
  195. raise prefix_validation_error(
  196. error,
  197. prefix=self.error_messages["item_invalid"],
  198. code="item_invalid",
  199. params={"nth": index + 1},
  200. )
  201. def formfield(self, **kwargs):
  202. return super().formfield(
  203. **{
  204. "form_class": SimpleArrayField,
  205. "base_field": self.base_field.formfield(),
  206. "max_length": self.size,
  207. **kwargs,
  208. }
  209. )
  210. class ArrayRHSMixin:
  211. def __init__(self, lhs, rhs):
  212. # Don't wrap arrays that contains only None values, psycopg doesn't
  213. # allow this.
  214. if isinstance(rhs, (tuple, list)) and any(self._rhs_not_none_values(rhs)):
  215. expressions = []
  216. for value in rhs:
  217. if not hasattr(value, "resolve_expression"):
  218. field = lhs.output_field
  219. value = Value(field.base_field.get_prep_value(value))
  220. expressions.append(value)
  221. rhs = Func(
  222. *expressions,
  223. function="ARRAY",
  224. template="%(function)s[%(expressions)s]",
  225. )
  226. super().__init__(lhs, rhs)
  227. def process_rhs(self, compiler, connection):
  228. rhs, rhs_params = super().process_rhs(compiler, connection)
  229. cast_type = self.lhs.output_field.cast_db_type(connection)
  230. return "%s::%s" % (rhs, cast_type), rhs_params
  231. def _rhs_not_none_values(self, rhs):
  232. for x in rhs:
  233. if isinstance(x, (list, tuple)):
  234. yield from self._rhs_not_none_values(x)
  235. elif x is not None:
  236. yield True
  237. @ArrayField.register_lookup
  238. class ArrayContains(ArrayRHSMixin, lookups.DataContains):
  239. pass
  240. @ArrayField.register_lookup
  241. class ArrayContainedBy(ArrayRHSMixin, lookups.ContainedBy):
  242. pass
  243. @ArrayField.register_lookup
  244. class ArrayExact(ArrayRHSMixin, Exact):
  245. pass
  246. @ArrayField.register_lookup
  247. class ArrayOverlap(ArrayRHSMixin, lookups.Overlap):
  248. pass
  249. @ArrayField.register_lookup
  250. class ArrayLenTransform(Transform):
  251. lookup_name = "len"
  252. output_field = IntegerField()
  253. def as_sql(self, compiler, connection):
  254. lhs, params = compiler.compile(self.lhs)
  255. # Distinguish NULL and empty arrays
  256. return (
  257. "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE "
  258. "coalesce(array_length(%(lhs)s, 1), 0) END"
  259. ) % {"lhs": lhs}, params * 2
  260. @ArrayField.register_lookup
  261. class ArrayInLookup(In):
  262. def get_prep_lookup(self):
  263. values = super().get_prep_lookup()
  264. if hasattr(values, "resolve_expression"):
  265. return values
  266. # In.process_rhs() expects values to be hashable, so convert lists
  267. # to tuples.
  268. prepared_values = []
  269. for value in values:
  270. if hasattr(value, "resolve_expression"):
  271. prepared_values.append(value)
  272. else:
  273. prepared_values.append(tuple(value))
  274. return prepared_values
  275. class IndexTransform(Transform):
  276. def __init__(self, index, base_field, *args, **kwargs):
  277. super().__init__(*args, **kwargs)
  278. self.index = index
  279. self.base_field = base_field
  280. def as_sql(self, compiler, connection):
  281. lhs, params = compiler.compile(self.lhs)
  282. if not lhs.endswith("]"):
  283. lhs = "(%s)" % lhs
  284. return "%s[%%s]" % lhs, (*params, self.index)
  285. @property
  286. def output_field(self):
  287. return self.base_field
  288. class IndexTransformFactory:
  289. def __init__(self, index, base_field):
  290. self.index = index
  291. self.base_field = base_field
  292. def __call__(self, *args, **kwargs):
  293. return IndexTransform(self.index, self.base_field, *args, **kwargs)
  294. class SliceTransform(Transform):
  295. def __init__(self, start, end, *args, **kwargs):
  296. super().__init__(*args, **kwargs)
  297. self.start = start
  298. self.end = end
  299. def as_sql(self, compiler, connection):
  300. lhs, params = compiler.compile(self.lhs)
  301. if not lhs.endswith("]"):
  302. lhs = "(%s)" % lhs
  303. return "%s[%%s:%%s]" % lhs, (*params, self.start, self.end)
  304. class SliceTransformFactory:
  305. def __init__(self, start, end):
  306. self.start = start
  307. self.end = end
  308. def __call__(self, *args, **kwargs):
  309. return SliceTransform(self.start, self.end, *args, **kwargs)