# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: MIT

from dataclasses import dataclass
from pathlib import Path
import tempfile
import os

from typing import List, Tuple, Optional, Dict

import pandas as pd

from mpp import ViewData, ViewAttributes, ViewType, ViewAggregationLevel, Partition, Device
from mpp.core.types import EventInfoDataFrame


@dataclass(frozen=True)
class DataViewsDescriptor:
    file_path: str


def _get_hdf_view_name_identifier(view_name) -> str:
    # Although to_hdf claims to take a string for key, pytables throws a warning
    # when the key doesn't qualify as a legitimate python identifier (THORS-264)
    # Replace '-' (from core types like p-core, e-core, etc) with '_' for the
    # key only. This doesn't affect view name attributes in the views themselves.
    return view_name.replace('-', '_')


class DataViewsSerializer:
    """
    Utility class for serializing `DataView` objects using an efficient binary format (HDF).
    """

    def __init__(self, parent_dir: Path = None):
        """
        Initialize the serializer
        :param parent_dir: directory in which to store the serialized data. If None, the default temporary directory
                           is used (OS dependent).
        """
        if parent_dir is not None:
            parent_dir = parent_dir.resolve()
            if not parent_dir.is_dir():
                raise ValueError(f'Invalid parent_dir: {parent_dir}')

        self.parent_dir = tempfile.TemporaryDirectory(dir=parent_dir)

    def cleanup(self):
        self.parent_dir.cleanup()

    def write_views(self, views: List[ViewData], partition: Partition = None) -> \
            DataViewsDescriptor:
        """
        Writes post processor data views to a file.
        :views: list of `ViewData` objects to write.
        :partition: an optional partition object describing the sample range stored in `views`.
        :return: a descriptor object that stores information about the file to which the data was persisted.
        """
        hdf_file_handle, hdf_file_path = tempfile.mkstemp(dir=self.parent_dir.name,
                                                          prefix=self.format_temp_file_prefix(partition, views))
        try:
            view_attributes = [v.attributes for v in views]

            # Write view order and attributes as header
            view_attr_data = [
                [va.view_name, va.view_type.value, va.aggregation_level.value, va.device.type_name, va.show_modules]
                for va in view_attributes
            ]
            view_attr_df = pd.DataFrame(view_attr_data, columns=['view_name', 'view_type', 'aggregation_level',
                                                                 'type_name', 'show_modules'])
            view_attr_df.to_hdf(hdf_file_path, mode='w', key='views')

            # Write partition information, if provided
            if partition:
                partition_df = pd.DataFrame(vars(partition), index=[0])
                partition_df.to_hdf(hdf_file_path, key='partition')

            # Write view data
            for view in views:
                view_name_id = _get_hdf_view_name_identifier(view.attributes.view_name)
                view.attributes.required_events.to_hdf(hdf_file_path, key=f'{view_name_id}_columns')
                view.data.to_hdf(hdf_file_path, key=view_name_id)
                if not view.retire_latency_counts.empty:
                    view.retire_latency_counts.to_hdf(hdf_file_path, key=f'{view_name_id}_rl_counts')
                if not view.static_msr_counts.empty:
                    view.static_msr_counts.to_hdf(hdf_file_path, key=f'{view_name_id}_s_msr_counts')

        finally:
            os.close(hdf_file_handle)

        return DataViewsDescriptor(hdf_file_path)

    @staticmethod
    def format_temp_file_prefix(partition, views):
        return str(partition.first_sample) + '__' + str(partition.last_sample) + '__' + \
            views[0].attributes.view_type.name + '__'


class DataViewsDeserializer:
    """
    Utility class for efficiently deserializing post processor `DataView` objects from a binary file (HDF).
    """

    def __init__(self):
        pass

    @staticmethod
    def read_views(descriptor: Path,
                   delete_after_read=False) -> Tuple[Dict[str, ViewData], Optional[Partition]]:
        """
        Reads post processor data views from a file.
        :param descriptor: a descriptor object that stores information about the file from which to read data.
                           Use the `DataViewSerializer.write_views` method to create descriptor objects.
        :param delete_after_read: set to True to delete the storage file after it was read.
        :return: a tuple:
                   First element: a list of the `DataView` objects
                   Second element: a partition object describing the range of samples in the data views.
                                   This element will be None if there is no partition information.
        """
        hdf_file_path = Path(descriptor).resolve()
        if not hdf_file_path.is_file():
            raise ValueError(f'Invalid descriptor. File not found: {hdf_file_path}')

        views_data = {}
        partition = None

        # Read information about the views stored in the file
        view_attr_df = pd.read_hdf(hdf_file_path, key='views')

        # Read partition data (if exists)'
        try:
            partition_df = pd.read_hdf(hdf_file_path, key='partition')
            partition = Partition(**partition_df.loc[0].to_dict())
        except KeyError:
            pass  # Partition data does not exist. It's optional, so this is OK.

        # Read views data stored in the file
        for view_attr_row in view_attr_df.iterrows():
            view_attr_data = view_attr_row[1]
            view_name_id = _get_hdf_view_name_identifier(view_attr_data.view_name)
            view_columns_df = pd.read_hdf(hdf_file_path, key=f'{view_name_id}_columns')
            view_data_df = pd.read_hdf(hdf_file_path, key=view_name_id)
            view_attr = ViewAttributes(view_name=view_attr_data.view_name,
                                       view_type=ViewType(view_attr_data.view_type),
                                       aggregation_level=ViewAggregationLevel(view_attr_data.aggregation_level),
                                       metric_computer=None,
                                       normalizer=None,
                                       device=Device(view_attr_data.type_name),
                                       show_modules=view_attr_data.show_modules,
                                       required_events=EventInfoDataFrame(view_columns_df))

            # optional retire latency and static msr counts DataFrames
            try:
                retire_latency_counts = pd.read_hdf(hdf_file_path, key=f'{view_name_id}_rl_counts')
            except KeyError:
                retire_latency_counts = pd.DataFrame()
            try:
                static_msr_counts = pd.read_hdf(hdf_file_path, key=f'{view_name_id}_s_msr_counts')
            except KeyError:
                static_msr_counts = pd.DataFrame()

            views_data[view_attr_row[1].loc['view_name']] = ViewData(view_attr, view_data_df, retire_latency_counts, static_msr_counts)

        if delete_after_read:
            hdf_file_path.unlink()

        return views_data, partition
