# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: MIT

import logging
import types
from typing import List, Dict

import numpy as np
import pandas as pd

from mpp.core.metric_computer import _MetricCompiler, CompiledMetric
from mpp.core.types import ThresholdDefinition


class CompiledThresholdMetric(CompiledMetric):

    def _handle_exception(self, e):
        if self._log_error:
            logging.warning(f'Cannot calculate threshold for \'{self.definition.name}\'. {e}')
            self._log_error = False
        return np.nan


class _ThresholdMetricCompiler(_MetricCompiler):

    name_attr = "name"

    def _create_metric_namespace(self, metric_definition: ThresholdDefinition) -> dict:
        namespace = {'metric_name': metric_definition.name}
        namespace.update(metric_definition.metric_aliases)
        return namespace




class ThresholdMetricComputer:

    __module_index = 0

    metric_compiler = _ThresholdMetricCompiler()

    def __init__(self, threshold_definition_list):
        self._metric_definitions = threshold_definition_list
        self._compiled_metrics = self._generate_compiled_metrics(self._metric_definitions)
        ThresholdMetricComputer.__module_index += 1

    def compute_metric(self, df: pd.DataFrame,
                       metrics_to_compute: List[CompiledMetric] = None) -> pd.DataFrame:
        result_df = pd.DataFrame()
        if not metrics_to_compute:
            metrics_to_compute = self._compiled_metrics
        for compiled_metric in metrics_to_compute:
            if self.__all_metric_references_are_available(compiled_metric.definition, df):
                result_df[compiled_metric.definition.name] = compiled_metric(df)
        return result_df

    def _generate_compiled_metrics(self, metric_definition_list):
        vectorized_compute_metric_code = self.metric_compiler.compile(metric_definition_list)
        generated_module_name = f'generated_code_{self.__module_index}'
        generated_module = self._import_code(vectorized_compute_metric_code, generated_module_name)
        metric_mapping = generated_module.get_metrics_mapping()

        compiled_metrics = self._get_compiled_metrics(metric_mapping)

        return compiled_metrics

    def _get_compiled_metrics(self, metric_mapping):
        compiled_metrics = []
        for metric in self._metric_definitions:
            metric_name = getattr(metric, self.metric_compiler.name_attr)
            if metric_name in metric_mapping:
                func, source = metric_mapping[metric_name]
                compiled_metrics.append(CompiledThresholdMetric(metric, func, source))
        return compiled_metrics

    @staticmethod
    def _import_code(code, module_name):
        module = types.ModuleType(module_name)
        exec(code, module.__dict__)
        return module

    @staticmethod
    def __all_metric_references_are_available(metric_def: ThresholdDefinition,
                                              df: pd.DataFrame) -> bool:
        for metric in metric_def.metric_aliases:
            symbol_name = metric_def.metric_aliases[metric]
            if symbol_name not in df.columns:
                logging.debug(f'Excluding threshold for \'{metric_def.name}\' from reports because {symbol_name} is '
                              f'unavailable.')
                return False
        return True
