# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

from collections.abc import Iterable, Mapping, Sequence
from functools import reduce
from typing import Literal, Optional, TypeVar, Union

import numpy as np
import pandas as pd

from nsys_recipe.lib import exceptions
from nsys_recipe.log import logger


def filter_by_column_value(
    df: pd.DataFrame, column_name: str, values_to_keep: list[str]
) -> None:
    """Filter dataframes to retain only the values that we want to plot,

    Parameters
    ----------
    df : dataframe
        The dataframe to filter.
    column_name : str
        Name of the column that we want to edit.
    values_to_keep : list of str
        A list of strings that correspond to values we want to keep in the dataframe.
    """
    if df.empty:
        return

    mask = pd.Series(True, index=df.index)
    mask = df[column_name].isin(values_to_keep)

    # Discard the rows that don't meet the `mask` criteria.
    df.drop(df[~mask].index, inplace=True)


_T = TypeVar("_T")


def filter_none(items: Iterable[Optional[_T]]) -> list[_T]:
    """Remove Nones from the list.

    This function is deprecated, please use `filter_none_or_empty` instead.
    """
    # Unfortunately warnings.warn is not visible with `nsys recipe`
    logger.warning(
        "filter_none is deprecated, please update your recipe to use filter_none_or_empty instead."
    )
    return filter_none_or_empty(items, mode="all")


def filter_none_or_empty(
    items: Iterable[Optional[_T]], mode: Literal["any", "all"] = "any"
) -> list[_T]:
    """Remove Nones or empty dataframes from the list.

    Parameters
    ----------
    items : Iterable
        Items to filter. The items themselves can be anything.
        If an item is a DataFrame, it is skipped if empty.
        If an item is a Sequence, it is skipped if any or all
        of the DataFrames in this Sequence are empty.
    mode : Literal["any", "all"]
        - any: Filter out items if ANY contained DataFrame is empty.
        - all: Filter out items if ALL contained DataFrames are empty.

    Returns
    -------
    filtered_items : list
        Filtered list with Nones or empty dataframes removed.

    Raises
    ------
    exceptions.NoDataError
        If no items remain after filtering.
    """

    def get_dataframes(item: _T) -> list[pd.DataFrame]:
        if isinstance(item, pd.DataFrame):
            return [item]
        elif isinstance(item, Sequence):
            return [x for x in item if isinstance(x, pd.DataFrame)]
        else:
            return []

    def should_keep(item: _T) -> bool:
        dataframes = get_dataframes(item)
        if not dataframes:
            return True

        empty_dfs = [df.empty for df in dataframes]
        return not any(empty_dfs) if mode == "any" else not all(empty_dfs)

    filtered_items = [item for item in items if item is not None and should_keep(item)]

    if not filtered_items:
        raise exceptions.NoDataError("No valid data remaining after filtering")

    return filtered_items


def stddev(
    group_df: pd.DataFrame,
    series_dict: Mapping[str, pd.Series],
    n_col_name: str = "Instances",
) -> float:
    """Calculate the standard deviation out of aggregated values.

    Parameters
    ----------
    group_df : dataframe
        Subset of data sharing a common grouping key. It contains values before
        the overall aggregation.
    series_dict : dict
        Dictionary mapping aggregators to their corresponding values.
    n_col_name : str
        Name of the column representing population size.
    """
    instance = series_dict[n_col_name].loc[group_df.name]
    if instance <= 1:
        return group_df["StdDev"].iloc[0]

    var_sum = np.dot(group_df[n_col_name] - 1, group_df["StdDev"] ** 2)
    deviation = group_df["Avg"] - series_dict["Avg"].loc[group_df.name]
    dev_sum = np.dot(group_df[n_col_name], deviation**2)
    variance = (var_sum + dev_sum) / (instance - 1)
    return (variance**0.5).round(1)


MergeHow = Literal["left", "right", "outer", "inner", "cross"]


def merge(
    dfs: Iterable[Optional[pd.DataFrame]],
    merge_by_columns: Union[list[str], tuple[str, ...]],
    how: MergeHow = "inner",
) -> Optional[pd.DataFrame]:
    """
    Merge multiple dataframes on specified columns.
    For each dataframe, the columns are renamed to avoid conflicts.
    Renaming is done by appending the dataframe index to the column name.
    """
    updated_dfs = []
    for idx, df in enumerate(dfs):
        if df is None:
            continue
        columns_to_rename = [col for col in df.columns if col not in merge_by_columns]
        df = df.rename(columns={col: f"{col}#{idx}" for col in columns_to_rename})
        updated_dfs.append(df)

    if len(updated_dfs) == 0:
        return None

    return reduce(
        lambda left, right: left.merge(right, on=merge_by_columns, how=how), updated_dfs
    )
