Skip to content
Snippets Groups Projects
boss.py 1.55 KiB
Newer Older
Guilhem Saurel's avatar
Guilhem Saurel committed
import time

import numpy as np

from manager import QueueClient
Guilhem Saurel's avatar
Guilhem Saurel committed
from task import Task


def grid(n_data):
    """
    Generate n_traj points from a 3d grid
    """
    n_points = int(np.cbrt(n_data)) + 1
    symbreak = .9586  # Use it to break the symmetry of the grid
Guilhem Saurel's avatar
Guilhem Saurel committed
    x = np.linspace(-1., symbreak, num=n_points, endpoint=False)
    y = np.linspace(-1., symbreak, num=n_points, endpoint=True)
    z = np.linspace(-np.pi, np.pi * symbreak, num=n_points, endpoint=True)

    positions = [[a, b, c] for a in x for b in y for c in z]

    np.random.shuffle(positions)
    return np.array(positions)[0:n_data]


class Boss(QueueClient):
Guilhem Saurel's avatar
Guilhem Saurel committed
    def run(self, n_data, n_sub=100):

        assert (n_data % n_sub == 0)

        # Sample the dataset
        x0s = grid(n_data)
Guilhem Saurel's avatar
Guilhem Saurel committed
        lots = np.split(x0s, n_data // n_sub)
        print(f'Choose {len(x0s)} trajectories divided in {len(lots)} lots.')

        # create some tasks
        for lot in lots:
            self.tasks.put(Task(lot))

        # Wait for the results
        data = []
        for _ in lots:
            result = self.results.get()
            print(f'got result {result.identifier} of processed in {result.time:.3f}s : ')
            data.extend(result.results)
        print('Got all {tasks} results !')

        return data

Guilhem Saurel's avatar
Guilhem Saurel committed

if __name__ == '__main__':
    # Allocate data to be proceeded.
    NDATA = 20

Guilhem Saurel's avatar
Guilhem Saurel committed
    x0s = grid(1000)

    start = time.perf_counter()
    boss = Boss()
Guilhem Saurel's avatar
Guilhem Saurel committed
    data = boss.run(10, 1)
    total = time.perf_counter() - start

Guilhem Saurel's avatar
Guilhem Saurel committed
    print(f'Completed in {total:.3f} secs')