import argparse
import minimal_opt
import torch
from tqdm import auto as tqdm_lib
torch.set_grad_enabled(False)


def run_step(model, batch_size, input_len, output_len):
    input_ids = torch.zeros(batch_size, input_len).long().cuda()
    max_seq_len = input_len + output_len

    initial_input_length = input_ids.shape[1]
    current_input_ids = input_ids
    layer_past = None
    layer_past_length = 0
    all_token_ids = input_ids.tolist()
    batch_size = len(all_token_ids)

    trange = range(initial_input_length, max_seq_len)
    with torch.inference_mode():
        for _ in trange:
            input_length = current_input_ids.shape[1]
            model_out, layer_past = model(
                current_input_ids,
                layer_past=layer_past,
            )
            greedy_predicted_token_ids = model_out[:, -1].argmax(-1)
            current_input_ids = greedy_predicted_token_ids[:, None]
            for i in range(batch_size):
                all_token_ids[i].append(greedy_predicted_token_ids[i])
            layer_past_length += input_length


def create_model(model_name):
    config = {
        "125m": minimal_opt.OPT_125M_CONFIG,
        "1.3b": minimal_opt.OPT_1_3B_CONFIG,
        "2.7b": minimal_opt.OPT_2_7B_CONFIG,
        "6.7b": minimal_opt.OPT_6_7B_CONFIG,
        "13b": minimal_opt.OPT_13B_CONFIG,
        "30b": minimal_opt.OPT_30B_CONFIG,
        "66b": minimal_opt.OPT_66B_CONFIG,
        "175b": minimal_opt.OPT_175B_CONFIG,
    }[model_name]
    model = minimal_opt.PPOPTModel(config, use_cache=True)
    return model


def main():
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--input_len', type=int, default=1024)
    parser.add_argument('--output_len', type=int, default=128)
    parser.add_argument('--num_steps', type=int, default=10)
    args = parser.parse_args()

    model = create_model(args.model_name)
    run_step(
        model=model,
        batch_size=args.batch_size,
        input_len=args.input_len,
        output_len=args.output_len,
    )
    for _ in tqdm_lib.trange(args.num_steps):
        run_step(
            model=model,
            batch_size=args.batch_size,
            input_len=args.input_len,
            output_len=args.output_len,
        )
    print(f"{args.batch_size} Done.")


if __name__ == "__main__":
    main()