choices.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from collections.abc import Callable, Iterable, Iterator, Mapping
  2. from itertools import islice, tee, zip_longest
  3. from django.utils.functional import Promise
  4. __all__ = [
  5. "BaseChoiceIterator",
  6. "BlankChoiceIterator",
  7. "CallableChoiceIterator",
  8. "flatten_choices",
  9. "normalize_choices",
  10. ]
  11. class BaseChoiceIterator:
  12. """Base class for lazy iterators for choices."""
  13. def __eq__(self, other):
  14. if isinstance(other, Iterable):
  15. return all(a == b for a, b in zip_longest(self, other, fillvalue=object()))
  16. return super().__eq__(other)
  17. def __getitem__(self, index):
  18. if index < 0:
  19. # Suboptimally consume whole iterator to handle negative index.
  20. return list(self)[index]
  21. try:
  22. return next(islice(self, index, index + 1))
  23. except StopIteration:
  24. raise IndexError("index out of range") from None
  25. def __iter__(self):
  26. raise NotImplementedError(
  27. "BaseChoiceIterator subclasses must implement __iter__()."
  28. )
  29. class BlankChoiceIterator(BaseChoiceIterator):
  30. """Iterator to lazily inject a blank choice."""
  31. def __init__(self, choices, blank_choice):
  32. self.choices = choices
  33. self.blank_choice = blank_choice
  34. def __iter__(self):
  35. choices, other = tee(self.choices)
  36. if not any(value in ("", None) for value, _ in flatten_choices(other)):
  37. yield from self.blank_choice
  38. yield from choices
  39. class CallableChoiceIterator(BaseChoiceIterator):
  40. """Iterator to lazily normalize choices generated by a callable."""
  41. def __init__(self, func):
  42. self.func = func
  43. def __iter__(self):
  44. yield from normalize_choices(self.func())
  45. def flatten_choices(choices):
  46. """Flatten choices by removing nested values."""
  47. for value_or_group, label_or_nested in choices or ():
  48. if isinstance(label_or_nested, (list, tuple)):
  49. yield from label_or_nested
  50. else:
  51. yield value_or_group, label_or_nested
  52. def normalize_choices(value, *, depth=0):
  53. """Normalize choices values consistently for fields and widgets."""
  54. # Avoid circular import when importing django.forms.
  55. from django.db.models.enums import ChoicesType
  56. match value:
  57. case BaseChoiceIterator() | Promise() | bytes() | str():
  58. # Avoid prematurely normalizing iterators that should be lazy.
  59. # Because string-like types are iterable, return early to avoid
  60. # iterating over them in the guard for the Iterable case below.
  61. return value
  62. case ChoicesType():
  63. # Choices enumeration helpers already output in canonical form.
  64. return value.choices
  65. case Mapping() if depth < 2:
  66. value = value.items()
  67. case Iterator() if depth < 2:
  68. # Although Iterator would be handled by the Iterable case below,
  69. # the iterator would be consumed prematurely while checking that
  70. # its elements are not string-like in the guard, so we handle it
  71. # separately.
  72. pass
  73. case Iterable() if depth < 2 and not any(
  74. isinstance(x, (Promise, bytes, str)) for x in value
  75. ):
  76. # String-like types are iterable, so the guard above ensures that
  77. # they're handled by the default case below.
  78. pass
  79. case Callable() if depth == 0:
  80. # If at the top level, wrap callables to be evaluated lazily.
  81. return CallableChoiceIterator(value)
  82. case Callable() if depth < 2:
  83. value = value()
  84. case _:
  85. return value
  86. try:
  87. # Recursive call to convert any nested values to a list of 2-tuples.
  88. return [(k, normalize_choices(v, depth=depth + 1)) for k, v in value]
  89. except (TypeError, ValueError):
  90. # Return original value for the system check to raise if it has items
  91. # that are not iterable or not 2-tuples:
  92. # - TypeError: cannot unpack non-iterable <type> object
  93. # - ValueError: <not enough / too many> values to unpack
  94. return value