#
# Copyright (C) 2017 Intel Corporation
#
# This software and the related documents are Intel copyrighted materials, and your use of them
# is governed by the express license under which they were provided to you ("License"). Unless
# the License provides otherwise, you may not use, modify, copy, publish, distribute, disclose
# or transmit this software or the related documents without Intel's prior written permission.
#
# This software and the related documents are provided as is, with no express or implied
# warranties, other than those that are expressly stated in the License.
#


# ------------------------------------------------------------------------------
# This example shows how to traverse a topdown tree and aggregate data from
# children, using data from the bottomup table if absent in the topdown
# representation.
#
# Please note that resulting data accuracy depends on the fact that particular
# loop or function present only in a single call chain. For example, this
# approach wouldn't work well for recursive functions.
# ------------------------------------------------------------------------------

import sys

try:

    import advisor

except ImportError:

    print(
        """Import error: Python could not resolve path to Advisor's pythonapi directory.
        To fix, either manually add path to the pythonapi directory into PYTHONPATH environment
        variable, or use advixe-vars.* scripts to set up product environment variables automatically."""
    )
    sys.exit(1)

# Check command-line arguments.
if len(sys.argv) < 2:
    print('Usage: "python {} path_to_project_dir'.format(__file__))
    sys.exit(2)


class Accumulator(object):
    def __init__(self, keys):
        self.keys = keys
        self.values = {}
        self.zeros = dict([(key, 0.0) for key in self.keys])

    def get_accumulated(self, row):
        if row["key_column"] in self.values:
            return self.values[row["key_column"]]
        return self.zeros

    def get_bottomup_data(self, row, key):
        # row from topdown
        buValue = 0.0
        for buRow in row.sync:
            if key in buRow:
                try:
                    buValue += float(buRow[key])
                except ValueError:
                    pass
        return buValue

    def accumulate(self, row):
        if not row["key_column"] in self.values:
            row_values = {}
            for k in self.keys:
                if not k in row:
                    # There is no such field in topdown, try to get data from bottomup.
                    row_values[k] = self.get_bottomup_data(row, k)
                else:
                    try:
                        row_values[k] = float(row[k])
                    except ValueError:
                        # There is no such data in topdown, try to get data from bottomup.
                        row_values[k] = self.get_bottomup_data(row, k)
            for child in row.children:
                child_values = self.accumulate(child)
                for k in self.keys:
                    if not k in row:
                        # There is no such field in topdown, try to get data from bottomup.
                        row_values[k] += self.get_bottomup_data(row, k)
                    try:
                        row_values[k] += float(child_values[k])
                    except ValueError:
                        # There is no such data in topdown, try to get data from bottomup.
                        row_values[k] += self.get_bottomup_data(row, k)
            self.values[row["key_column"]] = row_values
        return self.values[row["key_column"]]


name_key = "function_call_sites_and_loops"
aggregate_keys = ["self_gflop", "self_memory_gb"]

# Open the Advisor Project and load the data.
project = advisor.open_project(sys.argv[1])
survey = project.load(advisor.SURVEY)

acc = Accumulator(aggregate_keys)
acc.accumulate(next(survey.topdown))


# Traverse topdown tree and print accumulated values.
fmt_head = "{:^80}" + "| {:^24}" * 6
fmt_data = "{:80}" + "| {:<24}" * 6

print("TOPDOWN: call tree representation")
print("{:=<233}".format(""))
print(
    fmt_head.format(
        *(
            [name_key.upper() + " and TYPE"]
            + ["AGG_" + k.upper() for k in aggregate_keys]
            + ["TOTAL_ELAPSED_TIME", "AGGREGATED_GFLOPS", "AGGREGATED_AI", "SELF_MEMORY_GB",]
        )
    )
)
print("{:=<233}".format(""))
for row in survey.topdown:
    stack = [(row, 0)]
    while stack:
        v, level = stack.pop()
        for r in v.get_children():
            stack.append((r, level + 1))
        # Do not print non-executed binary parts.
        if "[Not Executed]" in v["type"]:
            continue
        agg = acc.get_accumulated(v)
        total_elapsed_time = 0.0
        try:
            total_elapsed_time = float(v["total_elapsed_time"])
        except ValueError:
            pass
        agg_gflops = str(agg["self_gflop"] // total_elapsed_time) if total_elapsed_time > 0.0 else ""
        agg_ai = str(agg["self_gflop"] // agg["self_memory_gb"]) if agg["self_memory_gb"] > 0.0 else ""
        print(
            fmt_data.format(
                *(
                    ["-" * level + v[name_key] + " : " + v.type]
                    + [agg[k] for k in aggregate_keys]
                    + [total_elapsed_time, agg_gflops, agg_ai, v.self_memory_gb]
                )
            )
        )


# Traverse bottomup flat dataset and print accumulated values.
print("\n\n" + "BOTTOMUP: flat representation")
fmt_head = "{:^40}" + "| {:^38}" + "| {:^24}" * 2
fmt_data = "{:<40}" + "| {:<38}" + "| {:<24}" * 2

print("{:=<132}".format(""))
print(fmt_head.format(*([name_key.upper(), "TYPE", "AGGREGATED_GFLOPS", "AGGREGATED_AI"])))
print("{:=<132}".format(""))
for row in survey.bottomup:
    # Do not print non-executed binary parts.
    if "[Not Executed]" in row.type:
        continue
    agg_gflop = 0.0
    agg_memory_gb = 0.0
    total_elapsed_time = 0.0
    try:
        total_elapsed_time = float(row.total_elapsed_time)
    except ValueError:
        pass
    for topdownRow in row.sync:
        agg = acc.get_accumulated(topdownRow)
        agg_gflop += agg["self_gflop"]
        agg_memory_gb += agg["self_memory_gb"]

    agg_ai = str(agg_gflop // agg_memory_gb) if agg_memory_gb > 0.0 else ""
    agg_gflops = str(agg_gflop // total_elapsed_time) if total_elapsed_time > 0.0 else ""

    print(fmt_data.format(*([row[name_key], row.type, agg_gflops, agg_ai])))
