#!/usr/bin/env python3

import time

import numpy as np
from manager import QueueClient
from task import Task


def grid(n_data):
    """
    Generate n_traj points from a 3d grid
    """
    n_points = int(np.cbrt(n_data)) + 1
    symbreak = 0.9586  # Use it to break the symmetry of the grid
    x = np.linspace(-1.0, symbreak, num=n_points, endpoint=False)
    y = np.linspace(-1.0, symbreak, num=n_points, endpoint=True)
    z = np.linspace(-np.pi, np.pi * symbreak, num=n_points, endpoint=True)

    positions = [[a, b, c] for a in x for b in y for c in z]

    np.random.shuffle(positions)
    return np.array(positions)[0:n_data]


class Boss(QueueClient):
    def run(self, n_data, n_sub=100):
        assert n_data % n_sub == 0

        # Sample the dataset
        x0s = grid(n_data)
        lots = np.split(x0s, n_data // n_sub)
        print(f"Choose {len(x0s)} trajectories divided in {len(lots)} lots.")

        # create some tasks
        for lot in lots:
            self.tasks.put(Task(lot))

        # Wait for the results
        data = []
        for _ in lots:
            result = self.results.get()
            print(
                f"got result {result.identifier} of processed in {result.time:.3f}s : "
            )
            data.extend(result.results)
        print("Got all {tasks} results !")

        return data


if __name__ == "__main__":
    # Allocate data to be proceeded.
    NDATA = 20

    x0s = grid(1000)

    start = time.perf_counter()
    boss = Boss()
    data = boss.run(10, 1)
    total = time.perf_counter() - start

    print(f"Completed in {total:.3f} secs")