# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import numpy as np
import pandas as pd

from nsys_recipe.lib import exceptions


def filter_by_column_value(df, column_name, values_to_keep):
    """Filter dataframes to retain only the values that we want to plot,

    Parameters
    ----------
    df : dataframe
        The dataframe to filter.
    column_name : str
        Name of the column that we want to edit.
    values_to_keep : list of str
        A list of strings that correspond to values we want to keep in the dataframe.
    """
    if df.empty:
        return

    mask = pd.Series(True, index=df.index)
    mask = df[column_name].isin(values_to_keep)

    # Discard the rows that don't meet the `mask` criteria.
    df.drop(df[~mask].index, inplace=True)


def filter_none(dfs):
    """Remove Nones from the dataframe list.

    If the list only contains Nones or empty dataframes, raise an exception.
    """
    dfs = [df for df in dfs if df is not None and len(df) != 0]
    if not dfs:
        raise exceptions.NoDataError
    return dfs


def stddev(group_df, series_dict, n_col_name="Instances"):
    """Calculate the standard deviation out of aggregated values.

    Parameters
    ----------
    group_df : dataframe
        Subset of data sharing a common grouping key. It contains values before
        the overall aggregation.
    series_dict : dict
        Dictionary mapping aggregators to their corresponding values.
    n_col_name : str
        Name of the column representing population size.
    """
    instance = series_dict[n_col_name].loc[group_df.name]
    if instance <= 1:
        return group_df["StdDev"].iloc[0]

    var_sum = np.dot(group_df[n_col_name] - 1, group_df["StdDev"] ** 2)
    deviation = group_df["Avg"] - series_dict["Avg"].loc[group_df.name]
    dev_sum = np.dot(group_df[n_col_name], deviation**2)
    variance = (var_sum + dev_sum) / (instance - 1)
    return (variance**0.5).round(1)
