123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365 |
- import json
- from django.contrib.postgres import lookups
- from django.contrib.postgres.forms import SimpleArrayField
- from django.contrib.postgres.validators import ArrayMaxLengthValidator
- from django.core import checks, exceptions
- from django.db.models import Field, Func, IntegerField, Transform, Value
- from django.db.models.fields.mixins import CheckFieldDefaultMixin
- from django.db.models.lookups import Exact, In
- from django.utils.translation import gettext_lazy as _
- from ..utils import prefix_validation_error
- from .utils import AttributeSetter
- __all__ = ["ArrayField"]
- class ArrayField(CheckFieldDefaultMixin, Field):
- empty_strings_allowed = False
- default_error_messages = {
- "item_invalid": _("Item %(nth)s in the array did not validate:"),
- "nested_array_mismatch": _("Nested arrays must have the same length."),
- }
- _default_hint = ("list", "[]")
- def __init__(self, base_field, size=None, **kwargs):
- self.base_field = base_field
- self.db_collation = getattr(self.base_field, "db_collation", None)
- self.size = size
- if self.size:
- self.default_validators = [
- *self.default_validators,
- ArrayMaxLengthValidator(self.size),
- ]
- # For performance, only add a from_db_value() method if the base field
- # implements it.
- if hasattr(self.base_field, "from_db_value"):
- self.from_db_value = self._from_db_value
- super().__init__(**kwargs)
- @property
- def model(self):
- try:
- return self.__dict__["model"]
- except KeyError:
- raise AttributeError(
- "'%s' object has no attribute 'model'" % self.__class__.__name__
- )
- @model.setter
- def model(self, model):
- self.__dict__["model"] = model
- self.base_field.model = model
- @classmethod
- def _choices_is_value(cls, value):
- return isinstance(value, (list, tuple)) or super()._choices_is_value(value)
- def check(self, **kwargs):
- errors = super().check(**kwargs)
- if self.base_field.remote_field:
- errors.append(
- checks.Error(
- "Base field for array cannot be a related field.",
- obj=self,
- id="postgres.E002",
- )
- )
- else:
- # Remove the field name checks as they are not needed here.
- base_checks = self.base_field.check()
- if base_checks:
- error_messages = "\n ".join(
- "%s (%s)" % (base_check.msg, base_check.id)
- for base_check in base_checks
- if isinstance(base_check, checks.Error)
- )
- if error_messages:
- errors.append(
- checks.Error(
- "Base field for array has errors:\n %s" % error_messages,
- obj=self,
- id="postgres.E001",
- )
- )
- warning_messages = "\n ".join(
- "%s (%s)" % (base_check.msg, base_check.id)
- for base_check in base_checks
- if isinstance(base_check, checks.Warning)
- )
- if warning_messages:
- errors.append(
- checks.Warning(
- "Base field for array has warnings:\n %s"
- % warning_messages,
- obj=self,
- id="postgres.W004",
- )
- )
- return errors
- def set_attributes_from_name(self, name):
- super().set_attributes_from_name(name)
- self.base_field.set_attributes_from_name(name)
- @property
- def description(self):
- return "Array of %s" % self.base_field.description
- def db_type(self, connection):
- size = self.size or ""
- return "%s[%s]" % (self.base_field.db_type(connection), size)
- def cast_db_type(self, connection):
- size = self.size or ""
- return "%s[%s]" % (self.base_field.cast_db_type(connection), size)
- def db_parameters(self, connection):
- db_params = super().db_parameters(connection)
- db_params["collation"] = self.db_collation
- return db_params
- def get_placeholder(self, value, compiler, connection):
- return "%s::{}".format(self.db_type(connection))
- def get_db_prep_value(self, value, connection, prepared=False):
- if isinstance(value, (list, tuple)):
- return [
- self.base_field.get_db_prep_value(i, connection, prepared=False)
- for i in value
- ]
- return value
- def deconstruct(self):
- name, path, args, kwargs = super().deconstruct()
- if path == "django.contrib.postgres.fields.array.ArrayField":
- path = "django.contrib.postgres.fields.ArrayField"
- kwargs.update(
- {
- "base_field": self.base_field.clone(),
- "size": self.size,
- }
- )
- return name, path, args, kwargs
- def to_python(self, value):
- if isinstance(value, str):
- # Assume we're deserializing
- vals = json.loads(value)
- value = [self.base_field.to_python(val) for val in vals]
- return value
- def _from_db_value(self, value, expression, connection):
- if value is None:
- return value
- return [
- self.base_field.from_db_value(item, expression, connection)
- for item in value
- ]
- def value_to_string(self, obj):
- values = []
- vals = self.value_from_object(obj)
- base_field = self.base_field
- for val in vals:
- if val is None:
- values.append(None)
- else:
- obj = AttributeSetter(base_field.attname, val)
- values.append(base_field.value_to_string(obj))
- return json.dumps(values)
- def get_transform(self, name):
- transform = super().get_transform(name)
- if transform:
- return transform
- if "_" not in name:
- try:
- index = int(name)
- except ValueError:
- pass
- else:
- index += 1 # postgres uses 1-indexing
- return IndexTransformFactory(index, self.base_field)
- try:
- start, end = name.split("_")
- start = int(start) + 1
- end = int(end) # don't add one here because postgres slices are weird
- except ValueError:
- pass
- else:
- return SliceTransformFactory(start, end)
- def validate(self, value, model_instance):
- super().validate(value, model_instance)
- for index, part in enumerate(value):
- try:
- self.base_field.validate(part, model_instance)
- except exceptions.ValidationError as error:
- raise prefix_validation_error(
- error,
- prefix=self.error_messages["item_invalid"],
- code="item_invalid",
- params={"nth": index + 1},
- )
- if isinstance(self.base_field, ArrayField):
- if len({len(i) for i in value}) > 1:
- raise exceptions.ValidationError(
- self.error_messages["nested_array_mismatch"],
- code="nested_array_mismatch",
- )
- def run_validators(self, value):
- super().run_validators(value)
- for index, part in enumerate(value):
- try:
- self.base_field.run_validators(part)
- except exceptions.ValidationError as error:
- raise prefix_validation_error(
- error,
- prefix=self.error_messages["item_invalid"],
- code="item_invalid",
- params={"nth": index + 1},
- )
- def formfield(self, **kwargs):
- return super().formfield(
- **{
- "form_class": SimpleArrayField,
- "base_field": self.base_field.formfield(),
- "max_length": self.size,
- **kwargs,
- }
- )
- class ArrayRHSMixin:
- def __init__(self, lhs, rhs):
- # Don't wrap arrays that contains only None values, psycopg doesn't
- # allow this.
- if isinstance(rhs, (tuple, list)) and any(self._rhs_not_none_values(rhs)):
- expressions = []
- for value in rhs:
- if not hasattr(value, "resolve_expression"):
- field = lhs.output_field
- value = Value(field.base_field.get_prep_value(value))
- expressions.append(value)
- rhs = Func(
- *expressions,
- function="ARRAY",
- template="%(function)s[%(expressions)s]",
- )
- super().__init__(lhs, rhs)
- def process_rhs(self, compiler, connection):
- rhs, rhs_params = super().process_rhs(compiler, connection)
- cast_type = self.lhs.output_field.cast_db_type(connection)
- return "%s::%s" % (rhs, cast_type), rhs_params
- def _rhs_not_none_values(self, rhs):
- for x in rhs:
- if isinstance(x, (list, tuple)):
- yield from self._rhs_not_none_values(x)
- elif x is not None:
- yield True
- @ArrayField.register_lookup
- class ArrayContains(ArrayRHSMixin, lookups.DataContains):
- pass
- @ArrayField.register_lookup
- class ArrayContainedBy(ArrayRHSMixin, lookups.ContainedBy):
- pass
- @ArrayField.register_lookup
- class ArrayExact(ArrayRHSMixin, Exact):
- pass
- @ArrayField.register_lookup
- class ArrayOverlap(ArrayRHSMixin, lookups.Overlap):
- pass
- @ArrayField.register_lookup
- class ArrayLenTransform(Transform):
- lookup_name = "len"
- output_field = IntegerField()
- def as_sql(self, compiler, connection):
- lhs, params = compiler.compile(self.lhs)
- # Distinguish NULL and empty arrays
- return (
- "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE "
- "coalesce(array_length(%(lhs)s, 1), 0) END"
- ) % {"lhs": lhs}, params * 2
- @ArrayField.register_lookup
- class ArrayInLookup(In):
- def get_prep_lookup(self):
- values = super().get_prep_lookup()
- if hasattr(values, "resolve_expression"):
- return values
- # In.process_rhs() expects values to be hashable, so convert lists
- # to tuples.
- prepared_values = []
- for value in values:
- if hasattr(value, "resolve_expression"):
- prepared_values.append(value)
- else:
- prepared_values.append(tuple(value))
- return prepared_values
- class IndexTransform(Transform):
- def __init__(self, index, base_field, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.index = index
- self.base_field = base_field
- def as_sql(self, compiler, connection):
- lhs, params = compiler.compile(self.lhs)
- if not lhs.endswith("]"):
- lhs = "(%s)" % lhs
- return "%s[%%s]" % lhs, (*params, self.index)
- @property
- def output_field(self):
- return self.base_field
- class IndexTransformFactory:
- def __init__(self, index, base_field):
- self.index = index
- self.base_field = base_field
- def __call__(self, *args, **kwargs):
- return IndexTransform(self.index, self.base_field, *args, **kwargs)
- class SliceTransform(Transform):
- def __init__(self, start, end, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.start = start
- self.end = end
- def as_sql(self, compiler, connection):
- lhs, params = compiler.compile(self.lhs)
- if not lhs.endswith("]"):
- lhs = "(%s)" % lhs
- return "%s[%%s:%%s]" % lhs, (*params, self.start, self.end)
- class SliceTransformFactory:
- def __init__(self, start, end):
- self.start = start
- self.end = end
- def __call__(self, *args, **kwargs):
- return SliceTransform(self.start, self.end, *args, **kwargs)
|