ranges.py 11 KB


  1. import datetime
  2. import json
  3. from django.contrib.postgres import forms, lookups
  4. from django.db import models
  5. from django.db.backends.postgresql.psycopg_any import (
  6. DateRange,
  7. DateTimeTZRange,
  8. NumericRange,
  9. Range,
  10. )
  11. from django.db.models.functions import Cast
  12. from django.db.models.lookups import PostgresOperatorLookup
  13. from .utils import AttributeSetter
  14. __all__ = [
  15. "RangeField",
  16. "IntegerRangeField",
  17. "BigIntegerRangeField",
  18. "DecimalRangeField",
  19. "DateTimeRangeField",
  20. "DateRangeField",
  21. "RangeBoundary",
  22. "RangeOperators",
  23. ]
  24. class RangeBoundary(models.Expression):
  25. """A class that represents range boundaries."""
  26. def __init__(self, inclusive_lower=True, inclusive_upper=False):
  27. self.lower = "[" if inclusive_lower else "("
  28. self.upper = "]" if inclusive_upper else ")"
  29. def as_sql(self, compiler, connection):
  30. return "'%s%s'" % (self.lower, self.upper), []
  31. class RangeOperators:
  32. # https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE
  33. EQUAL = "="
  34. NOT_EQUAL = "<>"
  35. CONTAINS = "@>"
  36. CONTAINED_BY = "<@"
  37. OVERLAPS = "&&"
  38. FULLY_LT = "<<"
  39. FULLY_GT = ">>"
  40. NOT_LT = "&>"
  41. NOT_GT = "&<"
  42. ADJACENT_TO = "-|-"
  43. class RangeField(models.Field):
  44. empty_strings_allowed = False
  45. def __init__(self, *args, **kwargs):
  46. if "default_bounds" in kwargs:
  47. raise TypeError(
  48. f"Cannot use 'default_bounds' with {self.__class__.__name__}."
  49. )
  50. # Initializing base_field here ensures that its model matches the model
  51. # for self.
  52. if hasattr(self, "base_field"):
  53. self.base_field = self.base_field()
  54. super().__init__(*args, **kwargs)
  55. @property
  56. def model(self):
  57. try:
  58. return self.__dict__["model"]
  59. except KeyError:
  60. raise AttributeError(
  61. "'%s' object has no attribute 'model'" % self.__class__.__name__
  62. )
  63. @model.setter
  64. def model(self, model):
  65. self.__dict__["model"] = model
  66. self.base_field.model = model
  67. @classmethod
  68. def _choices_is_value(cls, value):
  69. return isinstance(value, (list, tuple)) or super()._choices_is_value(value)
  70. def get_placeholder(self, value, compiler, connection):
  71. return "%s::{}".format(self.db_type(connection))
  72. def get_prep_value(self, value):
  73. if value is None:
  74. return None
  75. elif isinstance(value, Range):
  76. return value
  77. elif isinstance(value, (list, tuple)):
  78. return self.range_type(value[0], value[1])
  79. return value
  80. def to_python(self, value):
  81. if isinstance(value, str):
  82. # Assume we're deserializing
  83. vals = json.loads(value)
  84. for end in ("lower", "upper"):
  85. if end in vals:
  86. vals[end] = self.base_field.to_python(vals[end])
  87. value = self.range_type(**vals)
  88. elif isinstance(value, (list, tuple)):
  89. value = self.range_type(value[0], value[1])
  90. return value
  91. def set_attributes_from_name(self, name):
  92. super().set_attributes_from_name(name)
  93. self.base_field.set_attributes_from_name(name)
  94. def value_to_string(self, obj):
  95. value = self.value_from_object(obj)
  96. if value is None:
  97. return None
  98. if value.isempty:
  99. return json.dumps({"empty": True})
  100. base_field = self.base_field
  101. result = {"bounds": value._bounds}
  102. for end in ("lower", "upper"):
  103. val = getattr(value, end)
  104. if val is None:
  105. result[end] = None
  106. else:
  107. obj = AttributeSetter(base_field.attname, val)
  108. result[end] = base_field.value_to_string(obj)
  109. return json.dumps(result)
  110. def formfield(self, **kwargs):
  111. kwargs.setdefault("form_class", self.form_field)
  112. return super().formfield(**kwargs)
  113. CANONICAL_RANGE_BOUNDS = "[)"
  114. class ContinuousRangeField(RangeField):
  115. """
  116. Continuous range field. It allows specifying default bounds for list and
  117. tuple inputs.
  118. """
  119. def __init__(self, *args, default_bounds=CANONICAL_RANGE_BOUNDS, **kwargs):
  120. if default_bounds not in ("[)", "(]", "()", "[]"):
  121. raise ValueError("default_bounds must be one of '[)', '(]', '()', or '[]'.")
  122. self.default_bounds = default_bounds
  123. super().__init__(*args, **kwargs)
  124. def get_prep_value(self, value):
  125. if isinstance(value, (list, tuple)):
  126. return self.range_type(value[0], value[1], self.default_bounds)
  127. return super().get_prep_value(value)
  128. def formfield(self, **kwargs):
  129. kwargs.setdefault("default_bounds", self.default_bounds)
  130. return super().formfield(**kwargs)
  131. def deconstruct(self):
  132. name, path, args, kwargs = super().deconstruct()
  133. if self.default_bounds and self.default_bounds != CANONICAL_RANGE_BOUNDS:
  134. kwargs["default_bounds"] = self.default_bounds
  135. return name, path, args, kwargs
  136. class IntegerRangeField(RangeField):
  137. base_field = models.IntegerField
  138. range_type = NumericRange
  139. form_field = forms.IntegerRangeField
  140. def db_type(self, connection):
  141. return "int4range"
  142. class BigIntegerRangeField(RangeField):
  143. base_field = models.BigIntegerField
  144. range_type = NumericRange
  145. form_field = forms.IntegerRangeField
  146. def db_type(self, connection):
  147. return "int8range"
  148. class DecimalRangeField(ContinuousRangeField):
  149. base_field = models.DecimalField
  150. range_type = NumericRange
  151. form_field = forms.DecimalRangeField
  152. def db_type(self, connection):
  153. return "numrange"
  154. class DateTimeRangeField(ContinuousRangeField):
  155. base_field = models.DateTimeField
  156. range_type = DateTimeTZRange
  157. form_field = forms.DateTimeRangeField
  158. def db_type(self, connection):
  159. return "tstzrange"
  160. class DateRangeField(RangeField):
  161. base_field = models.DateField
  162. range_type = DateRange
  163. form_field = forms.DateRangeField
  164. def db_type(self, connection):
  165. return "daterange"
  166. class RangeContains(lookups.DataContains):
  167. def get_prep_lookup(self):
  168. if not isinstance(self.rhs, (list, tuple, Range)):
  169. return Cast(self.rhs, self.lhs.field.base_field)
  170. return super().get_prep_lookup()
  171. RangeField.register_lookup(RangeContains)
  172. RangeField.register_lookup(lookups.ContainedBy)
  173. RangeField.register_lookup(lookups.Overlap)
  174. class DateTimeRangeContains(PostgresOperatorLookup):
  175. """
  176. Lookup for Date/DateTimeRange containment to cast the rhs to the correct
  177. type.
  178. """
  179. lookup_name = "contains"
  180. postgres_operator = RangeOperators.CONTAINS
  181. def process_rhs(self, compiler, connection):
  182. # Transform rhs value for db lookup.
  183. if isinstance(self.rhs, datetime.date):
  184. value = models.Value(self.rhs)
  185. self.rhs = value.resolve_expression(compiler.query)
  186. return super().process_rhs(compiler, connection)
  187. def as_postgresql(self, compiler, connection):
  188. sql, params = super().as_postgresql(compiler, connection)
  189. # Cast the rhs if needed.
  190. cast_sql = ""
  191. if (
  192. isinstance(self.rhs, models.Expression)
  193. and self.rhs._output_field_or_none
  194. and
  195. # Skip cast if rhs has a matching range type.
  196. not isinstance(
  197. self.rhs._output_field_or_none, self.lhs.output_field.__class__
  198. )
  199. ):
  200. cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
  201. cast_sql = "::{}".format(connection.data_types.get(cast_internal_type))
  202. return "%s%s" % (sql, cast_sql), params
  203. DateRangeField.register_lookup(DateTimeRangeContains)
  204. DateTimeRangeField.register_lookup(DateTimeRangeContains)
  205. class RangeContainedBy(PostgresOperatorLookup):
  206. lookup_name = "contained_by"
  207. type_mapping = {
  208. "smallint": "int4range",
  209. "integer": "int4range",
  210. "bigint": "int8range",
  211. "double precision": "numrange",
  212. "numeric": "numrange",
  213. "date": "daterange",
  214. "timestamp with time zone": "tstzrange",
  215. }
  216. postgres_operator = RangeOperators.CONTAINED_BY
  217. def process_rhs(self, compiler, connection):
  218. rhs, rhs_params = super().process_rhs(compiler, connection)
  219. # Ignore precision for DecimalFields.
  220. db_type = self.lhs.output_field.cast_db_type(connection).split("(")[0]
  221. cast_type = self.type_mapping[db_type]
  222. return "%s::%s" % (rhs, cast_type), rhs_params
  223. def process_lhs(self, compiler, connection):
  224. lhs, lhs_params = super().process_lhs(compiler, connection)
  225. if isinstance(self.lhs.output_field, models.FloatField):
  226. lhs = "%s::numeric" % lhs
  227. elif isinstance(self.lhs.output_field, models.SmallIntegerField):
  228. lhs = "%s::integer" % lhs
  229. return lhs, lhs_params
  230. def get_prep_lookup(self):
  231. return RangeField().get_prep_value(self.rhs)
  232. models.DateField.register_lookup(RangeContainedBy)
  233. models.DateTimeField.register_lookup(RangeContainedBy)
  234. models.IntegerField.register_lookup(RangeContainedBy)
  235. models.FloatField.register_lookup(RangeContainedBy)
  236. models.DecimalField.register_lookup(RangeContainedBy)
  237. @RangeField.register_lookup
  238. class FullyLessThan(PostgresOperatorLookup):
  239. lookup_name = "fully_lt"
  240. postgres_operator = RangeOperators.FULLY_LT
  241. @RangeField.register_lookup
  242. class FullGreaterThan(PostgresOperatorLookup):
  243. lookup_name = "fully_gt"
  244. postgres_operator = RangeOperators.FULLY_GT
  245. @RangeField.register_lookup
  246. class NotLessThan(PostgresOperatorLookup):
  247. lookup_name = "not_lt"
  248. postgres_operator = RangeOperators.NOT_LT
  249. @RangeField.register_lookup
  250. class NotGreaterThan(PostgresOperatorLookup):
  251. lookup_name = "not_gt"
  252. postgres_operator = RangeOperators.NOT_GT
  253. @RangeField.register_lookup
  254. class AdjacentToLookup(PostgresOperatorLookup):
  255. lookup_name = "adjacent_to"
  256. postgres_operator = RangeOperators.ADJACENT_TO
  257. @RangeField.register_lookup
  258. class RangeStartsWith(models.Transform):
  259. lookup_name = "startswith"
  260. function = "lower"
  261. @property
  262. def output_field(self):
  263. return self.lhs.output_field.base_field
  264. @RangeField.register_lookup
  265. class RangeEndsWith(models.Transform):
  266. lookup_name = "endswith"
  267. function = "upper"
  268. @property
  269. def output_field(self):
  270. return self.lhs.output_field.base_field
  271. @RangeField.register_lookup
  272. class IsEmpty(models.Transform):
  273. lookup_name = "isempty"
  274. function = "isempty"
  275. output_field = models.BooleanField()
  276. @RangeField.register_lookup
  277. class LowerInclusive(models.Transform):
  278. lookup_name = "lower_inc"
  279. function = "LOWER_INC"
  280. output_field = models.BooleanField()
  281. @RangeField.register_lookup
  282. class LowerInfinite(models.Transform):
  283. lookup_name = "lower_inf"
  284. function = "LOWER_INF"
  285. output_field = models.BooleanField()
  286. @RangeField.register_lookup
  287. class UpperInclusive(models.Transform):
  288. lookup_name = "upper_inc"
  289. function = "UPPER_INC"
  290. output_field = models.BooleanField()
  291. @RangeField.register_lookup
  292. class UpperInfinite(models.Transform):
  293. lookup_name = "upper_inf"
  294. function = "UPPER_INF"
  295. output_field = models.BooleanField()