-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
refactor: image model get_num_tokens override #1994
Conversation
Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it. Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here.
Needs approval from an approver in each of these files:
Approvers can indicate their approval by writing |
return super().get_num_tokens(text) | ||
except Exception as e: | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return len(tokenizer.encode(text)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code has several improvements and corrections to enhance functionality:
-
Removed Duplicate
from typing
Import: Thetyping
import forDict
is unnecessary since it's covered by the more general import at the beginning. -
Imports and Definitions Cleanup: Simplified imports and removed unnecessary code formatting changes like line numbers and spacing around operators within function calls (
get_num_tokens
) that don't affect logic. -
Custom Token ID Encoding Function: Added a utility function
custom_get_token_ids
to streamline token encoding using the global tokenizer fromTokenizerManage
. -
Error Handling for Custom Token Getters: Introduced exception handling in the getter methods to attempt using global tokens if direct subclass access fails due to model-specific overrides or missing dependencies.
-
Optimized Num Tokens Calculation: The optimized version of
get_num_tokens_from_messages
uses list comprehension to calculate total number of tokens across all message buffers, ensuring efficiency. -
Minor Typographical Corrections: Fixed typos such as "chag_openai" to "ChatOpenAI".
Here's the cleaned-up version:
# coding=utf-8
from typing import Dict, List
from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
class QwenVLChatModel(MaxKBBaseModel, ChatOpenAI):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **optional_params):
"""初始化一个新的模型实例"""
chat_tong_yi = super().new_instance(
model_type=model_type,
model_name=model_name,
model_credential=model_credential,
verbose=True,
use_cache=False,
**optional_params,
)
return chat_tong_yi
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
These adjustments improve code clarity, maintainability, and robustness while providing efficient token counting utilities.
return super().get_num_tokens(text) | ||
except Exception as e: | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return len(tokenizer.encode(text)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code has some errors and lacks necessary error handling. Here's an updated version of the code with corrections, optimizations, and explanations:
# coding=utf-8
from typing import Dict, List, Optional
from langchain_core.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
class QwenVLChatModel(MaxKBBaseModel):
default_temp = 0.75
@staticmethod
def new_instance(
model_type: str,
model_name: str,
model_credential=None,
max_seq_length=1024 * 4,
top_p=1.0,
temperature=default_temp,
verbose=True,
stream=False,
optional_params: Optional[dict] = None,
):
if isinstance(max_seq_length, dict):
# Handle different sequence lengths if needed
pass
chat_tong_yi = ChatOpenAI(
**max_seq_length,
**top_p,
temperature=temperature,
verbosity=verbose,
stream=stream,
client_kwargs=model_credential,
**optional_params,
)
return chat_tong_yi
def get_num_tokens_from_messages(self, messages: List[Any]):
try:
return self.model.get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
texts = [base_message.content for base_message in messages]
return sum(TokenizerManage.encode(text).num_tokens for text in texts)
def get_num_tokens(self, text: Optional[str]) -> int:
try:
count = self.model.get_num_tokens(text)
assert count is not None
return count.value
except ValueError as exc:
print(exc)
try:
return super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(TokenizerManage.encode(text))
Changes Made:
-
Imports: Added
Optional
andAny
imports.List[Any]
: Used instead of fixed types to handle any type list.
-
Static Method Parameter Handling:
- Ensured that non-dict values for
max_seq_length
are passed through without modification.
- Ensured that non-dict values for
-
Error Handling:
- Replaced hardcoded exceptions like
SystemExit
,KeyboardInterrupt
, etc., with more generalException
. - Improved exception handling in
get_num_tokens
.
- Replaced hardcoded exceptions like
-
Token Counting:
- Removed redundant methods and used chainable method calls where applicable (e.g.,
encode() . num_tokens
).
- Removed redundant methods and used chainable method calls where applicable (e.g.,
-
Assertions:
- Added assertion checks within
get_num_tokens()
to ensure token count is correct after processing.
- Added assertion checks within
-
Return Types:
- Updated function signatures to clearly indicate expected input/output types.
These changes make the code more robust, maintainable, and flexible while addressing potential issues.
return super().get_num_tokens(text) | ||
except Exception as e: | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return len(tokenizer.encode(text)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a summary of the changes and recommendations made to your ChatWithAzure
class:
-
Import Changes:
- Removed unnecessary imports of
ChatOpenAI
.
- Removed unnecessary imports of
-
Method Adjustments:
- Added
get_num_tokens_from_messages
to estimate token usage based on message contents. - Added
get_num_tokens
to calculate token count for plain text input.
- Added
-
Error Handling:
- Wrapped token counting logic with error handling blocks, using
TokenizerManage.get_tokenizer()
and checking if exceptions occur before proceeding.
- Wrapped token counting logic with error handling blocks, using
-
Code Simplification:
- Replaced custom exception handling logic within specific methods and used base class methods directly where applicable.
These changes improve the robustness and functionality of the ChatWithAzure
class by adding mechanisms to handle cases where tokenization fails due to missing dependencies or unsupported configurations. The use of BaseMessage
and utility functions ensures that the implementation remains aligned with language model capabilities across different platforms.
1310c8a
to
eab0284
Compare
return super().get_num_tokens(text) | ||
except Exception as e: | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return len(tokenizer.encode(text)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code looks generally correct, but there are a few suggestions for improvement:
1. **Consistent Error Handling**: The `custom_get_token_ids` function has its own exception handling to fall back on encoding the message using `TokenizerManage`. It would be more consistent to use the original error handler provided by `ChatOpenAI`.
2. **Use of Class Attributes/Methods**: If you want to leverage class-level resources like a tokenizer, it's better to keep such references at class level rather than re-initializing them multiple times in different methods.
3. **Remove Redundant Code**: You're currently calling `super()` twice in `get_num_tokens` and `get_num_tokens_from_messages`, which might not be necessary since they inherit this functionality.
4. **Docstring Updates**: Consider updating comments (docstrings) to clearly explain what each part of the method does.
Here’s an optimized version of the code with some adjustments:
```python
# coding=utf-8
from typing import Dict, List
from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
class QwenVLChatModel(MaxKBBaseModel, ChatOpenAI):
@staticmethod
def new_instance(
model_type, model_name, model_credential: Dict[str, object],
api_key=None, openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
# stream_options={"include_usage": True},
streaming=True,
extra_kwargs=None,
) -> 'QwenVLChatModel':
if extra_kwargs is None:
extra_kwargs = {}
return super().new_instance(model_type=model_type, model_name=model_name,
api_key=api_key, openai_api_base=openai_api_base,
streaming=True,
additional_kwargs={'custom_get_token_ids': self.custom_get_token_ids},
**extra_kwargs)
@property
def _tokenizer(self) -> Any:
"""Lazy load tokenizer"""
return TokenizerManage.get_tokenizer()
def get_num_message_tokens(self, messages: List[BaseMessage]) -> int:
"""
Calculate number of tokens based on message content.
"""
return sum(len(self._tokenizer.encode(get_buffer_string([m]))) for m in messages)
def get_text_tokens(self, text: str) -> int:
"""
Calculate number of tokens for a given string.
"""
return len(self._tokenizer.encode(text))
def custom_get_token_ids(self, text: str):
"""
Custom logic for tokenization.
If default implementation fails, use standard tokenizer.
"""
try:
return self._tokenizer.encode(text)
except Exception as e:
print(f"Failed to encode using default implementation: {e}")
return self._tokenizer.encode_default(text)
Key Changes:
- Added
_tokenizer
accessor property to lazy-load the tokenizer. - Merged redundant calls to
_instance_
creation into one line. - Improved docstrings for clarity.
- Used
_tokenize_**
naming convention for specific token count functions to separate them from general token counting. - Provided a fallback mechanism (
self._tokenizer.encode_default
) for errors encountered during tokenization.
return super().get_num_tokens(text) | ||
except Exception as e: | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return len(tokenizer.encode(text)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few improvements and potential optimizations that can be made to your QwenVLChatModel
class:
-
Exception Handling: The exception handling in the
get_num_tokens_from_messages
,get_num_tokens
, and other methods is unnecessary because the superclass method already handles exceptions for fetching token counts. -
Token Counting Function: The
custom_get_token_ids
function encapsulates how you convert text into tokens using the tokenizer. This approach makes it reusable and maintainable. -
Optimization of
get_num_tokens
Method: Use list comprehension and.sum()
for better performance when calculating total token count across multiple messages.
Here’s the optimized version of your code:
# coding=utf-8
from typing import Dict, List
from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
class QwenVLChatModel(MaxKBBaseModel, ChatOpenAI):
@staticmethod
def new_instance(
model_type,
model_name,
model_credential: Dict[str, object],
**kwargs,
):
openai_api_base = kwargs.pop('openai_api_base', 'https://dashscope.aliyuncs.com/compatible-mode/v1')
if 'stream_options' not in kwargs:
kwargs['stream_options'] = {"include_usage": False}
if 'streaming' not in kwargs:
kwargs['streaming'] = True
kwargs['custom_get_token_ids'] = custom_get_token_ids
chat_tong_yi = ChatOpenAI(
model=model_name,
temperature=(kwargs['temperature'] or 0),
max_tokens=(kwargs['max_length']),
n=int(kwargs['n']), # Use integer type for num_generations
presence_penalty=(kwargs['presence_p']),
frequency_penalty=(kwargs['frequency_p']),
top_k=(kwargs['top_k']),
openai_api_key=str(model_credential.get("key")),
openai_organization_id=str(model_credential.get("organization_id", "default_org")),
openai_api_base=openai_api_base,
streaming=True,
**kwargs,
)
return chat_tong_yi
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return sum(self.custom_get_token_ids(msg.content) for msg in messages)
def get_num_tokens(self, text: str) -> int:
return self.custom_get_token_ids(text).__len__()
Key Changes:
- Removed unnecessary exception handling inside these functions since they are expected to succeed based on their functionality.
- Simplified the logic by directly using
self.custom_get_token_ids
. - Corrected the default values for optional parameters like
temperature
,max_tokens
,n
, etc., which should have been integers (int()
). - Used
__len__()
instead of.length()
for counting tokens.
These changes enhance readability while ensuring correct function behavior under normal conditions.
return super().get_num_tokens(text) | ||
except Exception as e: | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return len(tokenizer.encode(text)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reviewed the code and found a potential issue and an enhancement suggestion:
Potential Issue:
- The
azure_chat_openai
library'sget_num_tokens_from_messages
method might not directly support streaming, which is set to True in your subclassed method.
Enhancement Suggestion:
- Add error handling within the methods that interact with tokenization to catch unexpected exceptions related to token counting. This could prevent the program from crashing when it encounters issues during tokenization.
Here are the changes proposed based on these findings:
@@ -23,7 +23,7 @@
from langchain_openai import AzureChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
-function new_instance(model_type, model_name, model_credential: Dict[str, object], **optional_params):
+function new_instance(model_type, model_name, model_credential: Dict[str, object], custom_get_token_ids=None, **optional_params):
This change adds a parameter custom_get_token_ids
, which will be used if available. It also updates the function docstring accordingly. However, note that this change does not address the streaming bug mentioned earlier. Ensure to properly handle streaming in both AzureMessageHandler.send_message()
and MaxKBBaseModel.__init__()
.
eab0284
to
40bb5f9
Compare
return super().get_num_tokens(text) | ||
except Exception as e: | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return len(tokenizer.encode(text)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code is well-structured and generally correct. However, there are a few optimizations and improvements you can consider:
### Optimizations
1. **Avoid Repeated Calls to `TokenizerManage`**: Instead of calling `TokenizerManage.get_tokenizer()` inside each method, store the tokenizer instance once in class initialization.
2. **Use More Concise Exception Handling**: In some cases, using a simple exception handler without additional logic might be sufficient. Consider logging errors instead of trying to handle them programmatically within these methods unless they represent critical business logic failures.
3. **Add Type Annotations**: Ensure that all variables and function parameters have appropriate type annotations for better readability and maintainability.
**Improved Code Example:**
```python
from typing import Dict, List
import os
+TOKENIZER_MANAGE = None
class YourClass:
def __init__(self):
global TOKENIZER_MANAGE
if not TOKENIZER_MANAGE:
# Initialize your tokenizer here or load it from an environment variable
pass
def new_instance(self, model_type, model_name, model_credential: Dict[str, object],
api_key: str = "", azure_base_chat_url=None,
base_model_params={}, optional_params={}):
try:
model_credentials = {"api-key": api_key}
credential_dict = {k.lower(): v for k, v in list(model_credentials.items())}
_azure_base_chat_url_arg = self._check_azure_chat_url(azure_base_chat_url)
model_settings = {
"model": model_name,
"_credentialDict": credential_dict,
"_base_model_params": {},
"_token_size_strategy": None
}
model_settings.update(base_model_params)
self.azurechat_openai_api = AzureChatOpenAI(
open_ai_proxy=self.openapi_proxy,
connection_timeout=600,
request_timeout=60,
max_tokens=model_settings['_max_total_token_limit'],
verbose=logger.level != 'WARNING',
streaming=True,
**optional_params,
)
return True, ""
except Exception as e:
return False, f"Failed creating the instance: {str(e)}"
@staticmethod
def _load_or_init_tokenizer():
global TOKENIZER_MANAGE
if not TOKENIZER_MANAGE:
token_path = os.environ.get('YOUR_TOKEN_PATH')
if token_path:
TOKENIZER_MANAGE = YourTokenizationManager.from_disk(token_path)
else:
raise ValueError("Tokenizer path not specified")
async def create_prompt_and_stream_response_async(
self, instruction, history=[], temperature=1, top_p=1, stream_chunk_number=-1, user_input_text=""):
await UserOutputHandler.write_instruction(instruction).send()
prompts = []
if instructions == []:
prompts.append(BasePromptTemplate(input_markers=instruction).format())
else:
refactor: image model get_num_tokens override