OCR结果不一致:训练与测试中的不同预测

0 投票
1 回答
98 浏览
提问于 2025-04-13 03:14

我在使用自定义的OCR(光学字符识别)模型时遇到了一个问题,模型在训练和测试阶段的预测结果不一样。

在训练阶段,模型的表现非常好,简直不可思议,在训练数据集上达到了很高的准确率。然而,当我在测试阶段用这个模型从图片中识别文本时,结果却不稳定。每个类别的置信度分数也波动很大,导致预测结果出乎意料。

Training started...
Epoch [1/50], Step [100/133], Loss: 1.4840
Epoch [2/50], Step [100/133], Loss: 0.1713
Epoch [3/50], Step [100/133], Loss: 0.1087
Epoch [4/50], Step [100/133], Loss: 0.0793
Epoch [5/50], Step [100/133], Loss: 0.0793
Epoch [6/50], Step [100/133], Loss: 0.0552
Epoch [7/50], Step [100/133], Loss: 0.0501
Epoch [8/50], Step [100/133], Loss: 0.0484
Epoch [9/50], Step [100/133], Loss: 0.0595
Epoch [10/50], Step [100/133], Loss: 0.0437
Epoch [11/50], Step [100/133], Loss: 0.0351
Epoch [12/50], Step [100/133], Loss: 0.0914
Epoch [13/50], Step [100/133], Loss: 0.0304
Epoch [14/50], Step [100/133], Loss: 0.0406
Epoch [15/50], Step [100/133], Loss: 0.0315
Epoch [16/50], Step [100/133], Loss: 0.0331
Epoch [17/50], Step [100/133], Loss: 0.0220
Epoch [18/50], Step [100/133], Loss: 0.0238
Epoch [19/50], Step [100/133], Loss: 0.0272
Epoch [20/50], Step [100/133], Loss: 0.0259
Epoch [21/50], Step [100/133], Loss: 0.0210
Epoch [22/50], Step [100/133], Loss: 0.0826
Epoch [23/50], Step [100/133], Loss: 0.0673
Epoch [24/50], Step [100/133], Loss: 0.0240
Epoch [25/50], Step [100/133], Loss: 0.0198
Epoch [26/50], Step [100/133], Loss: 0.0250
Epoch [27/50], Step [100/133], Loss: 0.0203
Epoch [28/50], Step [100/133], Loss: 0.0170
Epoch [29/50], Step [100/133], Loss: 0.0204
Epoch [30/50], Step [100/133], Loss: 0.0177
Epoch [31/50], Step [100/133], Loss: 0.0208
Epoch [32/50], Step [100/133], Loss: 0.0231
Epoch [33/50], Step [100/133], Loss: 0.0156
Epoch [34/50], Step [100/133], Loss: 0.0117
Epoch [35/50], Step [100/133], Loss: 0.0171
Epoch [36/50], Step [100/133], Loss: 0.0138
Epoch [37/50], Step [100/133], Loss: 0.0196
Epoch [38/50], Step [100/133], Loss: 0.0158
Epoch [39/50], Step [100/133], Loss: 0.0183
Epoch [40/50], Step [100/133], Loss: 0.0163
Epoch [41/50], Step [100/133], Loss: 0.0305
Epoch [42/50], Step [100/133], Loss: 0.0504
Epoch [43/50], Step [100/133], Loss: 0.0404
Epoch [44/50], Step [100/133], Loss: 0.0176
Epoch [45/50], Step [100/133], Loss: 0.0140
Epoch [46/50], Step [100/133], Loss: 0.0099
Epoch [47/50], Step [100/133], Loss: 0.0123
Epoch [48/50], Step [100/133], Loss: 0.0121
Epoch [49/50], Step [100/133], Loss: 0.0118
Epoch [50/50], Step [100/133], Loss: 0.0140
Training finished.
Testing started...
Predicted: A, Actual: A, Confidence: 100.00%
Predicted: B, Actual: B, Confidence: 98.97%
Predicted: C, Actual: C, Confidence: 100.00%
Predicted: D, Actual: D, Confidence: 99.46%
Predicted: E, Actual: E, Confidence: 99.63%
Predicted: F, Actual: F, Confidence: 99.32%
Predicted: G, Actual: G, Confidence: 99.92%
Predicted: H, Actual: H, Confidence: 99.99%
Predicted: I, Actual: I, Confidence: 90.64%
Predicted: J, Actual: J, Confidence: 97.04%
Predicted: K, Actual: K, Confidence: 100.00%
Predicted: L, Actual: L, Confidence: 99.08%
Predicted: M, Actual: M, Confidence: 100.00%
Predicted: O, Actual: A, Confidence: 67.16%
Predicted: P, Actual: P, Confidence: 96.75%
Predicted: Q, Actual: Q, Confidence: 99.71%
Predicted: R, Actual: R, Confidence: 99.98%
Predicted: S, Actual: S, Confidence: 99.25%
Predicted: T, Actual: T, Confidence: 91.48%
Predicted: U, Actual: U, Confidence: 100.00%
Predicted: V, Actual: V, Confidence: 100.00%
Predicted: W, Actual: W, Confidence: 100.00%
Predicted: X, Actual: X, Confidence: 100.00%
Predicted: Y, Actual: Y, Confidence: 100.00%
Predicted: Z, Actual: Z, Confidence: 100.00%
Predicted: 0, Actual: 0, Confidence: 100.00%
Predicted: 1, Actual: 1, Confidence: 100.00%
Predicted: 2, Actual: 2, Confidence: 100.00%
Predicted: 3, Actual: 3, Confidence: 100.00%
Predicted: 4, Actual: 4, Confidence: 100.00%
Predicted: 5, Actual: 5, Confidence: 100.00%
Predicted: 6, Actual: 6, Confidence: 100.00%
Predicted: 7, Actual: 7, Confidence: 100.00%
Predicted: 8, Actual: 8, Confidence: 100.00%
Predicted: 9, Actual: 9, Confidence: 100.00%
Accuracy on test dataset: 99.69%

Process finished with exit code 0

我使用PyTorch开发了一个自定义的OCR模型。以下是相关组件的简要介绍:

  • OCRModel:定义了OCR神经网络的结构。
  • OCRDataset:自定义的数据集类,用于处理OCR数据的加载和预处理。
  • OCRHandler:负责OCR模型的训练和测试。
  • OCR:继承自OCRHandler,提供图像到文本转换的额外功能。

在测试时,模型的预测结果与预期不同。例如,模型可能会预测出不同的字符,或者对正确的预测给出低置信度分数。我原本希望模型在训练和测试阶段都能产生一致且准确的预测,并且对正确的预测保持高置信度分数。现在我不知道该怎么办,因为我对神经网络的了解很少。

import string
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from pathlib import Path
from PIL import UnidentifiedImageError, Image
from typing import Optional, Union


class OCRModel(nn.Module):
    def __init__(self, num_classes: int):
        """
        Initialize the OCRModel.

        Args:
            num_classes (int): Number of classes for classification.
        """
        super(OCRModel, self).__init__()
        # Define convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        # Define pooling layer
        self.pool = nn.MaxPool2d(2, 2)
        # Define fully connected layers
        self.fc1 = nn.Linear(64 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the network.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class OCRDataset(Dataset):
    def __init__(self, data_path: Union[str, Path], transform: Optional[transforms.Compose] = None):
        """
        Initialize the OCRDataset.

        Args:
            data_path (Union[str, Path]): Path to the dataset.
            transform (Optional[transforms.Compose]): Transformations to apply to the data.
        """
        if not isinstance(data_path, (str, Path)):
            raise TypeError("data_path must be a string or a Path object.")
        self.data_path = Path(data_path)
        self.transform = transform
        self.dataset = ImageFolder(root=str(self.data_path), transform=transform)

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: int) -> tuple[None, None]:
        while True:
            try:
                image, label = self.dataset[idx]
                # Convert image to PIL Image to handle errors
                image = transforms.ToPILImage()(image)
                if self.transform:
                    image = self.transform(image)
                break
            except (UnidentifiedImageError, OSError) as error:
                print(f"Error opening image at index {idx}: {error}")
                idx += 1
                if idx >= len(self):
                    print("Reached end of dataset.")
                    return None, None
        return image, label


class OCRHandler:
    def __init__(self, model: OCRModel):
        """
        Initialize the OCRHandler.

        Args:
            model (OCRModel): OCR model instance.
        """
        if not isinstance(model, OCRModel):
            raise TypeError("model must be an instance of OCRModel.")
        self.model = model

    def train(self, train_data_path: Union[str, Path], num_epochs: int = 10, batch_size: int = 32,
              learning_rate: float = 0.001) -> None:
        """
        Train the OCR model.

        Args:
            train_data_path (Union[str, Path]): Path to the training dataset.
            num_epochs (int): Number of epochs for training.
            batch_size (int): Batch size for training.
            learning_rate (float): Learning rate for optimization.
        """
        # Print a message to indicate the start of training
        print("Training started...")

        # Prepare the training dataset with transformations
        train_dataset = OCRDataset(train_data_path, transform=transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor()
        ]))

        # Create a data loader for the training dataset
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        # Define loss criterion and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)

        # Loop through epochs
        for epoch in range(num_epochs):
            # Initialize running loss for each epoch
            running_loss = 0.0
            # Loop through batches in the training dataloader
            for i, (images, labels) in enumerate(train_dataloader):
                # Clear previous gradients
                optimizer.zero_grad()
                # Forward pass
                outputs = self.model(images)
                # Calculate loss
                loss = criterion(outputs, labels)
                # Backpropagation
                loss.backward()
                # Update weights
                optimizer.step()
                # Accumulate loss
                running_loss += loss.item()
                # Print loss statistics every 100 steps
                if (i + 1) % 100 == 0:
                    print(
                        f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_dataloader)}], Loss: {running_loss / 100:.4f}")
                    running_loss = 0.0
        print("Training finished.")

    def test(self, testing_path: Union[str, Path]) -> None:
        """
        Test the OCR model on the test dataset.

        Args:
            testing_path (str): Path to the test dataset.
        """
        # Print a message to indicate testing has started
        print("Testing started...")

        # Prepare the test dataset
        test_dataset = ImageFolder(root=str(testing_path), transform=transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor()
        ]))

        # Create a data loader for the test dataset
        test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

        # Initialize counters for correct predictions and total samples
        correct = 0
        total = 0

        # Disable gradient calculation for inference
        with torch.no_grad():
            # Iterate over the test dataset
            for images, labels in test_dataloader:
                # Forward pass through the model
                outputs = self.model(images)
                # Get the predicted classes
                _, predicted = torch.max(outputs, 1)
                # Update correct predictions count
                correct += (predicted == labels).sum().item()
                # Update total count
                total += labels.size(0)
                # Calculate confidence of the prediction
                confidence = torch.softmax(outputs, 1)[0][predicted.item()].item() * 100
                # Convert predicted and actual labels to uppercase letters or digits
                predicted_label = string.ascii_uppercase[predicted.item()] if predicted.item() < 26 else str(
                    predicted.item() - 26)
                actual_label = string.ascii_uppercase[labels.item()] if labels.item() < 26 else str(labels.item() - 26)
                # Print prediction details
                print(f"Predicted: {predicted_label}, Actual: {actual_label}, Confidence: {confidence:.2f}%")

        # Calculate and print accuracy
        print(f"Accuracy on test dataset: {(correct / total) * 100:.2f}%")

    def save(self) -> None:
        """
        Save the trained model.
        """
        save_path = Path(__file__).resolve().parent
        save_path.mkdir(parents=True, exist_ok=True)
        model_path = save_path / "ocr_model.pt"
        torch.save(self.model.state_dict(), model_path)
        print(f"Model saved successfully at: {model_path}")


class OCR(OCRHandler):
    def __init__(self, debug=False):
        """
        Initialize the OCR class.

        Args:
            debug (bool, optional): Whether to enable debug mode. Defaults to False.
        """
        # Load the OCR model
        ocr_model = OCRModel(num_classes=36)
        super().__init__(ocr_model)
        self.debug = debug

    def image_to_text(self, image_path: Union[str, Path]) -> Union[str, tuple[str, dict | None]]:
        """
        Convert an image to text using the OCR model.

        Args:
            image_path (Union[str, Path]): Path to the input image.

        Returns:
            Union[str, tuple[str, dict]]: The predicted text. If debug mode is enabled,
                returns a tuple containing the predicted text and a dictionary with debug information.
        """
        # Prepare the input image
        image = Image.open(image_path).convert("RGB")
        transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor()
        ])
        # Add batch dimension
        image = transform(image).unsqueeze(0)

        # Disable gradient calculation for inference
        with torch.no_grad():
            # Forward pass through the model
            outputs = self.model(image)
            # Get the predicted class
            _, predicted = torch.max(outputs, 1)
            # Convert predicted label to uppercase letter or digit
            predicted_text = string.ascii_uppercase[predicted.item()] if predicted.item() < 26 else str(
                predicted.item() - 26)

            if self.debug:
                # Calculate confidence scores
                confidence_scores = torch.softmax(outputs, 1)[0].tolist()
                # Convert confidence scores to percentages
                confidence_percentages = [score * 100 for score in confidence_scores]
                # Create a dictionary with debug information
                debug_info = {
                    "predicted_text": predicted_text,
                    "confidence_scores": {string.ascii_uppercase[i] if i < 26 else str(i - 26): percentage
                                          for i, percentage in enumerate(confidence_percentages)},
                    "top_predictions": [
                        {
                            "class": string.ascii_uppercase[i] if i < 26 else str(i - 26),
                            "confidence": percentage
                        }
                        for i, percentage in enumerate(confidence_percentages)
                    ],
                }
                return predicted_text, debug_info
            else:
                return predicted_text, None

if __name__ == "__main__":
    ocr_model = OCRModel(num_classes=36)
    trainer = OCRHandler(ocr_model)
    train_data_path = Path("dataset/text_identification/train")
    trainer.train(train_data_path, num_epochs=50, batch_size=32, learning_rate=0.001)
    test_data_path = Path("dataset/text_identification/test")
    trainer.test(train_data_path)

在测试我的OCR时,识别字符“0”时,它却预测成了G?我完全无法理解,明明训练了几百张图片,怎么会预测成G呢。

{'predicted_text': 'G', 'confidence_scores': {'A': 2.5511952117085457, 'B': 2.9084114357829094, 'C': 2.9245806857943535, 'D': 3.019385412335396, 'E': 2.867276221513748, 'F': 2.7050500735640526, 'G': 3.0951984226703644, 'H': 3.0260657891631126, 'I': 2.691103331744671, 'J': 2.631979249417782, 'K': 2.5525128468871117, 'L': 2.7190934866666794, 'M': 2.6693686842918396, 'N': 2.9030684381723404, 'O': 2.5663597509264946, 'P': 2.941553108394146, 'Q': 2.8950219973921776, 'R': 2.789396792650223, 'S': 2.6342585682868958, 'T': 2.5583021342754364, 'U': 2.799813263118267, 'V': 2.574686147272587, 'W': 2.713942527770996, 'X': 2.824728935956955, 'Y': 2.85495538264513, 'Z': 2.7191564440727234, '0': 2.6622315868735313, '1': 2.6157714426517487, '2': 2.8603684157133102, '3': 2.5942767038941383, '4': 2.733685076236725, '5': 2.891615778207779, '6': 3.0571456998586655, '7': 2.551230974495411, '8': 2.9105449095368385, '9': 2.986663021147251}, 'top_predictions': [{'class': 'A', 'confidence': 2.5511952117085457}, {'class': 'B', 'confidence': 2.9084114357829094}, {'class': 'C', 'confidence': 2.9245806857943535}, {'class': 'D', 'confidence': 3.019385412335396}, {'class': 'E', 'confidence': 2.867276221513748}, {'class': 'F', 'confidence': 2.7050500735640526}, {'class': 'G', 'confidence': 3.0951984226703644}, {'class': 'H', 'confidence': 3.0260657891631126}, {'class': 'I', 'confidence': 2.691103331744671}, {'class': 'J', 'confidence': 2.631979249417782}, {'class': 'K', 'confidence': 2.5525128468871117}, {'class': 'L', 'confidence': 2.7190934866666794}, {'class': 'M', 'confidence': 2.6693686842918396}, {'class': 'N', 'confidence': 2.9030684381723404}, {'class': 'O', 'confidence': 2.5663597509264946}, {'class': 'P', 'confidence': 2.941553108394146}, {'class': 'Q', 'confidence': 2.8950219973921776}, {'class': 'R', 'confidence': 2.789396792650223}, {'class': 'S', 'confidence': 2.6342585682868958}, {'class': 'T', 'confidence': 2.5583021342754364}, {'class': 'U', 'confidence': 2.799813263118267}, {'class': 'V', 'confidence': 2.574686147272587}, {'class': 'W', 'confidence': 2.713942527770996}, {'class': 'X', 'confidence': 2.824728935956955}, {'class': 'Y', 'confidence': 2.85495538264513}, {'class': 'Z', 'confidence': 2.7191564440727234}, {'class': '0', 'confidence': 2.6622315868735313}, {'class': '1', 'confidence': 2.6157714426517487}, {'class': '2', 'confidence': 2.8603684157133102}, {'class': '3', 'confidence': 2.5942767038941383}, {'class': '4', 'confidence': 2.733685076236725}, {'class': '5', 'confidence': 2.891615778207779}, {'class': '6', 'confidence': 3.0571456998586655}, {'class': '7', 'confidence': 2.551230974495411}, {'class': '8', 'confidence': 2.9105449095368385}, {'class': '9', 'confidence': 2.986663021147251}]}

1 个回答

1

这里有很多事情要说。

[这个模型]在训练和测试阶段的预测结果是不同的。

根据你分享的信息,这个说法是不成立的。从你的评估来看,损失值是Loss: 0.0140(如果这个是准确率指标的话,那你有98.6%的准确率,而在测试时甚至达到了99.69%。所以从这里看,你的模型在测试数据上的表现比在训练数据上还要好。

在训练过程中,模型的表现非常好,简直不可思议,在训练数据集上达到了很高的准确率。然而,当我用这个模型来预测图像中的文本时,结果却不一致。

说到结果不一致,我这里有两个建议。

a) 定义一个验证集,在每个训练周期后运行它,这样可以真正看到不一致的地方。
b) 增加测试集的大小。测试集越小,结果准确率(或你使用的其他指标)的波动就越大。从上面的情况来看,你的测试集太小了,可能会让你产生误导。
c) 在此基础上,为每个字符创建验证集,这样你可以了解每个类别的损失情况,真正理解模型的弱点。

这些验证和测试的步骤,接下来可以帮助你改进模型,通过增加在表现不佳的类别中的数据样本,或者在训练时对这些类别的损失加大权重。但首先你需要确保自己有一个清晰的认识。

每个类别的置信度分数似乎也波动很大,导致了意外的预测结果。

置信度是一个容易让人误解的术语。虽然这可能反映了预测某个类别的概率,但如果你的类别中有两个相似的类别,它们会相互影响,从而导致共享的、较低的概率(例如,G和0,或者O和0)。为了更好地理解这一点,你可以绘制一个热图,显示每个字符相对于整个类别的平均概率。因此,虽然置信度和概率是一个不错的指标,但它们在很大程度上依赖于与其他类别的相似性。

在测试我的OCR时,识别字符“0”时,结果却是G?

没错,正如所说的,相似的类别让预测任务变得更加困难。而且,有时候确实很难区分0和G,具体取决于书写方式。如果你想更好地理解这个问题,绘制一些样本总是个好主意。特别是对于手写字符来说,存在语义差距,没有上下文,你是无法达到100%的预测准确率的。

我现在不知道该怎么办,我对神经网络的知识很少。

我建议你按照我回答中提到的做法(我已加粗标出)。关键的信息是“验证和测试”。你需要理解在预测时到底发生了什么。几份样本是不够的,所以要增加样本量。这样你就能更容易地改进你的模型,并从那里继续前进。

撰写回答