擅长:python、mysql、java
<p>本的回答很快就解决了你的问题。我只想添加混淆矩阵:</p>
<pre><code>confusion_matrix = (df.groupby('label')['pred']
.value_counts(normalize=True)
.unstack(fill_value=0)
)
</code></pre>
<p>输出:</p>
<pre><code>pred cat dog elephant snake
label
cat 1.000000 0.000000 0.0 0.000000
dog 0.333333 0.333333 0.0 0.333333
elephant 0.000000 0.000000 1.0 0.000000
snake 0.500000 0.000000 0.0 0.500000
</code></pre>