Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Evaluation] LM Harness refactor #1410

Merged
merged 4 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 0 additions & 32 deletions src/oumi/core/configs/params/model_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from omegaconf import MISSING
from transformers.utils import find_adapter_config_file, is_flash_attn_2_available

from oumi.core.configs.inference_engine_type import InferenceEngineType
from oumi.core.configs.params.base_params import BaseParams
from oumi.core.types.exceptions import HardwareException
from oumi.utils.logging import logger
Expand Down Expand Up @@ -187,37 +186,6 @@ class ModelParams(BaseParams):
other parts fixed.
"""

def to_lm_harness(
self, inference_engine_type: InferenceEngineType
) -> dict[str, Any]:
"""Converts Oumi's ModelParams to LM Harness model arguments."""
model_args_dict = {
"pretrained": self.model_name,
"trust_remote_code": self.trust_remote_code,
"dtype": self.torch_dtype,
}
if inference_engine_type == InferenceEngineType.NATIVE:
model_args_dict["parallelize"] = self.shard_for_eval
model_args_dict["device_map"] = self.device_map
if self.adapter_model:
model_args_dict["peft"] = self.adapter_model
if (
self.attn_implementation
and inference_engine_type == InferenceEngineType.NATIVE
):
model_args_dict["attn_implementation"] = self.attn_implementation

# Handle extra model_kwargs (construction arguments).
# Towards OPE-564.
if self.model_kwargs:
relevant_for_lm = ["load_in_4bit", "load_in_8bit", "max_memory_per_gpu"]
for key in relevant_for_lm:
if key in self.model_kwargs:
model_args_dict[key] = self.model_kwargs[key]
# TODO: load_in_8bit, load_in_4bit are deprecated and will be removed in
# future versions of HF. Integrate via PeftConfig.
return model_args_dict

def __post_init__(self):
"""Populate additional params."""
self.torch_dtype = get_torch_dtype(self.torch_dtype_str)
Expand Down
3 changes: 2 additions & 1 deletion src/oumi/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ def evaluate(config: EvaluationConfig) -> list[dict[str, Any]]:
output_dir=config.output_dir,
model_params=config.model,
generation_params=config.generation,
inference_engine_type=config.inference_engine,
enable_wandb=config.enable_wandb,
inference_engine_type=config.inference_engine,
inference_remote_params=config.inference_remote_params,
run_name=config.run_name,
)
results.append(result)
Expand Down
171 changes: 123 additions & 48 deletions src/oumi/evaluation/lm_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,52 +23,125 @@
import torch
from lm_eval.loggers import WandbLogger

from oumi.builders import build_processor, build_tokenizer, is_image_text_llm
from oumi.builders import build_processor, build_tokenizer
from oumi.builders.models import is_image_text_llm_using_model_name
from oumi.core.configs import (
GenerationParams,
InferenceEngineType,
LMHarnessTaskParams,
ModelParams,
RemoteParams,
)
from oumi.core.distributed import is_world_process_zero
from oumi.evaluation.save_utils import save_evaluation_output
from oumi.utils.logging import logger


def _create_extra_lm_harness_model_params_for_vlm(
######################### LM Harness: model and model arguments ########################
# LM Harness `model` types | Class | File location in their repo is under: #
# (= inference engine) | name | lm-evaluation-harness/lm_eval/models/...#
# --------------------------|----------------|---------------------------------------- #
# hf | HFLM | huggingface.py #
# vllm | VLLM | vllm_causallms.py #
# hf-multimodal | HFMultimodalLM | hf_vlms.py #
# vllm-vlm | VLLM_VLM | vllm_vlms.py #
########################################################################################
# LM Harness | Oumi | text engs | multimodal engs | remote #
# `model_args` | `model_params` | hf | vllm | hf-multimodal | vllm-vlm | #
# ------------------- | -------------- | -- | ---- | ------------- | -------- | ------ #
# trust_remote_code | | Υ | Υ | Υ | Υ | TBD #
# pretrained | model_name | Υ | Υ | Υ | Υ | TBD #
# dtype | torch_dtype | Υ | Υ | Υ | Υ | TBD #
# max_length |model_max_length| Υ | Υ | Υ | Υ | TBD #
# tokenizer | tokenizer_name | Υ | Υ | Υ | Υ | TBD #
# peft | adapter_model | Υ | | Υ | | TBD #
# parallelize | shard_for_eval | Υ | | Υ | | TBD #
# device_map | | ?? | | ?? | | TBD #
# attn_implementation | | ?? | | ?? | | TBD #
# ------------------- | -------------- | -- | ---- | ------------- | -------- | ------ #
# max_images | | NA | NA | Υ | Υ | TBD #
# interleave | | NA | NA | Υ | Υ | TBD #
# convert_img_format | | NA | NA | Υ | | TBD #
# image_token_id | | NA | NA | Υ | | TBD #
# image_string | | NA | NA | Υ | | TBD #
########################################################################################
def _generate_lm_harness_model_args(
lm_harness_model: str,
is_multimodal: bool,
model_params: ModelParams,
vllm_engine: bool,
inference_engine_type: InferenceEngineType,
inference_remote_params: Optional[RemoteParams],
) -> dict[str, Any]:
# For details, see:
# https://github.com/EleutherAI/lm-evaluation-harness/releases/tag/v0.4.5
# FIXME OPE-355 To remove `max_images=1` limit
result = {"max_images": 1, "interleave": True}

# Only applicable to hf-multimodal (NOT vllm-vlm).
if not vllm_engine:
result["convert_img_format"] = True

tokenizer = build_tokenizer(model_params)
processor = build_processor(
model_params.model_name,
tokenizer,
trust_remote_code=model_params.trust_remote_code,
"""Converts Oumi's ModelParams to LM Harness model arguments."""
# Arguments used across all engines and modalities.
model_args_dict = {
"trust_remote_code": model_params.trust_remote_code,
"pretrained": model_params.model_name,
"dtype": model_params.torch_dtype,
"max_length": model_params.model_max_length,
}
if model_params.tokenizer_name:
model_args_dict["tokenizer"] = model_params.tokenizer_name

# Add NATIVE inference engine's additional parameters.
if inference_engine_type == InferenceEngineType.NATIVE:
model_args_dict["parallelize"] = model_params.shard_for_eval
model_args_dict["device_map"] = model_params.device_map
if model_params.adapter_model:
model_args_dict["peft"] = model_params.adapter_model
if model_params.attn_implementation:
model_args_dict["attn_implementation"] = model_params.attn_implementation

# Add REMOTE inference engine's additional parameters.
if inference_engine_type == InferenceEngineType.REMOTE:
if not inference_remote_params:
raise ValueError(
"The `REMOTE` inference engine requires `inference_remote_params`."
)
raise NotImplementedError(
"The REMOTE inference engine is not yet supported with LM Harness."
)
if image_token := processor.image_token:
result["image_string"] = image_token
if image_token_id := processor.image_token_id:
result["image_token_id"] = image_token_id

return result
# Add multi-modal related parameters.
# details at https://github.com/EleutherAI/lm-evaluation-harness/releases/tag/v0.4.5
if is_multimodal:
# FIXME OPE-355 To remove `max_images=1` limit
model_args_dict |= {"max_images": 1, "interleave": True}

# Only applicable to hf-multimodal (NOT vllm-vlm).
if lm_harness_model == "hf-multimodal":
model_args_dict["convert_img_format"] = True

tokenizer = build_tokenizer(model_params)
processor = build_processor(
model_params.model_name,
tokenizer,
trust_remote_code=model_params.trust_remote_code,
)
if image_token := processor.image_token:
kaisopos marked this conversation as resolved.
Show resolved Hide resolved
model_args_dict["image_string"] = image_token
if image_token_id := processor.image_token_id:
model_args_dict["image_token_id"] = image_token_id

# Handle extra model_kwargs (construction arguments).
# Towards OPE-564.
if model_params.model_kwargs:
for key in ["load_in_4bit", "load_in_8bit", "max_memory_per_gpu"]:
if key in model_params.model_kwargs:
model_args_dict[key] = model_params.model_kwargs[key]
# TODO: load_in_8bit, load_in_4bit are deprecated and will be removed in
# future versions of HF. Integrate via PeftConfig.
return model_args_dict


def evaluate(
task_params: LMHarnessTaskParams,
output_dir: str,
model_params: ModelParams,
generation_params: GenerationParams,
inference_engine_type: InferenceEngineType,
enable_wandb: bool,
inference_engine_type: InferenceEngineType,
inference_remote_params: Optional[RemoteParams] = None,
run_name: Optional[str] = None,
) -> dict[str, Any]:
"""Evaluates a model using the LM Evaluation Harness framework (EleutherAI).
Expand All @@ -77,12 +150,13 @@ def evaluate(
https://github.com/EleutherAI/lm-evaluation-harness

Args:
model_params: The parameters of the model to evaluate.
task_params: The LM Harness parameters to use for evaluation.
generation_params: The generation parameters to use for evaluation.
inference_engine_type: The inference engine to use (`VLLM` or `NATIVE`).
output_dir: The directory where the evaluation results will be saved.
model_params: The parameters of the model to evaluate.
generation_params: The generation parameters to use for evaluation.
enable_wandb: Whether to enable Weights & Biases (wandb) logging.
inference_engine_type: The inference engine to use (`VLLM`, `NATIVE`, `REMOTE`).
inference_remote_params: The parameters for remote inference, if applicable.
run_name: Unique identifier for wandb for the current training run.

Returns:
Expand All @@ -98,23 +172,32 @@ def evaluate(
device = "cpu"
logger.warning("No GPU available.")

# Ensure the requested inference engine type is applicable.
# Identify whether the model is multi-modal.
is_multimodal = is_image_text_llm_using_model_name(
model_name=model_params.model_name,
trust_remote_code=model_params.trust_remote_code,
)

# Identify the proper LM Harness model (`lm_harness_model`) to use.
if inference_engine_type == InferenceEngineType.NATIVE:
vllm_engine = False
lm_harness_model = "hf-multimodal" if is_multimodal else "hf"
if device.startswith("cuda"):
logger.warning(
"Since you have GPU support, it is highly recommended that you set "
"the `inference_engine` to `VLLM`, instead of the `NATIVE`, for faster "
"evaluation."
)
elif inference_engine_type == InferenceEngineType.VLLM:
vllm_engine = True
lm_harness_model = "vllm-vlm" if is_multimodal else "vllm"
if not device.startswith("cuda"):
raise ValueError("The `VLLM` inference_engine requires a CUDA-enabled GPU.")
elif inference_engine_type == InferenceEngineType.REMOTE:
lm_harness_model = "local-completions"
else:
raise ValueError(
"Our integration with the `lm_harness` evaluation platform only supports "
"the `VLLM` and `NATIVE` inference_engine types at the moment."
f"Unsupported inference engine type: {inference_engine_type}. "
"Our integration with the `lm_harness` evaluation platform supports "
"the `NATIVE`, `VLLM` and `REMOTE` inference_engine types."
)

if model_params.adapter_model:
Expand All @@ -127,26 +210,18 @@ def evaluate(
generation_params.batch_size if generation_params.batch_size else "auto"
)

lm_harness_model_params = _generate_lm_harness_model_args(
lm_harness_model=lm_harness_model,
is_multimodal=is_multimodal,
model_params=model_params,
inference_engine_type=inference_engine_type,
inference_remote_params=inference_remote_params,
)

# Get a timestamp for the current run.
start_time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
start_time = time.time()

lm_harness_model_params = model_params.to_lm_harness(inference_engine_type)

if is_image_text_llm(model_params):
# Multimodal support is currently restricted to
# the ['hf-multimodal', 'vllm-vlm'] model types.
lm_harness_model = "vllm-vlm" if vllm_engine else "hf-multimodal"
apply_chat_template = True
lm_harness_model_params.update(
_create_extra_lm_harness_model_params_for_vlm(model_params, vllm_engine)
)
else:
lm_harness_model = "vllm" if vllm_engine else "hf"
# False is the default value for `simple_evaluate()`
# TODO Should it be set to True?
apply_chat_template = False

logger.info("Starting evaluation...")
logger.info(f"\tLM Harness `model_params`:\n{pformat(lm_harness_model_params)}")
logger.info(f"\tLM Harness `task_params`:\n{pformat(task_params)}")
Expand All @@ -159,7 +234,7 @@ def evaluate(
device=device,
limit=task_params.num_samples,
log_samples=False,
apply_chat_template=apply_chat_template,
apply_chat_template=is_multimodal,
**task_params.eval_kwargs, # type: ignore
)
elapsed_time_sec = time.time() - start_time
Expand Down