从数组索引得到的矩阵的理解

2024-03-28 14:25:21 发布

您现在位置:Python中文网/ 问答频道 /正文

logistic regression code中列出的代码中,我看到了以下代码片段。让我不舒服的是: probs[range(num_examples),y]。 有人能告诉我这个矩阵有什么维数吗?我猜这是一个N*KN*K矩阵,但我不确定。谢谢

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(0)
N = 100 # number of points per class
D = 2 # dimensionality
K = 3 # number of classes
X = np.zeros((N*K,D))
y = np.zeros(N*K, dtype='uint8')
for j in xrange(K):
  ix = range(N*j,N*(j+1))
  r = np.linspace(0.0,1,N) # radius
  t = np.linspace(j*4,(j+1)*4,N) + np.random.randn(N)*0.2 # theta
  X[ix] = np.c_[r*np.sin(t), r*np.cos(t)]
  y[ix] = j

#Train a Linear Classifier

# initialize parameters randomly
W = 0.01 * np.random.randn(D,K)
b = np.zeros((1,K))

# some hyperparameters
step_size = 1e-0
reg = 1e-3 # regularization strength

# gradient descent loop
num_examples = X.shape[0]
for i in xrange(200):

  # evaluate class scores, [N x K]
  scores = np.dot(X, W) + b 

  # compute the class probabilities
  exp_scores = np.exp(scores)
  probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True) # [N x K]

  # compute the loss: average cross-entropy loss and regularization
  corect_logprobs = -np.log(probs[range(num_examples),y])
  data_loss = np.sum(corect_logprobs)/num_examples
  reg_loss = 0.5*reg*np.sum(W*W)
  loss = data_loss + reg_loss
  if i % 10 == 0:

Tags: 代码npzerosrange矩阵randomregexamples
1条回答
网友
1楼 · 发布于 2024-03-28 14:25:21

probs[range(num_examples), y]似乎是1D切片,其中:

  • range(num_examples)是跨越样本长度的向量
  • y是一维向量,长度N*K

相关问题 更多 >