在Streamlit应用中保持DataFrame修改的选项一致性

0 投票
0 回答
29 浏览
提问于 2025-04-11 22:08

我有一个叫做PredictaApp的Streamlit应用,用户可以上传一个CSV文件,然后对这个数据集进行各种数据分析和预处理。其中一个任务是填补数据集中的缺失值,这个工作是由missing_data模块里的DataImputer类来处理的。

predicta.py->

import streamlit as st
import pandas as pd
from DataExplore import explore
from FeatureCleaning import missing_data, outlier
from chat import ChatPredicta
from MLModel import predictmlalgo
from codeditor import PredictaCodeEditor
import theme


class PredictaApp:
    def __init__(self):
        self.df = None
        self.anthropi_api_key = None

    def show_hero_image(self):
        st.image("Hero.png")

    def show_footer(self):
        st.markdown("---")
        footer = "*copyright@infinitequants*"
        st.markdown(footer)

        footer_content = """
        <div class="footer">
            Follow us: &nbsp;&nbsp;&nbsp;
            <a href="https://github.com/ahammadnafiz" target="_blank">GitHub</a>  |
            <a href="https://twitter.com/ahammadnafi_z" target="_blank">Twitter</a> 
        </div>
        """
        st.markdown(footer_content, unsafe_allow_html=True)
    
    def file_upload(self):
        uploaded_file = st.file_uploader("Upload CSV", type=["csv"])
        if uploaded_file is not None:
            self.data = pd.read_csv(uploaded_file)
            self.df = self.data.copy(deep=True)
            
    def handle_sidebar(self):
        st.sidebar.title("Predicta")
        st.sidebar.markdown("---")

        self.file_upload()

        with st.sidebar:
            self.anthropi_api_key = st.text_input(
                "Anthropic API Key", key="file_qa_api_key", type="password"
            )
            "[Get an Anthropic API key](https://console.anthropic.com/)"

        st.sidebar.title("Tools")
        selected_option = st.sidebar.radio(
            "Select Option",
            [
                "Data Explore",
                "Impute Missing Values",
                "Detect Outlier",
                "Chat With Predicta",
                "PredictaCodeEditor",
                "Select ML Models",
            ],
        )
        if selected_option == "Data Explore":
            self.handle_data_explore()
        elif selected_option == "Impute Missing Values":
            self.handle_impute_missing_values()
        elif selected_option == "Detect Outlier":
            self.handle_detect_outlier()
        elif selected_option == "Chat With Predicta":
            self.handle_chat_with_predicta()
        elif selected_option == "PredictaCodeEditor":
            self.code_editor()
        elif selected_option == "Select ML Models":
            self.handle_select_ml_models()
            
        st.sidebar.markdown("---")
        self.handle_about()
        self.handle_help()

    def handle_about(self):
        st.sidebar.markdown("#### About")
        st.sidebar.info("Predicta is a powerful data analysis and machine learning tool designed to streamline your workflow and provide accurate predictions.")

    def handle_help(self):
        st.sidebar.markdown("#### Help")
        st.sidebar.info("For any assistance or inquiries, please contact us at support@predicta.com.")

    def handle_data_explore(self):
        if self.df is not None:
            analysis = explore.DataAnalyzer(self.df)
            analysis.analyzer()
        else:
            st.markdown(
                "<div style='text-align: center; margin-top: 20px; margin-bottom: 20px; font-size: 15px;'>Please upload a dataset to Explore.</div>",
                unsafe_allow_html=True,
            )
            st.image("uploadfile.png", use_column_width=True)

    def handle_impute_missing_values(self):
        if self.df is not None:
            impute = missing_data.DataImputer(self.df)
            impute.imputer()
        else:
            st.markdown(
                "<div style='text-align: center; margin-top: 20px; margin-bottom: 20px; font-size: 15px;'>Please upload a dataset to perform feature cleaning.</div>",
                unsafe_allow_html=True,
            )
            st.image("uploadfile.png", use_column_width=True)

    def handle_detect_outlier(self):
        if self.df is not None:
            out = outlier.OutlierDetector(self.df)
            out.outlier_detect()
        else:
            st.markdown(
                "<div style='text-align: center; margin-top: 20px; margin-bottom: 20px; font-size: 15px;'>Please upload a dataset to detect outlier.</div>",
                unsafe_allow_html=True,
            )
            st.image("uploadfile.png", use_column_width=True)

    def handle_chat_with_predicta(self):
        chat_page = ChatPredicta(self.df, self.anthropi_api_key)
        chat_page.chat_with_predicta()

    def code_editor(self):
        editor = PredictaCodeEditor()
        editor.run_code_editor(self.df)
        
    def handle_select_ml_models(self):
        if self.df is not None:
            model = predictmlalgo.PredictAlgo(self.df)
            model.algo()
        else:
            st.markdown(
                "<div style='text-align: center; margin-top: 20px; margin-bottom: 20px; font-size: 15px;'>Please upload a dataset to Perform Prediction.</div>",
                unsafe_allow_html=True,
            )
            st.image("uploadfile.png", use_column_width=True)

    def run(self):
        self.show_hero_image()
        self.handle_sidebar()
        self.show_footer()


if __name__ == "__main__":
    st.set_page_config(
        page_title="Predicta",
        page_icon="",
        initial_sidebar_state="expanded"
    )
    theme.footer()
    
    app = PredictaApp()
    app.run()

missing_data.py->

import pandas as pd
import numpy as np
import logging
import streamlit as st
from predicta import PredictaApp

class DataImputer(PredictaApp):
    def __init__(self, data):
        super().__init__()
        if not isinstance(data, pd.DataFrame):
            raise ValueError("Input data must be a pandas DataFrame.")
        self.data = data
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)
        self.logger.addHandler(logging.StreamHandler())

    def check_missing(self, output_path=None):
        try:
            result = pd.concat([self.data.isnull().sum(), self.data.isnull().mean()], axis=1)
            result = result.rename(index=str, columns={0: 'total missing', 1: 'proportion'})
            
            if output_path is not None:
                result.to_csv(output_path + 'missing.csv')
                self.logger.info('Result saved at %smissing.csv', output_path)
            return result
        except Exception as e:
            self.logger.error("An error occurred while checking missing values: %s", str(e))
            raise

    def drop_missing(self, axis=0):
        try:
            original_shape = self.data.shape
            self.data = self.data.dropna(axis=axis)
            if self.data.shape == original_shape:
                return None  
            else:
                return self.data
        except Exception as e:
            self.logger.error("An error occurred while dropping missing values: %s", str(e))
            raise

    def add_var_denote_NA(self, NA_col=[]):
        try:
            for i in NA_col:
                if self.data[i].isnull().sum() > 0:
                    self.data[i] = np.where(self.data[i].isnull(), 1, 0)
                    return self.data
                else:
                    self.logger.warning("Column %s has no missing cases", i)
        except Exception as e:
            self.logger.error("An error occurred while adding variable to denote NA: %s", str(e))
            raise

    def impute_NA_with_arbitrary(self, impute_value, NA_col=[]):
        try:
            for i in NA_col:
                if self.data[i].isnull().sum() > 0:
                    self.data[i].fillna(impute_value, inplace=True)
                else:
                    self.logger.warning("Column %s has no missing cases", i)
        except Exception as e:
            self.logger.error("An error occurred while imputing NA with arbitrary value: %s", str(e))
            raise

    def impute_NA_with_avg(self, strategy='mean', NA_col=[]):
        try:
            for i in NA_col:
                if self.data[i].isnull().sum() > 0:
                    if strategy == 'mean':
                        self.data[i].fillna(self.data[i].mean(), inplace=True)
                    elif strategy == 'median':
                        self.data[i].fillna(self.data[i].median(), inplace=True)
                    elif strategy == 'mode':
                        self.data[i].fillna(self.data[i].mode()[0], inplace=True)
                    return self.data
                    
                else:
                    self.logger.warning("Column %s has no missing", i)
        except Exception as e:
            self.logger.error("An error occurred while imputing NA with average: %s", str(e))
            raise

    def impute_NA_with_end_of_distribution(self, NA_col=[]):
        try:
            for i in NA_col:
                if self.data[i].isnull().sum() > 0:
                    self.data[i].fillna(self.data[i].mean() + 3 * self.data[i].std(), inplace=True)
                    return self.data
                else:
                    self.logger.warning("Column %s has no missing", i)
        except Exception as e:
            self.logger.error("An error occurred while imputing NA with end of distribution: %s", str(e))
            raise

    def impute_NA_with_random(self, NA_col=[], random_state=0):
        try:
            for i in NA_col:
                if self.data[i].isnull().sum() > 0:
                    random_sample = self.data[i].dropna().sample(self.data[i].isnull().sum(), random_state=random_state)
                    random_sample.index = self.data[self.data[i].isnull()].index
                    self.data.loc[self.data[i].isnull(), i] = random_sample
                    return self.data
                
                else:
                    self.logger.warning("Column %s has no missing", i)
        except Exception as e:
            self.logger.error("An error occurred while imputing NA with random sampling: %s", str(e))
            raise

    def impute_NA_with_interpolation(self, method='linear', limit=None, limit_direction='forward', NA_col=[]):
        try:
            for i in NA_col:
                if self.data[i].isnull().sum() > 0:
                    self.data[i] = self.data[i].interpolate(method=method, limit=limit, limit_direction=limit_direction)
                    return self.data
                else:
                    self.logger.warning("Column %s has no missing cases", i)
        except Exception as e:
            self.logger.error("An error occurred while imputing NA with interpolation: %s", str(e))
            raise

    def impute_NA_with_knn(self, NA_col=[], n_neighbors=5):
        try:
            from sklearn.impute import KNNImputer
            knn_imputer = KNNImputer(n_neighbors=n_neighbors)
            for i in NA_col:
                if self.data[i].isnull().sum() > 0:
                    imputed_values = knn_imputer.fit_transform(self.data[i].values.reshape(-1, 1))
                    self.data[i] = imputed_values.ravel()
                    return self.data
                else:
                    self.logger.warning("Column %s has no missing cases", i)
        except Exception as e:
            self.logger.error("An error occurred while imputing NA with KNN: %s", str(e))
            raise

    def impute_NA_with_mice(self, NA_col=[], n_iterations=10):
        try:
            from impyute.imputation.cs import mice
            for i in NA_col:
                if self.data[i].isnull().sum() > 0:
                    imputed_data = mice(self.data[i].values.reshape(1, -1), n_iterations=n_iterations)
                    self.data[i] = imputed_data.T.ravel()
                    return self.data
                else:
                    self.logger.warning("Column %s has no missing cases", i)
        except Exception as e:
            self.logger.error("An error occurred while imputing NA with MICE: %s", str(e))
            raise

    def impute_NA_with_missforest(self, NA_col=[], n_estimators=100, max_depth=None):
        try:
            from missingpy import MissForest
            imputer = MissForest(n_estimators=n_estimators, max_depth=max_depth)
            for i in NA_col:
                if self.data[i].isnull().sum() > 0:
                    imputed_data = imputer.fit_transform(self.data[i].values.reshape(-1, 1))
                    self.data[i] = imputed_data.ravel()
                    return self.data
                else:
                    self.logger.warning("Column %s has no missing cases", i)
        except Exception as e:
            self.logger.error("An error occurred while imputing NA with MissForest: %s", str(e))
            raise

    def imputer(self):
        
        st.markdown(
    "<h1 style='text-align: center; font-size: 30px;'>Impute Missing Values</h1>", 
    unsafe_allow_html=True
)
        st.markdown("---")
        st.markdown("<h2 style='text-align: center; font-size: 20px;'>Original Dataset</h1>", unsafe_allow_html=True)
        st.dataframe(self.data, width=800)

        
        option = st.selectbox("Select an Imputation Method", [
            "Check Missing Values",
            "Drop Missing Values",
            "Add Variable to Denote NA",
            "Impute NA with Arbitrary Value",
            "Impute NA with Interpolation",
            "Impute NA with KNN",
            "Impute NA with MICE",
            "Impute NA with MissForest",
            "Impute NA with Average",
            "Impute NA with End of Distribution",
            "Impute NA with Random Sampling"
        ])

        if option == "Check Missing Values":
            if st.button("Check"):
                self.check_missing()
                st.write(self.check_missing())

        elif option == "Drop Missing Values":
            axis = st.radio("Drop rows or columns?", ["Rows", "Columns"])
            axis = 0 if axis == "Rows" else 1
            if st.button("Drop"):
                self.drop_missing(axis=axis)
                if self.data is not None:
                    st.dataframe(self.data)
                else:
                    st.warning("No missing values found in the data.")

        elif option == "Add Variable to Denote NA":
            selected_columns = st.multiselect("Select columns to impute", options=self.data.columns)
            if st.button("Add"):
                if selected_columns:
                    data_add_var = self.add_var_denote_NA(NA_col=selected_columns)
                    st.write(data_add_var)
                else:
                    st.warning("Please select at least one column to impute")

        elif option == "Impute NA with Arbitrary Value":
            impute_value = st.text_input("Enter Arbitrary Value")
            na_cols = st.multiselect("Select Columns", self.data.columns)
            if st.button("Impute Arbitrary Value"):
                data_impute_arb = self.impute_NA_with_arbitrary(impute_value=float(impute_value), NA_col=na_cols)
                st.write(data_impute_arb)

        elif option == "Impute NA with Interpolation":
            na_cols = st.multiselect("Select Columns", self.data.columns)
            interp_method = st.selectbox("Interpolation Method", ['linear', 'quadratic', 'cubic'])
            interp_limit = st.text_input("Limit", None)
            interp_limit_direction = st.selectbox("Limit Direction", ['forward', 'backward', 'both'])
            if st.button("Impute Interpolation"):
                data_interp = self.impute_NA_with_interpolation(method=interp_method, limit=interp_limit, limit_direction=interp_limit_direction, NA_col=na_cols)
                st.write(data_interp)

        elif option == "Impute NA with KNN":
            n_neighbors = st.number_input("Number of Neighbors", min_value=1, value=5)
            selected_columns = st.multiselect("Select columns to impute", options=self.data.columns)
            if st.button("Impute KNN"):
                if selected_columns:  # Check if at least one column is selected
                    data_knn = self.impute_NA_with_knn(NA_col=selected_columns, n_neighbors=n_neighbors)
                    st.write(data_knn)
                else:
                    st.warning("Please select at least one column to impute")

        elif option == "Impute NA with MICE":
            na_cols = st.multiselect("Select Columns", self.data.columns)
            n_iterations = st.number_input("Number of Iterations", min_value=1, value=10)
            if st.button("Impute MICE"):
                data_mice = self.impute_NA_with_mice(NA_col=na_cols, n_iterations=n_iterations)
                st.write(data_mice)

        elif option == "Impute NA with MissForest":
            na_cols = st.multiselect("Select Columns", self.data.columns)
            n_estimators = st.number_input("Number of Estimators", min_value=1, value=100)
            max_depth = st.text_input("Max Depth", None)
            if st.button("Impute MissForest"):
                data_missforest = self.impute_NA_with_missforest(NA_col=na_cols, n_estimators=n_estimators, max_depth=max_depth)
                st.write(data_missforest)

        elif option == "Impute NA with Average":
            na_cols = st.multiselect("Select Columns", self.data.columns)
            strategy = st.selectbox("Imputation Strategy", ['mean', 'median', 'mode'])
            if st.button("Impute Average"):
                data_avg = self.impute_NA_with_avg(strategy=strategy, NA_col=na_cols)
                st.write(data_avg)

        elif option == "Impute NA with End of Distribution":
            na_cols = st.multiselect("Select Columns", self.data.columns)
            if st.button("Impute End of Distribution"):
                data_end_dist = self.impute_NA_with_end_of_distribution(NA_col=na_cols)
                st.write(data_end_dist)

        elif option == "Impute NA with Random Sampling":
            na_cols = st.multiselect("Select Columns", self.data.columns)
            random_state = st.number_input("Random State", min_value=0, value=0)
            if st.button("Impute Random"):
                data_random = self.impute_NA_with_random(NA_col=na_cols, random_state=random_state)
                st.write(data_random)
        
        return self.data

我希望在我修改missing_data.py的时候,也能同时修改原始的数据框(df)。

我遇到的问题是,当我在Streamlit应用中切换不同的选项(比如数据探索、填补缺失值、检测异常值等)时,之前选项中对数据框所做的修改并没有被保留。具体来说,当我使用DataImputer类填补缺失值后,再切换到其他选项时,原始的未修改数据框又被加载了,这样就覆盖了填补过程中所做的更改。

我希望确保在一个选项中(例如填补缺失值)对数据框所做的修改能够传递到其他选项(例如检测异常值、选择机器学习模型等),这样后续的操作就可以在修改后的数据框上进行,而不是在原始的未修改数据框上。

总的来说,问题就是:“如何在Streamlit应用中保留对数据框的修改,特别是在使用DataImputer类填补缺失值之后。”

我尝试修改DataImputer类中的imputer方法,让它返回修改后的数据框,然后再把这个数据框赋值给PredictaApp类中的self.df。我希望这样能确保在DataImputer类中对数据框的修改能够反映到PredictaApp类中的self.df上,并且在不同选项中对数据框的后续操作都能使用修改后的数据。

然而,即使在做了这些修改后,当我在Streamlit中切换到一个新选项时,原始的未修改数据框还是被重新加载了,覆盖了之前选项中所做的修改。

0 个回答

暂无回答

撰写回答