sync.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. import asyncio
  2. import asyncio.coroutines
  3. import contextvars
  4. import functools
  5. import inspect
  6. import os
  7. import sys
  8. import threading
  9. import warnings
  10. import weakref
  11. from concurrent.futures import Future, ThreadPoolExecutor
  12. from typing import (
  13. TYPE_CHECKING,
  14. Any,
  15. Awaitable,
  16. Callable,
  17. Coroutine,
  18. Dict,
  19. Generic,
  20. List,
  21. Optional,
  22. TypeVar,
  23. Union,
  24. overload,
  25. )
  26. from .current_thread_executor import CurrentThreadExecutor
  27. from .local import Local
  28. if sys.version_info >= (3, 10):
  29. from typing import ParamSpec
  30. else:
  31. from typing_extensions import ParamSpec
  32. if TYPE_CHECKING:
  33. # This is not available to import at runtime
  34. from _typeshed import OptExcInfo
  35. _F = TypeVar("_F", bound=Callable[..., Any])
  36. _P = ParamSpec("_P")
  37. _R = TypeVar("_R")
  38. def _restore_context(context: contextvars.Context) -> None:
  39. # Check for changes in contextvars, and set them to the current
  40. # context for downstream consumers
  41. for cvar in context:
  42. cvalue = context.get(cvar)
  43. try:
  44. if cvar.get() != cvalue:
  45. cvar.set(cvalue)
  46. except LookupError:
  47. cvar.set(cvalue)
  48. # Python 3.12 deprecates asyncio.iscoroutinefunction() as an alias for
  49. # inspect.iscoroutinefunction(), whilst also removing the _is_coroutine marker.
  50. # The latter is replaced with the inspect.markcoroutinefunction decorator.
  51. # Until 3.12 is the minimum supported Python version, provide a shim.
  52. # Django 4.0 only supports 3.8+, so don't concern with the _or_partial backport.
  53. if hasattr(inspect, "markcoroutinefunction"):
  54. iscoroutinefunction = inspect.iscoroutinefunction
  55. markcoroutinefunction: Callable[[_F], _F] = inspect.markcoroutinefunction
  56. else:
  57. iscoroutinefunction = asyncio.iscoroutinefunction # type: ignore[assignment]
  58. def markcoroutinefunction(func: _F) -> _F:
  59. func._is_coroutine = asyncio.coroutines._is_coroutine # type: ignore
  60. return func
  61. if sys.version_info >= (3, 8):
  62. _iscoroutinefunction_or_partial = iscoroutinefunction
  63. else:
  64. def _iscoroutinefunction_or_partial(func: Any) -> bool:
  65. # Python < 3.8 does not correctly determine partially wrapped
  66. # coroutine functions are coroutine functions, hence the need for
  67. # this to exist. Code taken from CPython.
  68. while inspect.ismethod(func):
  69. func = func.__func__
  70. while isinstance(func, functools.partial):
  71. func = func.func
  72. return iscoroutinefunction(func)
  73. class ThreadSensitiveContext:
  74. """Async context manager to manage context for thread sensitive mode
  75. This context manager controls which thread pool executor is used when in
  76. thread sensitive mode. By default, a single thread pool executor is shared
  77. within a process.
  78. In Python 3.7+, the ThreadSensitiveContext() context manager may be used to
  79. specify a thread pool per context.
  80. This context manager is re-entrant, so only the outer-most call to
  81. ThreadSensitiveContext will set the context.
  82. Usage:
  83. >>> import time
  84. >>> async with ThreadSensitiveContext():
  85. ... await sync_to_async(time.sleep, 1)()
  86. """
  87. def __init__(self):
  88. self.token = None
  89. async def __aenter__(self):
  90. try:
  91. SyncToAsync.thread_sensitive_context.get()
  92. except LookupError:
  93. self.token = SyncToAsync.thread_sensitive_context.set(self)
  94. return self
  95. async def __aexit__(self, exc, value, tb):
  96. if not self.token:
  97. return
  98. executor = SyncToAsync.context_to_thread_executor.pop(self, None)
  99. if executor:
  100. executor.shutdown()
  101. SyncToAsync.thread_sensitive_context.reset(self.token)
  102. class AsyncToSync(Generic[_P, _R]):
  103. """
  104. Utility class which turns an awaitable that only works on the thread with
  105. the event loop into a synchronous callable that works in a subthread.
  106. If the call stack contains an async loop, the code runs there.
  107. Otherwise, the code runs in a new loop in a new thread.
  108. Either way, this thread then pauses and waits to run any thread_sensitive
  109. code called from further down the call stack using SyncToAsync, before
  110. finally exiting once the async task returns.
  111. """
  112. # Maps launched Tasks to the threads that launched them (for locals impl)
  113. launch_map: "Dict[asyncio.Task[object], threading.Thread]" = {}
  114. # Keeps track of which CurrentThreadExecutor to use. This uses an asgiref
  115. # Local, not a threadlocal, so that tasks can work out what their parent used.
  116. executors = Local()
  117. # When we can't find a CurrentThreadExecutor from the context, such as
  118. # inside create_task, we'll look it up here from the running event loop.
  119. loop_thread_executors: "Dict[asyncio.AbstractEventLoop, CurrentThreadExecutor]" = {}
  120. def __init__(
  121. self,
  122. awaitable: Union[
  123. Callable[_P, Coroutine[Any, Any, _R]],
  124. Callable[_P, Awaitable[_R]],
  125. ],
  126. force_new_loop: bool = False,
  127. ):
  128. if not callable(awaitable) or (
  129. not _iscoroutinefunction_or_partial(awaitable)
  130. and not _iscoroutinefunction_or_partial(
  131. getattr(awaitable, "__call__", awaitable)
  132. )
  133. ):
  134. # Python does not have very reliable detection of async functions
  135. # (lots of false negatives) so this is just a warning.
  136. warnings.warn(
  137. "async_to_sync was passed a non-async-marked callable", stacklevel=2
  138. )
  139. self.awaitable = awaitable
  140. try:
  141. self.__self__ = self.awaitable.__self__ # type: ignore[union-attr]
  142. except AttributeError:
  143. pass
  144. if force_new_loop:
  145. # They have asked that we always run in a new sub-loop.
  146. self.main_event_loop = None
  147. else:
  148. try:
  149. self.main_event_loop = asyncio.get_running_loop()
  150. except RuntimeError:
  151. # There's no event loop in this thread. Look for the threadlocal if
  152. # we're inside SyncToAsync
  153. main_event_loop_pid = getattr(
  154. SyncToAsync.threadlocal, "main_event_loop_pid", None
  155. )
  156. # We make sure the parent loop is from the same process - if
  157. # they've forked, this is not going to be valid any more (#194)
  158. if main_event_loop_pid and main_event_loop_pid == os.getpid():
  159. self.main_event_loop = getattr(
  160. SyncToAsync.threadlocal, "main_event_loop", None
  161. )
  162. else:
  163. self.main_event_loop = None
  164. def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
  165. __traceback_hide__ = True # noqa: F841
  166. # You can't call AsyncToSync from a thread with a running event loop
  167. try:
  168. event_loop = asyncio.get_running_loop()
  169. except RuntimeError:
  170. pass
  171. else:
  172. if event_loop.is_running():
  173. raise RuntimeError(
  174. "You cannot use AsyncToSync in the same thread as an async event loop - "
  175. "just await the async function directly."
  176. )
  177. # Wrapping context in list so it can be reassigned from within
  178. # `main_wrap`.
  179. context = [contextvars.copy_context()]
  180. # Make a future for the return information
  181. call_result: "Future[_R]" = Future()
  182. # Get the source thread
  183. source_thread = threading.current_thread()
  184. # Make a CurrentThreadExecutor we'll use to idle in this thread - we
  185. # need one for every sync frame, even if there's one above us in the
  186. # same thread.
  187. if hasattr(self.executors, "current"):
  188. old_current_executor = self.executors.current
  189. else:
  190. old_current_executor = None
  191. current_executor = CurrentThreadExecutor()
  192. self.executors.current = current_executor
  193. loop = None
  194. # Use call_soon_threadsafe to schedule a synchronous callback on the
  195. # main event loop's thread if it's there, otherwise make a new loop
  196. # in this thread.
  197. try:
  198. awaitable = self.main_wrap(
  199. call_result,
  200. source_thread,
  201. sys.exc_info(),
  202. context,
  203. *args,
  204. **kwargs,
  205. )
  206. if not (self.main_event_loop and self.main_event_loop.is_running()):
  207. # Make our own event loop - in a new thread - and run inside that.
  208. loop = asyncio.new_event_loop()
  209. self.loop_thread_executors[loop] = current_executor
  210. loop_executor = ThreadPoolExecutor(max_workers=1)
  211. loop_future = loop_executor.submit(
  212. self._run_event_loop, loop, awaitable
  213. )
  214. if current_executor:
  215. # Run the CurrentThreadExecutor until the future is done
  216. current_executor.run_until_future(loop_future)
  217. # Wait for future and/or allow for exception propagation
  218. loop_future.result()
  219. else:
  220. # Call it inside the existing loop
  221. self.main_event_loop.call_soon_threadsafe(
  222. self.main_event_loop.create_task, awaitable
  223. )
  224. if current_executor:
  225. # Run the CurrentThreadExecutor until the future is done
  226. current_executor.run_until_future(call_result)
  227. finally:
  228. # Clean up any executor we were running
  229. if loop is not None:
  230. del self.loop_thread_executors[loop]
  231. if hasattr(self.executors, "current"):
  232. del self.executors.current
  233. if old_current_executor:
  234. self.executors.current = old_current_executor
  235. _restore_context(context[0])
  236. # Wait for results from the future.
  237. return call_result.result()
  238. def _run_event_loop(self, loop, coro):
  239. """
  240. Runs the given event loop (designed to be called in a thread).
  241. """
  242. asyncio.set_event_loop(loop)
  243. try:
  244. loop.run_until_complete(coro)
  245. finally:
  246. try:
  247. # mimic asyncio.run() behavior
  248. # cancel unexhausted async generators
  249. tasks = asyncio.all_tasks(loop)
  250. for task in tasks:
  251. task.cancel()
  252. async def gather():
  253. await asyncio.gather(*tasks, return_exceptions=True)
  254. loop.run_until_complete(gather())
  255. for task in tasks:
  256. if task.cancelled():
  257. continue
  258. if task.exception() is not None:
  259. loop.call_exception_handler(
  260. {
  261. "message": "unhandled exception during loop shutdown",
  262. "exception": task.exception(),
  263. "task": task,
  264. }
  265. )
  266. if hasattr(loop, "shutdown_asyncgens"):
  267. loop.run_until_complete(loop.shutdown_asyncgens())
  268. finally:
  269. loop.close()
  270. asyncio.set_event_loop(self.main_event_loop)
  271. def __get__(self, parent: Any, objtype: Any) -> Callable[_P, _R]:
  272. """
  273. Include self for methods
  274. """
  275. func = functools.partial(self.__call__, parent)
  276. return functools.update_wrapper(func, self.awaitable)
  277. async def main_wrap(
  278. self,
  279. call_result: "Future[_R]",
  280. source_thread: threading.Thread,
  281. exc_info: "OptExcInfo",
  282. context: List[contextvars.Context],
  283. *args: _P.args,
  284. **kwargs: _P.kwargs,
  285. ) -> None:
  286. """
  287. Wraps the awaitable with something that puts the result into the
  288. result/exception future.
  289. """
  290. __traceback_hide__ = True # noqa: F841
  291. if context is not None:
  292. _restore_context(context[0])
  293. current_task = SyncToAsync.get_current_task()
  294. assert current_task is not None
  295. self.launch_map[current_task] = source_thread
  296. try:
  297. # If we have an exception, run the function inside the except block
  298. # after raising it so exc_info is correctly populated.
  299. if exc_info[1]:
  300. try:
  301. raise exc_info[1]
  302. except BaseException:
  303. result = await self.awaitable(*args, **kwargs)
  304. else:
  305. result = await self.awaitable(*args, **kwargs)
  306. except BaseException as e:
  307. call_result.set_exception(e)
  308. else:
  309. call_result.set_result(result)
  310. finally:
  311. del self.launch_map[current_task]
  312. context[0] = contextvars.copy_context()
  313. class SyncToAsync(Generic[_P, _R]):
  314. """
  315. Utility class which turns a synchronous callable into an awaitable that
  316. runs in a threadpool. It also sets a threadlocal inside the thread so
  317. calls to AsyncToSync can escape it.
  318. If thread_sensitive is passed, the code will run in the same thread as any
  319. outer code. This is needed for underlying Python code that is not
  320. threadsafe (for example, code which handles SQLite database connections).
  321. If the outermost program is async (i.e. SyncToAsync is outermost), then
  322. this will be a dedicated single sub-thread that all sync code runs in,
  323. one after the other. If the outermost program is sync (i.e. AsyncToSync is
  324. outermost), this will just be the main thread. This is achieved by idling
  325. with a CurrentThreadExecutor while AsyncToSync is blocking its sync parent,
  326. rather than just blocking.
  327. If executor is passed in, that will be used instead of the loop's default executor.
  328. In order to pass in an executor, thread_sensitive must be set to False, otherwise
  329. a TypeError will be raised.
  330. """
  331. # Maps launched threads to the coroutines that spawned them
  332. launch_map: "Dict[threading.Thread, asyncio.Task[object]]" = {}
  333. # Storage for main event loop references
  334. threadlocal = threading.local()
  335. # Single-thread executor for thread-sensitive code
  336. single_thread_executor = ThreadPoolExecutor(max_workers=1)
  337. # Maintain a contextvar for the current execution context. Optionally used
  338. # for thread sensitive mode.
  339. thread_sensitive_context: "contextvars.ContextVar[ThreadSensitiveContext]" = (
  340. contextvars.ContextVar("thread_sensitive_context")
  341. )
  342. # Contextvar that is used to detect if the single thread executor
  343. # would be awaited on while already being used in the same context
  344. deadlock_context: "contextvars.ContextVar[bool]" = contextvars.ContextVar(
  345. "deadlock_context"
  346. )
  347. # Maintaining a weak reference to the context ensures that thread pools are
  348. # erased once the context goes out of scope. This terminates the thread pool.
  349. context_to_thread_executor: "weakref.WeakKeyDictionary[ThreadSensitiveContext, ThreadPoolExecutor]" = (
  350. weakref.WeakKeyDictionary()
  351. )
  352. def __init__(
  353. self,
  354. func: Callable[_P, _R],
  355. thread_sensitive: bool = True,
  356. executor: Optional["ThreadPoolExecutor"] = None,
  357. ) -> None:
  358. if (
  359. not callable(func)
  360. or _iscoroutinefunction_or_partial(func)
  361. or _iscoroutinefunction_or_partial(getattr(func, "__call__", func))
  362. ):
  363. raise TypeError("sync_to_async can only be applied to sync functions.")
  364. self.func = func
  365. functools.update_wrapper(self, func)
  366. self._thread_sensitive = thread_sensitive
  367. markcoroutinefunction(self)
  368. if thread_sensitive and executor is not None:
  369. raise TypeError("executor must not be set when thread_sensitive is True")
  370. self._executor = executor
  371. try:
  372. self.__self__ = func.__self__ # type: ignore
  373. except AttributeError:
  374. pass
  375. async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
  376. __traceback_hide__ = True # noqa: F841
  377. loop = asyncio.get_running_loop()
  378. # Work out what thread to run the code in
  379. if self._thread_sensitive:
  380. if hasattr(AsyncToSync.executors, "current"):
  381. # If we have a parent sync thread above somewhere, use that
  382. executor = AsyncToSync.executors.current
  383. elif self.thread_sensitive_context.get(None):
  384. # If we have a way of retrieving the current context, attempt
  385. # to use a per-context thread pool executor
  386. thread_sensitive_context = self.thread_sensitive_context.get()
  387. if thread_sensitive_context in self.context_to_thread_executor:
  388. # Re-use thread executor in current context
  389. executor = self.context_to_thread_executor[thread_sensitive_context]
  390. else:
  391. # Create new thread executor in current context
  392. executor = ThreadPoolExecutor(max_workers=1)
  393. self.context_to_thread_executor[thread_sensitive_context] = executor
  394. elif loop in AsyncToSync.loop_thread_executors:
  395. # Re-use thread executor for running loop
  396. executor = AsyncToSync.loop_thread_executors[loop]
  397. elif self.deadlock_context.get(False):
  398. raise RuntimeError(
  399. "Single thread executor already being used, would deadlock"
  400. )
  401. else:
  402. # Otherwise, we run it in a fixed single thread
  403. executor = self.single_thread_executor
  404. self.deadlock_context.set(True)
  405. else:
  406. # Use the passed in executor, or the loop's default if it is None
  407. executor = self._executor
  408. context = contextvars.copy_context()
  409. child = functools.partial(self.func, *args, **kwargs)
  410. func = context.run
  411. try:
  412. # Run the code in the right thread
  413. ret: _R = await loop.run_in_executor(
  414. executor,
  415. functools.partial(
  416. self.thread_handler,
  417. loop,
  418. self.get_current_task(),
  419. sys.exc_info(),
  420. func,
  421. child,
  422. ),
  423. )
  424. finally:
  425. _restore_context(context)
  426. self.deadlock_context.set(False)
  427. return ret
  428. def __get__(
  429. self, parent: Any, objtype: Any
  430. ) -> Callable[_P, Coroutine[Any, Any, _R]]:
  431. """
  432. Include self for methods
  433. """
  434. func = functools.partial(self.__call__, parent)
  435. return functools.update_wrapper(func, self.func)
  436. def thread_handler(self, loop, source_task, exc_info, func, *args, **kwargs):
  437. """
  438. Wraps the sync application with exception handling.
  439. """
  440. __traceback_hide__ = True # noqa: F841
  441. # Set the threadlocal for AsyncToSync
  442. self.threadlocal.main_event_loop = loop
  443. self.threadlocal.main_event_loop_pid = os.getpid()
  444. # Set the task mapping (used for the locals module)
  445. current_thread = threading.current_thread()
  446. if AsyncToSync.launch_map.get(source_task) == current_thread:
  447. # Our parent task was launched from this same thread, so don't make
  448. # a launch map entry - let it shortcut over us! (and stop infinite loops)
  449. parent_set = False
  450. else:
  451. self.launch_map[current_thread] = source_task
  452. parent_set = True
  453. source_task = (
  454. None # allow the task to be garbage-collected in case of exceptions
  455. )
  456. # Run the function
  457. try:
  458. # If we have an exception, run the function inside the except block
  459. # after raising it so exc_info is correctly populated.
  460. if exc_info[1]:
  461. try:
  462. raise exc_info[1]
  463. except BaseException:
  464. return func(*args, **kwargs)
  465. else:
  466. return func(*args, **kwargs)
  467. finally:
  468. # Only delete the launch_map parent if we set it, otherwise it is
  469. # from someone else.
  470. if parent_set:
  471. del self.launch_map[current_thread]
  472. @staticmethod
  473. def get_current_task() -> Optional["asyncio.Task[Any]"]:
  474. """
  475. Implementation of asyncio.current_task()
  476. that returns None if there is no task.
  477. """
  478. try:
  479. return asyncio.current_task()
  480. except RuntimeError:
  481. return None
  482. @overload
  483. def async_to_sync(
  484. *,
  485. force_new_loop: bool = False,
  486. ) -> Callable[
  487. [Union[Callable[_P, Coroutine[Any, Any, _R]], Callable[_P, Awaitable[_R]]]],
  488. Callable[_P, _R],
  489. ]:
  490. ...
  491. @overload
  492. def async_to_sync(
  493. awaitable: Union[
  494. Callable[_P, Coroutine[Any, Any, _R]],
  495. Callable[_P, Awaitable[_R]],
  496. ],
  497. *,
  498. force_new_loop: bool = False,
  499. ) -> Callable[_P, _R]:
  500. ...
  501. def async_to_sync(
  502. awaitable: Optional[
  503. Union[
  504. Callable[_P, Coroutine[Any, Any, _R]],
  505. Callable[_P, Awaitable[_R]],
  506. ]
  507. ] = None,
  508. *,
  509. force_new_loop: bool = False,
  510. ) -> Union[
  511. Callable[
  512. [Union[Callable[_P, Coroutine[Any, Any, _R]], Callable[_P, Awaitable[_R]]]],
  513. Callable[_P, _R],
  514. ],
  515. Callable[_P, _R],
  516. ]:
  517. if awaitable is None:
  518. return lambda f: AsyncToSync(
  519. f,
  520. force_new_loop=force_new_loop,
  521. )
  522. return AsyncToSync(
  523. awaitable,
  524. force_new_loop=force_new_loop,
  525. )
  526. @overload
  527. def sync_to_async(
  528. *,
  529. thread_sensitive: bool = True,
  530. executor: Optional["ThreadPoolExecutor"] = None,
  531. ) -> Callable[[Callable[_P, _R]], Callable[_P, Coroutine[Any, Any, _R]]]:
  532. ...
  533. @overload
  534. def sync_to_async(
  535. func: Callable[_P, _R],
  536. *,
  537. thread_sensitive: bool = True,
  538. executor: Optional["ThreadPoolExecutor"] = None,
  539. ) -> Callable[_P, Coroutine[Any, Any, _R]]:
  540. ...
  541. def sync_to_async(
  542. func: Optional[Callable[_P, _R]] = None,
  543. *,
  544. thread_sensitive: bool = True,
  545. executor: Optional["ThreadPoolExecutor"] = None,
  546. ) -> Union[
  547. Callable[[Callable[_P, _R]], Callable[_P, Coroutine[Any, Any, _R]]],
  548. Callable[_P, Coroutine[Any, Any, _R]],
  549. ]:
  550. if func is None:
  551. return lambda f: SyncToAsync(
  552. f,
  553. thread_sensitive=thread_sensitive,
  554. executor=executor,
  555. )
  556. return SyncToAsync(
  557. func,
  558. thread_sensitive=thread_sensitive,
  559. executor=executor,
  560. )