Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Got wrong inference result #316

Open
TinhAnhGitHub opened this issue Oct 5, 2024 · 1 comment
Open

Got wrong inference result #316

TinhAnhGitHub opened this issue Oct 5, 2024 · 1 comment

Comments

@TinhAnhGitHub
Copy link

TinhAnhGitHub commented Oct 5, 2024

I have attempted to fine-tune the model for end-to-end chart extraction, based on the Dataset ChartQA. During training, everything is ok, the loss was going down as expected.
However, during inference, the result is just a long sequence a of repeated random character:

Prediction: {'text_sequence': ' 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35'}

Here is the code we have used for loading the model and inference:

import json
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from datasets import load_dataset
from PIL import Image
from tqdm import tqdm
from donut import DonutModel, JSONParseEvaluator, save_json
from torchvision import transforms

import matplotlib.pyplot as plt
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((1280, 960)),
        transforms.ToTensor(),
    ])
    return transform(image)
def test(pretrained_model_name, dataset_name, sample_index=0, task_name=None, save_path=None):
    pretrained_model = DonutModel.from_pretrained(pretrained_model_name, ignore_mismatched_sizes=True)

    if torch.cuda.is_available():
        pretrained_model.half()
        pretrained_model.to("cuda") 

    pretrained_model.eval()

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)

    dataset = load_dataset(dataset_name)

    sample = dataset["validation"][sample_index]  
    ground_truth = json.loads(sample["ground_truth"])

    
    plt.imshow(sample['image'])
    plt.show()

    if task_name == "docvqa":
        output = pretrained_model.inference(
            image=input_tensor,
            prompt=f"<s_{task_name}><s_question>{ground_truth['gt_parses'][0]['question'].lower()}</s_question><s_answer>"
        )["predictions"][0]
    else:
        output =  pretrained_model.inference(image=sample['image'], prompt = '<s_general_figure_info>')['predictions'][0] 

    print(f"Sample Index: {sample_index}")
    
    print("*"*50)
    print(f"Ground truth: {ground_truth}")
    print("*"*50)
    
    
    print("*" * 50)
    print(f"Prediction: {output}")
    print('*'*50)
    
    if task_name == "rvlcdip":
        gt = ground_truth["gt_parse"]
        score = float(output["class"] == gt["class"])
    elif task_name == "docvqa":
        gt = ground_truth["gt_parses"]
        answers = set([qa_parse["answer"] for qa_parse in gt])
        score = float(output["answer"] in answers)
    else:
        gt = ground_truth["gt_parse"]
        evaluator = JSONParseEvaluator()
        score = evaluator.cal_acc(output, gt)

    print(f"Sample Index: {sample_index}, Score: {score}")

    if save_path:
        results = {
            "sample_index": sample_index,
            "predictions": output,
            "ground_truth": gt,
            "score": score
        }
        save_json(save_path, results)

pretrained_model_name = "/kaggle/working/orkspace/result/train_chartQA/test_experiment"  # Path to your model
dataset_name = "DanTheGuy/ChartQA_small_preprocessed_chunked_box"  # Path to your dataset
task_name = "Chart"  
save_path = "/kaggle/working/result/output_single_sample.json"  # Where to save the output

test(pretrained_model_name, dataset_name, sample_index=40, task_name=task_name, save_path=save_path)

Here is the sample image that we have tested on:
image
And here is the ground-truth:


Ground truth: {'gt_parse': {'general_figure_info': {'figure_info': {'bbox': {'x': 103, 'y': 45, 'w': 622, 'h': 271}}, 'y_axis': {'label': {'bbox': [{'x': 40, 'y': 113, 'w': 11, 'h': 135.640625}], 'text': 'Share of total U.S. households'}, 'major_labels': {'bboxes': {'x': 60, 'y': 24, 'w': 33, 'h': 302}}}, 'x_axis': {'major_labels': {'bboxes': {'x': 77, 'y': 324, 'w': 638, 'h': 123}, 'values': ['With an Internet', 'Dial-up only', 'Broadband of any type', 'Cellular data plan', 'Cellular data plan with no', 'Broadband such as cable,', 'Satellite', 'No subscription (internet']}, 'label': {}}}, 'models': [{'name': {}, 'bboxes': [{'x': 115, 'y': 86, 'w': 54, 'h': 231}, {'x': 193, 'y': 316, 'w': 54, 'h': 1}, {'x': 271, 'y': 86, 'w': 54, 'h': 231}, {'x': 349, 'y': 112, 'w': 54, 'h': 205}, {'x': 426, 'y': 286, 'w': 54, 'h': 31}, {'x': 504, 'y': 128, 'w': 54, 'h': 189}, {'x': 582, 'y': 298, 'w': 54, 'h': 19}, {'x': 660, 'y': 277, 'w': 54, 'h': 40}], 'x': ['With an Internet', 'Dial-up only', 'Broadband of any type', 'Cellular data plan', 'Cellular data plan with no', 'Broadband such as cable,', 'Satellite', 'No subscription (internet'], 'y': ['85.3%', '0.3%', '85.1%', '75.7%', '11.6%', '69.6%', '6.9%', '14.7%'], 'colors': ['#2876dd', '#2876dd', '#2876dd', '#2876dd', '#2876dd', '#2876dd', '#2876dd', '#2876dd'], 'labels': ['With an Internet', 'Dial-up only', 'Broadband of any type', 'Cellular data plan', 'Cellular data plan with no', 'Broadband such as cable,', 'Satellite', 'No subscription (internet']}], 'type': 'v_bar'}, 'metadata': {'imgname': 'two_col_2998.png'}}


Our version:
!pip install transformers==4.25.1
!pip install timm==0.5.4

And here is our loss graph:
image

Any comments will be a great support to help us fix this bug :D

@Maryam483
Copy link

same issue observed with my dataset. I have followed this tutorial: https://towardsdatascience.com/ocr-free-document-understanding-with-donut-1acfbdf099be

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants