#! /usr/bin/env python

# Copyright (C) 2023- The University of Notre Dame
# This software is distributed under the GNU General Public License.
# See the file COPYING for details.

# transactions log format:
# time manager_pid MANAGER manager_pid START|END
# time manager_pid WORKER worker_id CONNECTION host:port
# time manager_pid WORKER worker_id DISCONNECTION (UNKNOWN|IDLE_OUT|FAST_ABORT|FAILURE|STATUS_WORKER|EXPLICIT)
# time manager_pid WORKER worker_id RESOURCES {resources}
# time manager_pid WORKER worker_id CACHE_UPDATE filename size_in_mb wall_time_us start_time_us
# time manager_pid WORKER worker_id TRANSFER (INPUT|OUTPUT) filename size_in_mb wall_time_us start_time_us
# time manager_pid CATEGORY name MAX {resources_max_per_task}
# time manager_pid CATEGORY name MIN {resources_min_per_task_per_worker}
# time manager_pid CATEGORY name FIRST (FIXED|MAX|MIN_WASTE|MAX_THROUGHPUT) {resources_requested}
# time manager_pid TASK task_id WAITING category_name (FIRST_RESOURCES|MAX_RESOURCES) attempt_number {resources_requested}\n");
# time manager_pid TASK task_id RUNNING worker_id (FIRST_RESOURCES|MAX_RESOURCES) {resources_allocated}
# time manager_pid TASK task_id WAITING_RETRIEVAL worker_id
# time manager_pid TASK task_id (RETRIEVED|DONE) (SUCCESS|SIGNAL|END_TIME|FORSAKEN|MAX_RETRIES|MAX_WALLTIME|UNKNOWN|RESOURCE_EXHAUSTION) exit_code {limits_exceeded} {resources_measured}
# time manager_pid LIBRARY task_id (WAITING|SENT|STARTED|FAILURE) worker_id
# time manager_pid APPLICATION message*


from collections import defaultdict
from pathlib import Path
import argparse
import heapq
import json
import math
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.pyplot as plticker
import pandas as pd
import sys
from functools import partial

resources_names = "cores memory disk gpus".split()
task_report_abs = (
    "time_worker_start time_worker_end time_commit_start time_commit_end".split()
)
task_report_all = (
    "time_input_mgr time_output_mgr size_input_mgr size_output_mgr".split()
    + task_report_abs
)


class Manager:
    def __init__(self, manager_pid, origin):
        self.origin = origin
        self.termination = 0  # yet unknown

        self.pid = manager_pid

        # [task_id][attempt_number] -> attempt_info
        self.tasks_attempts = defaultdict(
            lambda: defaultdict(lambda: defaultdict(lambda: pd.NA))
        )

        # [task_id] -> attempt_number
        self.task_last_attempt = {}

        # [(worker_id,hostport)] -> worker_info
        self.worker_lifetime = defaultdict(lambda: defaultdict(lambda: pd.NA))

        # [(worker_id,hostport)] -> transfer
        self.worker_transfers = defaultdict(lambda: [])

        # [worker_id][task_id] -> attempt_number
        self.tasks_on_worker = defaultdict(lambda: {})

        # [worker_id] -> hostport
        self.last_host_port = {}

        # [time] -> label
        self.phases = defaultdict(str)

        self.tasks = None
        self.workers = None
        self.transfers = None

    def write_tables(self, filename, order_in_log):
        name = Path(filename).stem
        tables = ["tasks", "workers", "transfers"]
        for table, df in zip(tables, [self.tasks, self.workers, self.transfers]):
            df.to_csv(f"{name}_{order_in_log}_{table}.csv", index=False)

    def make_tables(self, expand_waiting):
        self.tasks = self.make_table_tasks(expand_waiting)
        self.workers = self.make_table_workers()
        self.transfers = self.make_table_transfers()

    def make_table_tasks(self, expand_waiting=False):
        all_tasks = []
        for attempts in self.tasks_attempts.values():
            for attempt in attempts.values():
                all_tasks.append(attempt)
        tasks_df = pd.DataFrame.from_records(all_tasks)

        #  fix time_worker_start 0 bug
        tasks_df["time_worker_start"] = tasks_df[
            ["time_worker_start", "time_commit_end"]
        ].max(axis=1)
        try:
            tasks_df["measured_wall_time"] = (
                tasks_df["time_worker_end"] - tasks_df["time_worker_start"]
            )
        except KeyError:
            print(
                f"No tasks that finished were found in the log for manager pid {self.pid}"
            )
            sys.exit(1)
        return self.assign_slots(tasks_df, expand_waiting)

    def assign_slots(self, ts, expand_waiting):
        ts["slot"] = pd.NA
        if expand_waiting:
            # keep attempts that were dispatched to some worker
            ts = ts[~ts["RUNNING"].isna()]
        else:
            # when showing only execution slots, we do not have enough
            # information to plot tasks lost on disconnection
            ts = ts[~ts["RETRIEVED"].isna()]

        for worker_id in ts["worker_id"].unique():
            self.assign_slots_worker(ts, worker_id, expand_waiting)
        return ts

    def assign_slots_worker(self, ts, worker_id, expand_waiting):
        free_slots = []
        used_slots = {}

        def next_slot(index):
            used_slots[index] = (
                heapq.heappop(free_slots) if free_slots else 1 + len(used_slots)
            )
            return used_slots[index]

        def free_slot(index):
            heapq.heappush(free_slots, used_slots.pop(index))

        if expand_waiting:
            ts["time_worker_end"] = ts["time_worker_end"].combine(
                ts["last_state_time"], lambda x, y: y if pd.isna(x) else x
            )

        events = []
        for index, row in ts[ts["worker_id"] == worker_id].iterrows():
            # event tuple is: (index, start, time)
            if expand_waiting:
                events.append((index, True, row["RUNNING"]))
            else:
                events.append((index, True, row["time_worker_start"]))
            events.append((index, False, row["time_worker_end"]))

        events.sort(key=lambda s: s[2])
        for index, start, time in events:
            if start:
                ts.loc[index, "slot"] = next_slot(index)
            else:
                free_slot(index)

    def make_table_workers(self):
        all_workers = []
        for worker in self.worker_lifetime.values():
            all_workers.append(worker)
        workers_df = pd.DataFrame.from_records(all_workers)

        # fill up disconnect times for workers that survive the manager
        workers_df.loc[
            workers_df["DISCONNECTION"].isna(), ("DISCONNECTION",)
        ] = self.termination
        return workers_df

    def make_table_transfers(self):
        all_transfers = []
        keys = "worker_id hostport time direction filetype filename size wall_time start_time".split()
        for per_worker in self.worker_transfers.values():
            all_transfers.extend(per_worker)
        transfers_df = pd.DataFrame(all_transfers, columns=keys)
        return transfers_df


class ParseTxn:
    def __init__(self, logfile, expand_waiting=False, tasks_range_spec=None):
        self._log = logfile
        self.managers = {}
        self.cm = None  # current manager

        self.expand_waiting = expand_waiting
        self.tasks_range = self.expand_range(tasks_range_spec)

        self._parse()

    def expand_range(self, tasks_range_spec):
        start, end, step = 1, None, 1

        if tasks_range_spec:
            rest = tasks_range_spec.split(",")
            (start, *rest) = rest
            start = max(1, int(start)) if start else 1
            if rest:
                (end, *step) = rest
                end = max(start + 1, int(end)) if end else None
                step = max(1, int(step[0])) if step else 1
        return (start, end, step)

    def check_task_range(self, task_id):
        (start, end, step) = self.tasks_range
        if start and task_id < start:
            return False
        if end and task_id >= end:
            return False
        return 0 == (task_id + start) % step

    def write_tables(self, filename):
        for i, m in enumerate(self.managers.values()):
            m.write_tables(filename, i)

    def _next_line(self):
        with open(self._log) as log_file:
            for line in log_file:
                try:
                    if line.startswith("#"):
                        continue
                    (time, manager_pid, subject, target, event, *arg) = line.split(
                        maxsplit=5
                    )
                    try:
                        arg = arg[0].strip()
                    except IndexError:
                        arg = None
                except ValueError:
                    continue

                time = float(time) / 1000000
                if self.cm and self.cm.pid == manager_pid:
                    time -= self.cm.origin

                yield (time, manager_pid, subject, target, event, arg)

    def arg_to_xfer(self, worker_id, hostport, time, direction, arg):
        filename, size, wall_time, start_time = arg.split()
        size = float(size) / 1e6
        wall_time = float(wall_time) / 1e6
        start_time = float(start_time) / 1e6

        if start_time < self.cm.origin:
            # guard for cases when start time was unknown
            start_time = time - wall_time
        else:
            start_time -= self.cm.origin

        try:
            (filetype, rest) = filename.split("-", maxsplit=1)
        except Exception:
            filetype = "file"

        return [
            worker_id,
            hostport,
            time,
            direction,
            filetype,
            filename,
            size,
            wall_time,
            start_time,
        ]

    def arg_to_values(self, ca, prefix, names, arg):
        # values of the form [value, "units"]
        def value(res):
            if not res:
                return pd.NA
            n = float(res[0])
            if math.isnan(n) or n < 0:
                return pd.NA
            else:
                return n

        from_json = json.loads(arg)
        for r in names:
            v = value(from_json.get(r, None))
            ca[f"{prefix}{r}"] = v

    def arg_to_resources(self, ca, prefix, arg):
        self.arg_to_values(ca, prefix, resources_names, arg)

    def arg_to_task_report(self, ca, arg):
        self.arg_to_values(ca, "", task_report_all, arg)
        for r in task_report_abs:
            ca[f"{r}"] -= self.cm.origin

    def _parse_manager(self, time, manager_pid, event):
        if event == "START":
            self.cm = Manager(manager_pid, time)
            self.managers[manager_pid] = self.cm
        elif event == "END":
            self.cm.make_tables(self.expand_waiting)
            self.cm = None

    def _parse_worker(self, time, manager_pid, worker_id, event, arg):
        if not worker_id.startswith("worker-"):
            return

        if event == "CONNECTION":
            hostport = arg
            self.cm.last_host_port[worker_id] = hostport
            self.cm.worker_lifetime[(worker_id, hostport)]["worker_id"] = worker_id
            self.cm.worker_lifetime[(worker_id, hostport)]["hostport"] = hostport
            self.cm.worker_lifetime[(worker_id, hostport)][event] = time
            self.cm.worker_lifetime[(worker_id, hostport)]["DISCONNECTION"] = pd.NA
        elif event == "DISCONNECTION":
            reason = arg
            hostport = self.cm.last_host_port[worker_id]
            self.cm.worker_lifetime[(worker_id, hostport)][event] = time
            self.cm.worker_lifetime[(worker_id, hostport)]["reason"] = reason
            for task_id, attempt_number in self.cm.tasks_on_worker[worker_id].items():
                self.cm.tasks_attempts[task_id][attempt_number][
                    "reason"
                ] = "DISCONNECTION"
                self.cm.tasks_attempts[task_id][attempt_number]["DISCONNECTION"] = time
                self.cm.tasks_attempts[task_id][attempt_number][
                    "last_state"
                ] = "DISCONNECTION"
                self.cm.tasks_attempts[task_id][attempt_number][
                    "last_state_time"
                ] = time
            self.cm.tasks_on_worker.clear()
        elif event == "RESOURCES":
            hostport = self.cm.last_host_port[worker_id]
            self.arg_to_resources(
                self.cm.worker_lifetime[(worker_id, hostport)], "", arg
            )
        elif event == "CACHE_UPDATE":
            hostport = self.cm.last_host_port[worker_id]
            xfer = self.arg_to_xfer(worker_id, hostport, time, event, arg)
            self.cm.worker_transfers[(worker_id, hostport)].append(xfer)
        elif event == "TRANSFER":
            hostport = self.cm.last_host_port[worker_id]
            (direction, arg) = arg.split(maxsplit=1)
            xfer = self.arg_to_xfer(worker_id, hostport, time, direction, arg)
            self.cm.worker_transfers[(worker_id, hostport)].append(xfer)

    def _parse_category(self, time, manager_pid, category, event, arg):
        if event == "MAX":
            pass
        elif event == "MIN":
            pass
        elif event == "FIRST":
            (mode, arg) = arg.split(maxsplit=1)
            pass

    def _parse_library(self, time, manager_pid, task_id, event, arg):
        task_id = int(task_id)

        if task_id in self.cm.task_last_attempt:
            la = self.cm.task_last_attempt[task_id]
            ca = self.cm.tasks_attempts[task_id][la]
        else:
            return

        if event == "SENT":
            ca["time_worker_start"] = time
            ca["time_commit_start"] = time
        elif event == "STARTED":
            ca["library"] = time
            ca["RETRIEVED"] = ca["DONE"]
            ca["time_worker_end"] = time
            ca["time_commit_end"] = time
            ca["reason"] = "SUCCESS"
            ca["exit_code"] = 0
        elif event == "FAILURE":
            pass

    def _parse_task(self, time, manager_pid, task_id, event, arg):
        task_id = int(task_id)

        if not self.check_task_range(task_id):
            return

        if event == "READY":
            (category, allocation, attempt_number, arg) = arg.split()
            self.cm.task_last_attempt[task_id] = attempt_number

        try:
            la = self.cm.task_last_attempt[task_id]
        except KeyError:
            # this task did not have a corresponding READY event (e.g. it could be a LIBRARY)
            return

        ca = self.cm.tasks_attempts[task_id][la]

        ca[event] = time
        ca["last_state"] = event
        ca["last_state_time"] = time

        if event == "READY":
            ca["task_id"] = task_id
            ca["attempt_number"] = attempt_number
            ca["category"] = category
            self.arg_to_resources(ca, "requested_", arg)
            ca["library"] = pd.NA  # we do not know if it a library
        elif event == "RUNNING":
            (worker_id, allocation, arg) = arg.split()
            self.cm.tasks_on_worker[worker_id][task_id] = la
            ca["worker_id"] = worker_id
            self.arg_to_resources(ca, "allocated_", arg)
            self.arg_to_task_report(ca, arg)
            ca[
                "DISCONNECTION"
            ] = pd.NA  # add field to keep time in case of worker disconnection
        elif event == "WAITING_RETRIEVAL":
            pass
        elif event == "RETRIEVED":
            (reason, exit_code, l, m) = arg.split()
            ca["reason"] = reason
            ca["exit_code"] = int(exit_code)
            self.arg_to_task_report(ca, m)
            self.arg_to_resources(ca, "measured_", m)
            try:
                del self.cm.tasks_on_worker[ca["worker_id"]][task_id]
            except KeyError:
                # got the message of the worker disconnecting before the
                # retrieved transaction. Mark the task as not a disconnection
                self.cm.tasks_attempts[task_id][la]["DISCONNECTION"] = pd.NA
        elif event == "DONE":
            pass

    def _parse_application(self, time, manager_pid, *msg):
        message = " ".join(msg)
        self.cm.phases[time] = message

    def _parse(self):
        for time, manager_pid, subject, target, event, arg in self._next_line():
            if self.cm and self.cm.pid == manager_pid:
                # set each last time know to the termination time of the manager
                self.cm.termination = time

            if subject == "MANAGER":
                self._parse_manager(time, manager_pid, event)
            elif subject == "WORKER":
                self._parse_worker(time, manager_pid, target, event, arg)
            elif subject == "CATEGORY":
                self._parse_category(time, manager_pid, target, event, arg)
            elif subject == "LIBRARY":
                self._parse_library(time, manager_pid, target, event, arg)
            elif subject == "TASK":
                self._parse_task(time, manager_pid, target, event, arg)
            elif subject == "APPLICATION":
                self._parse_application(time, manager_pid, target, event, arg)

        if self.cm:
            # if self.cm still assigned, then END record missing. Here we force
            # the generation of tables in that case. This is useful when
            # plotting partial logs while the workflow is still active.
            self.cm.make_tables(self.expand_waiting)


class TxnPlot:
    def __init__(self, managers, opts):
        self.managers = managers
        self.opts = opts

        plt.style.use("tableau-colorblind10")
        if opts.tex:
            font = {
                "family": "serif",
                "serif": ["Computer Modern Roman"],
                "weight": "normal",
                "size": 11,
            }
            plt.rc("font", **font)
            plt.rc("text", usetex=True)

        mpl.style.use("fast")

        height = opts.height if opts.height else opts.width * (2 / 3)

        self.fig = plt.figure(
            constrained_layout=True, figsize=(opts.width, height), dpi=opts.dpi
        )
        self.subs = self.fig.subplots(nrows=1, ncols=len(managers), squeeze=False)

        # default legend
        self.legend_cb = {
            "workers lifetime": "C9",
            "tasks waiting at worker": "C7",
            "tasks executing": "C0",
            "tasks lost on disconnection": "lightcoral",
            "results waiting retrieval": "C8",
            "tasks failures": "red",
            "cache updates": "C1",
            "inputs from manager": "C3",
            "outputs to manager": "C5",
        }

        # dask legend
        self.legend_dask = {
            "tasks executing": "#7fc96aff",
            "data transfers": "#e3170050",
        }

        self.legend = self.legend_cb
        if args.dask_colors:
            self.legend = self.legend_dask

        self.display_plot()

    def determine_origin(self, m):
        spec = self.opts.origin
        origin = 0
        try:
            if spec == "dispatched-first-task":
                origin = m.tasks["time_commit_start"].min()
            elif spec == "waiting-first-task":
                origin = m.tasks["READY"].min()
            elif spec == "connected-first-worker":
                origin = m.workers["CONNECTION"].min()
            elif spec == "start-manager":
                origin = 0
        except KeyError:
            origin = 0
        return origin

    def determine_end(self, m):
        spec = self.opts.end
        end = m.termination
        try:
            if spec == "done-last-task":
                end = m.tasks["last_state_time"].max()
            if spec == "disconnected-last-worker":
                end = m.workers["DISCONNECTION"].max()
            if spec == "end-manager":
                end = m.termination
        except KeyError:
            end = m.termination
        return end

    def rd(self, fact, x):
        return fact * math.ceil(x / fact)

    def roundb(self, end, t):
        rd = self.rd

        # if end in order hours, round to 5min
        if end >= 3600:
            return rd(300, t)

        # if end in order 5minutes, round to minutes
        if end >= 300:
            return rd(60, t)

        # if end in order 1minutes, round to 5s
        if end >= 60:
            return rd(5, t)

        # less, don't round
        return t

    def make_tick_time(self, origin, end, t, pos):
        rd = self.rd

        units = "s"  # unit
        label = t - origin  # label
        if end >= 300:
            # if end in order of minutes, in 15s increments
            units = "m"
            label = rd(15, label) / 60

        if t != end:
            # print units only for last value
            units = ""

        return "{:.3g}{}".format(label, units)

    def bh(self, fig, slots, left, width, color, **kwargs):
        # default edgecolor for bars
        edge_color_default = "#00000010"
        # Dask edgecolors for bars
        edge_color_te_dask = "#afdfa4ff"
        edge_color_cu_dask = "#cc1f12ff"

        edge_color = edge_color_default
        if "type" in kwargs:
            type = kwargs["type"]
            if type == "cu_dask":
                edge_color = edge_color_cu_dask
            elif type == "te_dask":
                edge_color = edge_color_te_dask
            kwargs.pop("type", None)

        args = {
            "left": left,
            "width": width,
            "color": color,
            "linewidth": 0.5,
            "align": "edge",
            "edgecolor": edge_color,
        }
        args.update(kwargs)

        if not args["edgecolor"]:
            args["edgecolor"] = color
        fig.barh(slots, **args)

    def generate_ticks(self, plot_origin, plot_end):
        # auto ticks
        step = self.roundb(plot_end, (plot_end - plot_origin) / self.opts.time_ticks_n)
        ticks = [
            plot_origin + self.roundb(plot_end, i * step)
            for i in range(self.opts.time_ticks_n - 1)
        ]

        # add plot end, remove last auto tick if it is too close
        while ticks and ticks[-1] > plot_end:
            ticks.pop()

        if len(ticks) > 1 and plot_end - ticks[-1] < 3 * step / 4:
            ticks[-1] = plot_end
        else:
            ticks.append(plot_end)

        ticks.extend(x - plot_origin for x in self.opts.time_tick)
        ticks = sorted(list(set(ticks)))
        return ticks

    def show_app_log(self, fig, manager):
        line_label_sep = 0.005
        for t, msg in manager.phases.items():
            fig.axvline(x=t, color="black", linestyle="-")
            fig.text(t + line_label_sep, 0, msg, rotation=90, fontsize=5)

    def display_plot(self):
        plt.title(self.opts.title)

        for i, m in enumerate(self.managers.values()):
            s = self.subs[0, i]

            plot_origin = self.determine_origin(m)
            plot_end = self.determine_end(m)

            s.axis(xmin=plot_origin, xmax=plot_end)

            ticks = self.generate_ticks(plot_origin, plot_end)
            s.set_xticks(ticks)

            s.xaxis.set_major_formatter(
                plticker.FuncFormatter(
                    partial(self.make_tick_time, plot_origin, plot_end)
                )
            )
            s.set_xlabel("time")
            s.tick_params(axis="both", which="major", labelsize=15)
            self.plot(s, m)
            if not self.opts.noapp:
                self.show_app_log(s, m)
            self.set_legend(s)

        if self.opts.output:
            plt.savefig(self.opts.output)

        if self.opts.display or not self.opts.output and not self.opts.nograph:
            plt.show()

    def set_legend(self, fig):
        if self.opts.legend != "none":
            handles = [
                mpl.patches.Patch(color=color, label=label) for label, color in self.legend.items()
            ]
            fig.legend(loc=self.opts.legend, handles=handles, fontsize=args.fontsize)

    def plot(self, fig, manager):
        return NotImplemented


class TxnPlotManager(TxnPlot):
    def line(self, fig, cu, xtime, color):
        norm = max(1, len(cu.index))
        fig.plot(cu["one"].cumsum() / norm, cu[xtime], color=color)

    def plot(self, fig, manager):
        ts = manager.tasks
        xs = manager.transfers

        xs["one"] = 1
        ts["one"] = 1

        xs["end_time"] = xs["start_time"] + xs["wall_time"]
        xs.sort_values("end_time", ignore_index=True, inplace=True)

        cu = xs[xs["direction"] == "CACHE_UPDATE"][["end_time", "one"]]
        self.line(fig, cu, "end_time", color=self.legend["cache updates"])

        cu = xs[xs["direction"] == "INPUT"][["end_time", "one"]]
        self.line(fig, cu, "end_time", color=self.legend["inputs from manager"])

        cu = xs[xs["direction"] == "OUTPUT"][["end_time", "one"]]
        self.line(fig, cu, "end_time", color=self.legend["outputs to manager"])

        fig.set_ylabel("Percentage")
        fig.set_yticks([0, 1])
        fig.set_yticklabels(["0", "100%"])


class TxnPlotTasks(TxnPlot):
    def plot(self, fig, manager):
        library_tasks = manager.tasks.loc[manager.tasks["library"] > 0]
        regular_tasks = manager.tasks.loc[manager.tasks["library"].isnull()].copy()
        regular_tasks.sort_values("time_commit_start", ignore_index=True, inplace=True)
        manager.tasks = pd.concat([regular_tasks, library_tasks])

        # divided tasks on whether we have a full report
        at = manager.tasks
        lt = manager.tasks[manager.tasks["reason"] == "DISCONNECTION"]
        dt = manager.tasks[~manager.tasks["RETRIEVED"].isna()]

        try:
            del self.legend["workers lifetime"]  # not show in this plot
        except KeyError:
            pass

        cu = at[at["size_input_mgr"] > 0]
        self.bh(
            fig,
            cu.index,
            cu["time_commit_start"],
            cu["time_input_mgr"],
            self.legend["inputs from manager"],
            edgecolor=None,
        )

        cu = at[at["size_output_mgr"] > 0]
        self.bh(
            fig,
            cu.index,
            cu["RETRIEVED"] - cu["time_output_mgr"],
            cu["time_output_mgr"],
            self.legend["outputs to manager"],
            edgecolor=None,
        )

        self.bh(
            fig,
            dt.index,
            dt["time_commit_end"],
            dt["time_worker_start"] - dt["time_commit_end"],
            self.legend["cache updates"],
            edgecolor=None,
        )

        try:
            del self.legend["tasks waiting at worker"]  # already covered above
        except KeyError:
            pass

        self.bh(
            fig,
            lt.index,
            lt["time_commit_end"],
            lt["last_state_time"] - lt["time_commit_end"],
            self.legend["tasks lost on disconnection"],
            edgecolor=None,
        )

        self.bh(
            fig,
            dt.index,
            dt["time_worker_start"],
            dt["time_worker_end"] - dt["time_worker_start"],
            self.legend["tasks executing"],
            edgecolor=None,
        )

        self.bh(
            fig,
            dt.index,
            dt["time_worker_end"],
            dt["RETRIEVED"] - dt["time_worker_end"],
            self.legend["results waiting retrieval"],
            edgecolor=None,
        )

        try:
            del self.legend["tasks failures"]  # already covered above
        except KeyError:
            pass

        fig.set_ylabel("tasks")
        fig.set_yticks([0, len(dt.index)])
        fig.set_yticklabels(["1", str(len(dt.index) + 1)])  # ytick number of workers


class TxnPlotWorkers(TxnPlot):
    def plot_by_worker(self, fig, worker, worker_times, worker_xfers, base_slot):
        slot_count = (
            worker["slot"].nunique() + 2
        )  # + 2 for the "slots" for cache update, and transfers
        worker["slot"] += base_slot + 1  # slot 0 for cache updates, 1 for transfers

        dt = worker[~worker["RETRIEVED"].isna()]
        lt = worker[worker["reason"] == "DISCONNECTION"]

        # plot times from this wid
        self.bh(
            fig,
            base_slot,
            worker_times["CONNECTION"],
            worker_times["DISCONNECTION"] - worker_times["CONNECTION"],
            "C9",
            height=slot_count,
        )

        cu = worker_xfers[worker_xfers["direction"] == "CACHE_UPDATE"]
        self.bh(
            fig,
            base_slot + 0.1,
            cu["start_time"],
            cu["wall_time"],
            self.legend["cache updates"],
            height=0.7,
        )

        # select tasks that needed inputs or outputs, and plot that.  we prefer this
        # information from tasks rather than transfer records, as there is only
        # at most one record per task.
        cu = worker[worker["size_input_mgr"] > 0]
        self.bh(
            fig,
            base_slot + 1,
            cu["time_commit_start"],
            cu["time_input_mgr"],
            self.legend["inputs from manager"],
        )

        cu = worker[worker["size_output_mgr"] > 0]
        self.bh(
            fig,
            base_slot + 1,
            cu["RETRIEVED"] - cu["time_output_mgr"],
            cu["time_output_mgr"],
            self.legend["outputs to manager"],
        )

        if self.opts.expand_waiting:
            self.bh(
                fig,
                dt["slot"],
                dt["RUNNING"],
                dt["time_worker_start"] - dt["RUNNING"],
                self.legend["tasks waiting at worker"],
            )
            self.bh(
                fig,
                lt["slot"],
                lt["time_commit_end"],
                lt["last_state_time"] - lt["time_commit_start"],
                color=self.legend["tasks lost on disconnection"],
            )
        else:
            pass

        self.bh(
            fig,
            dt["slot"],
            dt["time_worker_start"],
            dt["time_worker_end"] - dt["time_worker_start"],
            self.legend["tasks executing"],
        )
        self.bh(
            fig,
            dt["slot"],
            dt["time_worker_end"],
            dt["RETRIEVED"] - dt["time_worker_end"],
            self.legend["results waiting retrieval"],
        )

        ft = dt[
            (dt["reason"] != "SUCCESS")
            | ((dt["reason"] == "SUCCESS") & (dt["exit_code"] != 0))
        ]
        self.bh(
            fig, ft["slot"], ft["RETRIEVED"] - 0.5, 0.5, self.legend["tasks failures"]
        )

        return worker["slot"].max() + 1

    def plot_by_worker_dask(
        self, fig, worker_tasks, worker_times, worker_xfers, base_slot
    ):
        worker_tasks["slot"] += base_slot + 0.2
        retrieved_tasks = worker_tasks[~worker_tasks["RETRIEVED"].isna()]
        cache_updates = worker_xfers[worker_xfers["direction"] == "CACHE_UPDATE"]
        slot_height = 1

        # plot the red bar: cache updates
        self.bh(
            fig,
            base_slot,
            cache_updates["start_time"],
            cache_updates["wall_time"],
            self.legend_dask["data transfers"],
            height=slot_height,
            type="cu_dask",
        )

        # every following task's bar slot add 1
        worker_tasks["slot"] += 1

        # plot the green bar: tasks executing
        self.bh(
            fig,
            retrieved_tasks["slot"],
            retrieved_tasks["time_worker_start"],
            retrieved_tasks["time_worker_end"] - retrieved_tasks["time_worker_start"],
            self.legend_dask["tasks executing"],
            height=slot_height,
            type="te_dask",
        )

        # determine the following starting slot for tasks in the next worker
        return worker_tasks["slot"].max() + 0.2

    def plot(self, fig, manager):
        ts = manager.tasks
        xs = manager.transfers
        ws = manager.workers

        gs = ts.groupby("worker_id", sort=False)
        by_min_connects = (
            ws.groupby("worker_id")["CONNECTION"].min().sort_values().index
        )
        gsd = {wid: g for (wid, g) in gs}

        base_slot = 0
        for wid in by_min_connects:
            worker = gsd.get(wid, None)
            worker_times = ws[ws["worker_id"] == wid]
            worker_xfers = xs[xs["worker_id"] == wid]

            if worker is not None:
                if args.dask_colors:
                    base_slot = self.plot_by_worker_dask(
                        fig, worker, worker_times, worker_xfers, base_slot
                    )
                else:
                    base_slot = self.plot_by_worker(
                        fig, worker, worker_times, worker_xfers, base_slot
                    )

        fig.set_ylabel("workers")
        fig.set_yticks((0, base_slot))
        fig.set_yticklabels(["0", str(gs.ngroups)])  # ytick number of workers

        if not args.dask_colors and not self.opts.expand_waiting:
            del self.legend["tasks waiting at worker"]


class TxnStats(TxnPlot):
    def convert_size(self, size_bytes):
        if size_bytes == 0:
            return "0B"
        size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
        i = int(math.floor(math.log(size_bytes, 1024)))
        p = math.pow(1024, i)
        s = round(size_bytes / p, 2)
        return "%s %s" % (s, size_name[i])

    def plot(self, fig, manager: Manager):
        num_workers = manager.workers["CONNECTION"].count()
        num_tasks = manager.tasks["DONE"].count()

        updates = manager.transfers.loc[
            manager.transfers["direction"] == "CACHE_UPDATE"
        ]
        inputs = manager.transfers.loc[manager.transfers["direction"] == "INPUT"]
        outputs = manager.transfers.loc[manager.transfers["direction"] == "OUTPUT"]
        app_updates = updates.loc[updates["filetype"] != "CCTOOLS"]

        urls = app_updates.loc[app_updates["filetype"] == "url"]
        files = app_updates.loc[app_updates["filetype"] == "file"]
        tasks = app_updates.loc[app_updates["filetype"] == "task"]
        temps = app_updates.loc[app_updates["filetype"] == "temp"]

        num_urls = urls["size"].count()
        urls_size = self.convert_size(urls["size"].sum() * 1e6)

        num_files = files["size"].count()
        files_size = self.convert_size(files["size"].sum() * 1e6)

        num_inputs = inputs["size"].count()
        inputs_size = self.convert_size(inputs["size"].sum() * 1e6)

        num_outputs = outputs["size"].count()
        outputs_size = self.convert_size(outputs["size"].sum() * 1e6)

        num_minitasks = tasks["size"].count()
        minitasks_size = self.convert_size(tasks["size"].sum() * 1e6)

        num_temps = temps["size"].count()
        temps_size = self.convert_size(temps["size"].sum() * 1e6)

        if num_urls > 0:
            xfr_avg = urls["wall_time"].sum() / urls["wall_time"].count()

        print("-----------------------------------------")
        print(f"Number of workers: {num_workers}")
        print(f"Number of tasks: {num_tasks}")
        print(f"Files created at worker: {num_files}")
        print(f"Bytes created: {files_size}")
        print(f"Number of inputs transferred by manager: {num_inputs}")
        print(f"Bytes transferred: {inputs_size}")
        print(f"Number of outputs returned: {num_outputs}")
        print(f"Size of outputs returned: {outputs_size}")
        print(f"Number of URL/Worker transfers: {num_urls}")
        print(f"Bytes transferred: {urls_size}")
        if num_urls > 0:
            print("Average transfer time: {:.2f}s".format(float(xfr_avg)))
        print(f"Number of mini tasks: {num_minitasks}")
        print(f"Bytes created: {minitasks_size}")
        print(f"Number of temps: {num_temps}")
        print(f"Bytes created: {temps_size}")
        times = []
        for lifetime in manager.worker_lifetime.values():
            times.append(lifetime["DISCONNECTION"] - lifetime["CONNECTION"])
        cpu_time = sum(times)
        print(f"Worker CPU Time (h): {round(cpu_time/60/60, 2)}")
        print("-----------------------------------------")
        print()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="vine_plot_txn_log",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description="Plot TaskVine workflow information from a transaction log file.",
    )

    parser.add_argument("log", help="Path to transaction log file")
    parser.add_argument(
        "output",
        nargs="?",
        default=None,
        help="output name of the plot/csv generated. If not given, assumes --display",
    )
    parser.add_argument(
        "--mode",
        choices="workers tasks manager csv stats".split(),
        default="workers",
        help="information to plot. sv. If csv, write dataframes for tasks, workers and transfers to disk instead.",
    )
    parser.add_argument(
        "--title", nargs="?", default="TaskVine", help="Title of the plot"
    )
    parser.add_argument(
        "--origin",
        nargs="?",
        help="change plot origin.",
        choices="dispatched-first-task waiting-first-task connected-first-worker start-manager".split(),
        default="connected-first-worker",
    )
    parser.add_argument(
        "--end",
        nargs="?",
        help="change plot end time.",
        choices="done-last-task disconnected-last-worker end-manager".split(),
        default="done-last-task",
    )
    parser.add_argument(
        "--tasks-range",
        nargs="?",
        help="range of tasks ids to plot as start[,[end][,step]] ",
        default="1,,1",
    )
    parser.add_argument(
        "--expand-waiting",
        action="store_true",
        help="show complete lifetime of waiting tasks at the worker",
    )
    parser.add_argument(
        "--display",
        action="store_true",
        help="show plot using matplotlib internal viewer",
    )
    parser.add_argument(
        "--legend",
        nargs="?",
        help='position of the legend. ("best" may take longer to render.) ',
        choices=[
            "upper right",
            "upper left",
            "lower right",
            "lower left",
            "best",
            "none",
        ],
        default="upper left",
    )
    parser.add_argument(
        "--time-ticks-n",
        nargs="?",
        type=int,
        help="adds a time tick every (origin-end)/time-ticks-n), rounded to units of time",
        default=5,
    )
    parser.add_argument(
        "--time-tick",
        nargs="?",
        type=int,
        action="append",
        help="tick value to add, in seconds from manager's start. May be specified multiple times.",
        default=[],
    )
    parser.add_argument(
        "--width", nargs="?", type=float, help="width in inches", default=12
    )
    parser.add_argument(
        "--height",
        nargs="?",
        type=float,
        help="height in inches. Default is 2/3 of width",
        default=None,
    )
    parser.add_argument(
        "--dpi", nargs="?", type=int, help="output image resoulution", default=300
    )
    parser.add_argument("--tex", action="store_true", help="use tex fonts")
    parser.add_argument(
        "--noapp", action="store_true", help="do not show custom application lines"
    )
    parser.add_argument(
        "--nograph", action="store_true", help="do not show a matplotlib window"
    )
    parser.add_argument(
        "--dask-colors", action="store_true", help="color palette aligns with Dask's"
    )
    parser.add_argument(
        "--fontsize",
        nargs="?",
        type=str,
        help="set fontsize of figure legend",
        default="medium",
    )

    args = parser.parse_args()

    p = ParseTxn(args.log, args.expand_waiting, args.tasks_range)

    if args.time_ticks_n < 1:
        p.error("Minimum --time-ticks-n is 1")

    if args.mode == "csv":
        p.write_tables(args.output)
    elif args.mode == "workers":
        TxnPlotWorkers(p.managers, args)
    elif args.mode == "tasks":
        TxnPlotTasks(p.managers, args)
    elif args.mode == "manager":
        TxnPlotManager(p.managers, args)
    elif args.mode == "stats":
        args.nograph = True
        TxnStats(p.managers, args)
