我一直在尝试用tf.feature\u列对我的一些功能进行多重热编码。根据我的理解,给定一个特性,所有输入都必须填充到相同的长度,并让填充的值被排除在词汇表之外(参见here和here)。下面是一个代码示例:
import numpy as np
import itertools
import tensorflow as tf
from tensorflow import feature_column
from tensorflow.keras import layers
def pad_length(feat):
return np.array(list(itertools.zip_longest(*feat, fillvalue="UNK"))).T
COUNTRY_VOC = ["UK", "FR", "DE", "CH", "US", "ES"]
THEME_VOC = ["A", "B", "C", "D", "E", "F", "G"]
# input
country_input = [["FR", "UK"], ["FR", "DE"], ["CH"]]
theme_input = [["A"], ["B", "C"], ["C"]]
# build dataset
dict_input = {"country": [pad_length(country_input)], "theme": [pad_length(theme_input)]}
dataset = tf.data.Dataset.from_tensor_slices((dict_input))
# build feature column
country_feat = feature_column.categorical_column_with_vocabulary_list(
"country", vocabulary_list=COUNTRY_VOC
)
country_one_hot = feature_column.indicator_column(country_feat)
theme_feat = feature_column.categorical_column_with_vocabulary_list(
"theme", vocabulary_list=THEME_VOC
)
theme_one_hot = feature_column.indicator_column(theme_feat)
feature = [country_one_hot, theme_one_hot]
# output example
example_batch = next(iter(dataset))
feature_layer = layers.DenseFeatures(feature)
print(feature_layer(example_batch))
具有以下输出:
tf.Tensor(
[[1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[0. 1. 1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0.]
[0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0.]], shape=(3, 13), dtype=float32)
我对这种方法的担心是,我已经期望在我的推理管道中有词汇表外的输入,并希望避免这种额外的噪音。 因此,我尝试使用生成器加载可变长度输入的数据集,如here。但我无法获得预期的结果:
def dict_gen():
for i in range(num_samples):
ls = {}
for key, val in dict_input.items():
ls[key] = val[i]
yield ls
# input
dict_input = {
"country": [["FR", "UK"], ["FR", "DE"], ["CH"]],
"theme": [["A"], ["B", "C"], ["C"]],
}
# build dataset
num_samples = 3
dataset = tf.data.Dataset.from_generator(
dict_gen,
output_types={k: tf.string for k in dict_input},
output_shapes={k: tf.TensorShape([None]) for k in dict_input},
)
具有以下输出:
ValueError: Batch size (first dimension) of each feature must be same. Batch size of columns (country_indicator, theme_indicator): (2, 1)
有没有一种简单的方法可以在没有填充的情况下对多长度输入进行编码?它似乎是唯一一个正在进行字节编码和tfrecord(see here)的程序。 非常感谢您的指导
目前没有回答
相关问题 更多 >
编程相关推荐