from __future__ import annotations

import asyncio
import pickle
from datetime import timedelta
from time import sleep

import pytest

from distributed import Client, Nanny, Queue, TimeoutError, wait, worker_client
from distributed.metrics import time
from distributed.utils import open_port, wait_for
from distributed.utils_test import div, gen_cluster, inc, popen


@gen_cluster(client=True)
async def test_queue(c, s, a, b):
    x = await Queue("x")
    y = await Queue("y")
    xx = await Queue("x")
    assert x.client is c

    future = c.submit(inc, 1)

    await x.put(future)
    await y.put(future)
    future2 = await xx.get()
    assert future.key == future2.key

    with pytest.raises(TimeoutError):
        await x.get(timeout=0.1)

    del future, future2

    await asyncio.sleep(0.1)
    assert s.tasks  # future still present in y's queue
    await y.get()  # burn future

    start = time()
    while s.tasks:
        await asyncio.sleep(0.01)
        assert time() < start + 5


@gen_cluster(client=True)
async def test_queue_with_data(c, s, a, b):
    x = await Queue("x")
    xx = await Queue("x")
    assert x.client is c

    await x.put((1, "hello"))
    data = await xx.get()

    assert data == (1, "hello")

    with pytest.raises(TimeoutError):
        await x.get(timeout=0.1)


def test_sync(client):
    future = client.submit(lambda x: x + 1, 10)
    x = Queue("x")
    xx = Queue("x")
    x.put(future)
    assert x.qsize() == 1
    assert xx.qsize() == 1
    future2 = xx.get()

    assert future2.result() == 11


@gen_cluster()
async def test_hold_futures(s, a, b):
    async with Client(s.address, asynchronous=True) as c1:
        future = c1.submit(lambda x: x + 1, 10)
        q1 = await Queue("q")
        await q1.put(future)
        del q1

    await asyncio.sleep(0.1)

    async with Client(s.address, asynchronous=True) as c1:
        q2 = await Queue("q")
        future2 = await q2.get()
        result = await future2

        assert result == 11


@gen_cluster(client=True)
async def test_picklability(c, s, a, b):
    q = Queue()

    def f(x):
        q.put(x + 1)

    await c.submit(f, 10)
    result = await q.get()
    assert result == 11


def test_picklability_sync(client):
    q = Queue()

    def f(x):
        q.put(x + 1)

    client.submit(f, 10).result()

    assert q.get() == 11


@pytest.mark.slow
@gen_cluster(client=True, nthreads=[("127.0.0.1", 2)] * 5, Worker=Nanny, timeout=60)
async def test_race(c, s, *workers):
    def f(i):
        with worker_client() as c:
            q = Queue("x", client=c)
            for _ in range(100):
                future = q.get()
                x = future.result()
                y = c.submit(inc, x)
                q.put(y)
                sleep(0.01)
            result = q.get().result()
            return result

    q = Queue("x", client=c)
    L = await c.scatter(range(5))
    for future in L:
        await q.put(future)

    futures = c.map(f, range(5))
    results = await c.gather(futures)
    assert all(r > 50 for r in results)
    assert sum(results) == 510
    qsize = await q.qsize()
    assert not qsize


@gen_cluster(client=True)
async def test_same_futures(c, s, a, b):
    q = Queue("x")
    future = await c.scatter(123)

    for _ in range(5):
        await q.put(future)

    assert {ts.key for ts in s.clients["queue-x"].wants_what} == {future.key}

    for _ in range(4):
        future2 = await q.get()
        assert {ts.key for ts in s.clients["queue-x"].wants_what} == {future.key}
        await asyncio.sleep(0.05)
        assert {ts.key for ts in s.clients["queue-x"].wants_what} == {future.key}

    await q.get()

    start = time()
    while "queue-x" in s.clients and s.clients["queue-x"].wants_what:
        await asyncio.sleep(0.01)
        assert time() - start < 2


@gen_cluster(client=True)
async def test_get_many(c, s, a, b):
    x = await Queue("x")
    xx = await Queue("x")

    await x.put(1)
    await x.put(2)
    await x.put(3)

    data = await xx.get(batch=True)
    assert data == [1, 2, 3]

    await x.put(1)
    await x.put(2)
    await x.put(3)

    data = await xx.get(batch=2)
    assert data == [1, 2]

    with pytest.raises(TimeoutError):
        await wait_for(xx.get(batch=2), 0.1)


@gen_cluster(client=True)
async def test_Future_knows_status_immediately(c, s, a, b):
    x = await c.scatter(123)
    q = await Queue("q")
    await q.put(x)

    async with Client(s.address, asynchronous=True) as c2:
        q2 = await Queue("q", client=c2)
        future = await q2.get()
        assert future.status == "finished"

        x = c.submit(div, 1, 0)
        await wait(x)
        await q.put(x)

        future2 = await q2.get()
        assert future2.status == "error"
        with pytest.raises(ZeroDivisionError):
            await future2

        start = time()
        while True:  # we learn about the true error eventually
            try:
                await future2
            except ZeroDivisionError:
                break
            except Exception:
                assert time() < start + 5
                await asyncio.sleep(0.05)


@gen_cluster(client=True)
async def test_erred_future(c, s, a, b):
    future = c.submit(div, 1, 0)
    q = await Queue()
    await q.put(future)
    await asyncio.sleep(0.1)
    future2 = await q.get()
    with pytest.raises(ZeroDivisionError):
        await future2.result()

    exc = await future2.exception()
    assert isinstance(exc, ZeroDivisionError)


@gen_cluster(client=True)
async def test_close(c, s, a, b):
    q = await Queue()

    q.close()
    q.close()

    while q.name in s.extensions["queues"].queues:
        await asyncio.sleep(0.01)


@gen_cluster(client=True)
async def test_timeout(c, s, a, b):
    q = await Queue("v", maxsize=1)

    start = time()
    with pytest.raises(TimeoutError):
        await q.get(timeout="300ms")
    stop = time()
    assert 0.2 < stop - start < 2.0

    await q.put(1)

    start = time()
    with pytest.raises(TimeoutError):
        await q.put(2, timeout=timedelta(seconds=0.3))
    stop = time()
    assert 0.1 < stop - start < 2.0


@gen_cluster(client=True)
async def test_2220(c, s, a, b):
    q = Queue()

    def put():
        q.put(55)

    def get():
        print(q.get())

    fut = c.submit(put)
    res = c.submit(get)

    await c.gather([res, fut])


def test_queue_in_task(loop):
    port = open_port()
    # Ensure that we can create a Queue inside a task on a
    # worker in a separate Python process than the client
    with popen(
        [
            "dask",
            "scheduler",
            "--no-dashboard",
            f"--port={port}",
        ]
    ):
        with popen(["dask", "worker", f"127.0.0.1:{port}"]):
            with Client(f"tcp://127.0.0.1:{port}", loop=loop) as c:
                c.wait_for_workers(1)

                x = Queue("x")
                x.put(123)

                def foo():
                    y = Queue("x")
                    return y.get()

                result = c.submit(foo).result()
                assert result == 123


@gen_cluster(nthreads=[])
async def test_unpickle_without_client(s):
    """Ensure that the object properly pickle roundtrips even if no client, worker, etc. is active in the given context.

    This typically happens if the object is being deserialized on the scheduler.
    """
    async with Client(s.address, asynchronous=True) as c:
        q = await Queue()
        pickled = pickle.dumps(q)

    # We do not want to initialize a client during unpickling
    with pytest.raises(ValueError):
        Client.current()

    q2 = pickle.loads(pickled)

    with pytest.raises(ValueError):
        Client.current()

    assert q2.client is None
    await asyncio.sleep(0)

    with pytest.raises(RuntimeError, match="not properly initialized"):
        await q2.put(1)

    async with Client(s.address, asynchronous=True):
        q3 = pickle.loads(pickled)
        await q3.put(1)
        assert await q3.get() == 1


@gen_cluster(client=True, nthreads=[])
async def test_set_cancelled_future(c, s):
    x = c.submit(inc, 1)
    await x.cancel()
    q = Queue("x")
    # FIXME: This is a TimeoutError but pytest doesn't appear to recognize it as
    # such
    with pytest.raises(Exception, match="unknown to scheduler"):
        await q.put(x, timeout="100ms")
