# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: MIT

import json
import logging

import jsonschema.exceptions
from jsonschema import validate
import re
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Dict

import defusedxml.ElementTree as ET

from mpp.core.types import MetricDefinition, Threshold
from mpp.core.validators import FileValidator


def validate_file(file_path: Path):
    file_validator = FileValidator(file_must_exist=True, max_file_size=10 * 1024 * 1024)
    file_validator(str(file_path))


class JsonObjectValidator():
    '''
    Validate json file
    '''

    def __init__(self, file_spec: Path, schema: dict):
        """
        :param: json_file: json file to be validated
        :param: json_schema: json schema to validate the json_file against
        """
        self._json_file = file_spec
        self._json_object = None
        validate_file(file_spec)
        self.__validate_json_syntax(file_spec)
        self.__validate_json_schema(schema)

    @property
    def json_object(self):
        return self._json_object

    def __validate_json_syntax(self, file_spec):
        with open(file_spec) as json_file:
            try:
                self._json_object = json.load(json_file)
            except json.decoder.JSONDecodeError as ex:
                raise SyntaxError(f'syntax error in {self._json_file}')

    def __validate_json_schema(self, schema):
        try:
            validate(self._json_object, schema)
        except jsonschema.exceptions.ValidationError:
            raise jsonschema.exceptions.ValidationError(f'{self._json_file} does not conform to schema')


class MetricDefinitionParser(ABC):
    """
    Metric definition parser abstract base class (ABC)
    """

    def __init__(self, file_path: Path):
        """
        Constructor
        :param file_path: metric definition file to parse
        """
        self._metric_file = file_path
        self.legacy_constant_map = {
            # key = legacy constant name, value = current constant name
              'system.cha_count/system.socket_count': 'CHAS_PER_SOCKET',
        }

    @abstractmethod
    def parse(self) -> List[MetricDefinition]:
        """
        Parse metric definitions
        :return: list of parsed metrics
        """
        validate_file(self._metric_file)

    @staticmethod
    def _adjust_formula(name: str, constants: Dict, formula: str):
        # Adjusts the formula of sampling time to get the average sampling time per sample in the aggregate
        if name == "metric_EDP EMON Sampling time (seconds)":
            constants['samples'] = '$processed_samples'
            formula += ' / samples'

        return constants, formula


class JsonParser(MetricDefinitionParser):
    """
    Parser for Data Lake metric definition files (JSON)
    """

    def parse(self) -> List[MetricDefinition]:
        """
        Parse metric definitions
        :return: list of parsed metrics
        """
        super().parse()
        with open(self._metric_file) as metrics_file:
            metric_defs_json = json.load(metrics_file)
            metrics = metric_defs_json['Metrics']
            metric_defs = []
            for m in metrics:
                metric_defs.append(self.__create_from_json(m))
            return metric_defs

    def __create_from_json(self, metric_def):
        constant_alias_map = self.__set_alias_map(metric_def['Constants'], self.legacy_constant_map)
        event_alias_map = self.__set_alias_map(metric_def['Events'])
        name = metric_def['LegacyName']
        original_formula = metric_def['Formula']
        constant_alias_map, original_formula = \
            MetricDefinitionParser._adjust_formula(name, constant_alias_map, original_formula)
        to_python_converter = _ToPythonConverter(original_formula)
        python_formula = to_python_converter.convert()
        return MetricDefinition(name, self.__get_optional_key(metric_def, 'ThroughputName'),
                                self.__get_optional_key(metric_def, 'BriefDescription'), python_formula,
                                event_aliases=event_alias_map, constants=constant_alias_map,
                                level=metric_def['Level'], 
                                unit_of_measure=self.__get_optional_key(metric_def, 'UnitOfMeasure'),
                                category=self.__get_optional_key(metric_def, 'Category'), 
                                threshold=self.__get_optional_key(metric_def, 'Threshold'),
                                resolution_levels=self.__get_optional_key(metric_def, 'ResolutionLevels').split(', '),
                                metric_group=self.__get_optional_key(metric_def, 'MetricGroup'))

    @staticmethod
    def __get_optional_key(data, key, default=''):
        return data.get(key, default)

    @staticmethod
    def __set_alias_map(definitions, alternatives=None):
        if alternatives is None:
            alternatives = {}
        alias_map = {}
        for definition in definitions:
            alias_map.update({definition['Alias']: alternatives.get(definition['Name'].strip(), definition['Name'])})
        return alias_map


class XmlParser(MetricDefinitionParser):
    """
    Parser for XML metric definition files
    """

    retire_latency_str = 'retire_latency'

    def parse(self) -> List[MetricDefinition]:
        """
        Parse metric definitions
        :return: list of parsed metrics
        """
        super().parse()
        metric_defs = []
        root = ET.parse(self._metric_file, forbid_dtd=True, forbid_entities=True, forbid_external=True).getroot()
        for metric in root.findall('metric'):
            metric_defs.append(self.__create_from_xml(metric))
        return metric_defs

    def __create_from_xml(self, metric_def):
        name = metric_def.get('name')
        throughput_name = self._get_throughput_name(metric_def)
        description = self._get_description(metric_def)
        constants = self._get_constants(metric_def)
        events, latencies = self._get_dicts_for_events(metric_def)
        times = self._get_dict_for_tag('time', metric_def)
        events.update(times)
        threshold = self._get_threshold(metric_def)

        original_formula = metric_def.find('formula').text
        self._verify_constants_in_formula(name, constants, original_formula)
        constants, original_formula = MetricDefinitionParser._adjust_formula(name, constants, original_formula)
        to_python_converter = _ToPythonConverter(original_formula)
        python_formula = to_python_converter.convert()
        return MetricDefinition(name, throughput_name, description, python_formula, events, constants, latencies, '',
                                threshold=threshold)

    @staticmethod
    def _verify_constants_in_formula(name:str, constants: Dict[str, str], formula: str):
        for constant in constants:
            if constant not in formula:
                logging.debug(f'Constant \'{constant}\', defined for \'{name}\', is not used in the metric formula.')

    @staticmethod
    def _get_dicts_for_events(metric_def):
        events = {}
        latencies = {}
        for md in metric_def.findall('event'):
            if XmlParser.retire_latency_str not in md.text:
                events[md.get('alias')] = md.text
            else:
                latencies[md.get('alias')] = md.text
        return events, latencies

    @staticmethod
    def _get_dict_for_tag(tag: str, metric_def):
        d = {}
        for md in metric_def.findall(tag):
            d[md.get('alias')] = md.text
        return d

    @staticmethod
    def _get_throughput_name(metric_def) -> str:
        for tmn in metric_def.findall('throughput-metric-name'):
            # assume only one <throughput-metric-name> elements
            return tmn.text
        return ''

    @staticmethod
    def _get_description(metric_def) -> str:
        for d in metric_def.findall('description'):
            # assume only one <description> element
            return d.text if d.text is not None else ''
        return ''

    def _get_constants(self, metric_def) -> Dict[str, str]:
        constants = {}
        for const in metric_def.findall('constant'):
            # Hack - replace the "system.cha_count/system.socket_count" expression with a "chas_per_socket" constant
            # TODO: cleanup/refactor
            constants[const.get('alias')] = self.legacy_constant_map.get(const.text.strip(), const.text)
        return constants

    @staticmethod
    def _get_threshold(metric_def):
        threshold_metric_tags = metric_def.findall('threshold_metric')
        if not threshold_metric_tags:
            return ''
        threshold_metric_aliases = {tag.get('alias'): tag.text for tag in threshold_metric_tags}

        threshold_formula_tag = metric_def.find('threshold_formula')
        threshold_formula = threshold_formula_tag.text if threshold_formula_tag is not None else logging.debug(
            'Threshold is missing a formula')
        if not threshold_formula:
            return ''
        threshold_formula_raw_tag = metric_def.find('threshold_formula_raw')
        threshold_formula_raw = threshold_formula_raw_tag.text if threshold_formula_raw_tag is not None else None

        return Threshold(metric_aliases=threshold_metric_aliases, formula=threshold_formula,
                         formula_raw=threshold_formula_raw)


class _ToPythonConverter:
    """
    Converts non-Python expressions into Python equivalents
    """
    def __init__(self, expression):
        self.expression = expression
        self.regex_patterns = {'ternary': r'(.*)\?(.*):(.*)'}

    def convert(self):
        expression_type = self.__determine_expression_type()
        if expression_type == 'ternary':
            self.expression = self.__convert_ternary_expression()
        return self.expression

    def __determine_expression_type(self):
        for key, pattern in self.regex_patterns.items():
            if re.match(pattern, self.expression):
                return key

    def __convert_ternary_expression(self):
        out_formula = self.expression
        # Is there a C-style ternary operator?
        pattern = r'(.*)\?(.*):(.*)'
        ternary_pattern_match = re.match(self.regex_patterns['ternary'], out_formula)
        while ternary_pattern_match:
            # Find each subexpression in the formula
            stack = []
            subexpression_index_pairs = []
            for index, char in enumerate(out_formula):
                if char == '(':
                    stack.append(index + 1)
                elif char == ')':
                    subexpression_index_pairs.append((stack.pop(), index - 1))
            subexpression_index_pairs.append((0, len(out_formula) - 1))
            # Find the innermost subexpression containing the ternary expression and transform it
            for subexpression_index_pair in subexpression_index_pairs:
                subexpression = out_formula[subexpression_index_pair[0]: subexpression_index_pair[1] + 1]
                ternary_match = re.match(pattern, subexpression)
                if ternary_match:
                    [cond, val1, val2] = ternary_match.groups()
                    out_formula = out_formula.replace(subexpression, '{0} if {1} else {2}'.format(val1, cond, val2))
                    break
            ternary_pattern_match = re.match(pattern, out_formula)

        return out_formula


class MetricDefinitionParserFactory:
    """
    Create a metric definition parser based on file type. Use `create(file_path)` method to create the appropriate
    metric definition parser.
    """
    parser_for_file_type = {
        '.xml': XmlParser,
        '.json': JsonParser
    }

    @classmethod
    def create(cls, file_path: Path):
        """
        Creates a parser object based on the given file type
        :param file_path: metric file to parse
        :return: an implementation of MetricDefinitionParser suitable for parsing the specified file
        """
        if file_path.suffix not in cls.parser_for_file_type:
            raise ValueError(f'No metric definition parser defined for files of type {file_path.suffix}')
        return cls.parser_for_file_type[file_path.suffix](file_path)


class JsonConstantParser:
    """
    Parser for json retire latency constant definition files
    """

    # retire latency schema for validating -l/--retire-latency json
    # command line input see SDL: T1148/T1149
    # https://sdp-prod.intel.com/bunits/intel/thor-next-gen-edp/thor/tasks/phase/development/11130-T1148/
    # https://sdp-prod.intel.com/bunits/intel/thor-next-gen-edp/thor/tasks/phase/development/11130-T1149/
    schema = {
        "type": "object",
        "properties": {
            "metric:retire_latency": {
                "type": "object",
                "properties": {
                    "MEAN": {
                        "type": "number",
                        "minimum": 0.0,
                        "maximum": 1000000.0
                    },
                    "other-stat": {
                        "type": "number",
                        "minimum": 0.0,
                        "maximum": 1000000.0
                    },
                },
                "required": ["MEAN"]
            }
        },
    }

    def __init__(self, file_path: Path, constant_descriptor: str):
        """
        :param file_path: file to parse metric constants from
        """
        self.__constant_descriptor = constant_descriptor
        self.__json_file = file_path

    def parse(self) -> Dict[str, float]:
        """
        Parse constants from json file
        :return: dictionary of parsed constants
        """
        json_object_validator = JsonObjectValidator(self.__json_file, JsonConstantParser.schema)
        json_constants = json_object_validator.json_object
        return self._get_constants_for_descriptor(json_constants)

    def _get_constants_for_descriptor(self, json_constants) -> Dict[str, float]:
        """
        Extract constants that match the constant descriptor from a json dictionary
        :param json_constants: dictionary of constants from parsed from json
        :return: dictionary of constants that match the constant descriptor or an
                 empty dictionary if the constant descriptor isn't found
        """
        constants = {}
        for c in json_constants:
            if self.__constant_descriptor in json_constants[c]:
                constants[c] = json_constants[c].get(self.__constant_descriptor)
        return constants

