读入时一个热编码tf数据

2024-04-18 12:31:53 发布

您现在位置:Python中文网/ 问答频道 /正文

Tensai-Flow平台上运行的Tensai-Flow。数据集很大,并不是所有的数据都可以同时保存在内存中,因此我使用以下代码将数据读入tf.dataset

def read_dataset(filepattern):
    def decode_csv(value_column):
        cols = tf.io.decode_csv(value_column, record_defaults=[[0.0],[0],[0.0])
        features=[cols[1],cols[2]]
        label = cols[0]
        return features, label
    # Create list of files that match pattern
    file_list = tf.io.gfile.glob(filepattern)
    # Create dataset from file list
    dataset = tf.data.TextLineDataset(file_list).map(decode_csv)
    return dataset

training_data=read_dataset(<filepattern>)

问题是数据中的第二列是分类的,我需要使用一种热编码。如何做到这一点,无论是在函数decode_csv中,还是在以后操作tf.dataset。在


Tags: csv数据readvaluetfdefcolumnflow
1条回答
网友
1楼 · 发布于 2024-04-18 12:31:53

您可以使用tf.one_hot。假设第二列是cols[1],并且类别值已转换为整数,则可以执行以下操作:

def decode_csv(value_column):
    cols = tf.io.decode_csv(value_column, record_defaults=[[0.0],[0],[0.0]])
    features=[cols[1], tf.one_hot(cols[2], nb_classes)]
    label = cols[0]
    return features, label

注意:未测试。在

相关问题 更多 >