-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
feat: Generate problem support for generating unfinished paragraphs #2299
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
from typing import Dict | ||
from typing import Dict, List | ||
|
||
from langchain_core.messages import BaseMessage, get_buffer_string | ||
|
||
from common.config.tokenizer_manage_config import TokenizerManage | ||
from setting.models_provider.base_model_provider import MaxKBBaseModel | ||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI | ||
|
||
|
@@ -18,3 +21,15 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** | |
stream_usage=True, | ||
**optional_params, | ||
) | ||
|
||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: | ||
if self.usage_metadata is None or self.usage_metadata == {}: | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) | ||
return self.usage_metadata.get('input_tokens', 0) | ||
|
||
def get_num_tokens(self, text: str) -> int: | ||
if self.usage_metadata is None or self.usage_metadata == {}: | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return len(tokenizer.encode(text)) | ||
return self.get_last_generation_info().get('output_tokens', 0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The provided code is mostly clean and follows typical Python practices for a class that extends a base model with additional functionality related to message processing and token counting. However, there are a few improvements and optimizations that can be made:
Here’s an enhanced version of the code with these considerations: from typing import Dict, List
import langchain_core.messages as lc_messages
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
class CustomBaseChatOpenAI(BaseChatOpenAI):
def __init__(self, model_type, model_name, model_credential: Dict[str, object], **optional_params):
super().__init__(
model_type=model_type,
model_name=model_name,
max_seq_len=None,
max_total_tokens=self.max_total_tokens,
max_context_window=3072,
streaming=True,
**optional_params,
)
def count_tokens_in_messages(self, messages: List[lc_messages.BaseMessage]) -> int:
"""Count the total number of tokens in a list of chat messages."""
if self.usage_metadata is None or self.usage_metadata == {}:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(lc_messages.get_buffer_string([m]))) for m in messages])
return self.usage_metadata.get('input_tokens', 0)
def get_num_tokens_for_text(self, text: str) -> int:
"""Get the number of tokens required to encode a given input text."""
if self.usage_metadata is None or self.usage_metadata == {}:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
return self.get_last_generation_info().get('output_tokens', 0)
# Example usage in another module
# instance = CustomBaseChatOpenAI(...) Key Improvements:
These changes make the code easier to understand and maintain, improving its overall quality. Additionally, they follow best practices recommended by PEP 8 and adhere to naming conventions commonly used in modern software development. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,16 @@ | |
type="textarea" | ||
/> | ||
</el-form-item> | ||
<el-form-item :label="$t('views.problem.relateParagraph.selectParagraph')" prop="state"> | ||
<el-radio-group v-model="state" class="radio-block"> | ||
<el-radio value="error" size="large" class="mb-16">{{ | ||
$t('views.document.form.selectVectorization.error') | ||
}}</el-radio> | ||
<el-radio value="all" size="large">{{ | ||
$t('views.document.form.selectVectorization.all') | ||
}}</el-radio> | ||
</el-radio-group> | ||
</el-form-item> | ||
</el-form> | ||
</div> | ||
<template #footer> | ||
|
@@ -87,7 +97,11 @@ const dialogVisible = ref<boolean>(false) | |
const modelOptions = ref<any>(null) | ||
const idList = ref<string[]>([]) | ||
const apiType = ref('') // 文档document或段落paragraph | ||
|
||
const state = ref<'all' | 'error'>('error') | ||
const stateMap = { | ||
all: ['0', '1', '2', '3', '4', '5', 'n'], | ||
error: ['0', '1', '3', '4', '5', 'n'] | ||
} | ||
const FormRef = ref() | ||
const userId = user.userInfo?.id as string | ||
const form = ref(prompt.get(userId)) | ||
|
@@ -131,14 +145,22 @@ const submitHandle = async (formEl: FormInstance) => { | |
// 保存提示词 | ||
prompt.save(user.userInfo?.id as string, form.value) | ||
if (apiType.value === 'paragraph') { | ||
const data = { ...form.value, paragraph_id_list: idList.value } | ||
const data = { | ||
...form.value, | ||
paragraph_id_list: idList.value, | ||
state_list: stateMap[state.value] | ||
} | ||
paragraphApi.batchGenerateRelated(id, documentId, data, loading).then(() => { | ||
MsgSuccess(t('views.document.generateQuestion.successMessage')) | ||
emit('refresh') | ||
dialogVisible.value = false | ||
}) | ||
} else if (apiType.value === 'document') { | ||
const data = { ...form.value, document_id_list: idList.value } | ||
const data = { | ||
...form.value, | ||
document_id_list: idList.value, | ||
state_list: stateMap[state.value] | ||
} | ||
documentApi.batchGenerateRelated(id, data, loading).then(() => { | ||
MsgSuccess(t('views.document.generateQuestion.successMessage')) | ||
emit('refresh') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here are some suggestions to optimize and ensure correctness of the provided code:
Here’s the revised version based on these suggestions: import { FormInstance } from '@vxe-table/components/form';
import { MessageBox } from 'element-plus';
import { useUserStore } from '@/stores/user-store';
import { Prompt } from '@/types/prompt';
interface DocumentForm {
// other form fields...
}
const dialogVisible = ref<boolean>(false);
const modelOptions = ref<any>(null);
const idList = ref<string[]>([]);
const apiType = ref(''); // 文档document或段落paragraph
enum State {
All,
Error
}
let stateMap: Record<State, string[]> = {
[State.All]: ['0', '1', '2', '3', '4', '5', 'n'],
[State.Error]: ['0', '1', '3', '4', '5', 'n']
};
const FormRef = ref<FormInstance>();
const userId = user.userInfo?.id as string;
const form = ref<Prompt>(prompt.get(userId));
// Update the type of data accordingly for batchGenerateRelated
async function submitHandle(formEl: FormInstance) {
try {
if (!formEl) throw new Error('Form element is not found');
await formEl.validate();
// Save prompt
prompt.save(user.userInfo?.id as string, form.value);
const data: DocumentForm & { state_list }: DocumentForm & { state_list?: string[] } = {
...form.value,
document_id_list: idList.value,
state_list: stateMap[apiType.value === 'document'] ? stateMap['All'] : stateMap['Error']
};
if (apiType.value === 'paragraph') {
await paragraphApi.batchGenerateRelated(id, documentId, data, loading).then(() => {
MsgSuccess(t('views.document.generateQuestion.successMessage'));
emit('refresh');
dialogVisible.value = false;
});
} else if (apiType.value === 'document') {
await documentApi.batchGenerateRelated(id, data, loading).then(() => {
MsgSuccess(t('views.document.generateQuestion.successMessage'));
emit('refresh');
});
}
} catch (err) {
console.error(err);
MessageBox.alert(`${err.message}`, t('tips'), {
confirmButtonText: t('common.ok'),
cancelButtonText: t('common.cancel'),
}).catch((action) => {});
}
} Key Changes:
These changes should make the code cleaner, more maintainable, and potentially handle edge cases better. |
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your code looks generally well-structured and clean. However, there are a few recommendations for improvement:
String Literal Quotes: Consistently use either single (
'
) or double quotes ("
) for string literals to avoid any potential issues.Error Handling in
generate_related
Method: The method should handle cases wheremodel_id
is missing gracefully.Comments and Readability: Add comments explaining what each section of the methods does, especially complex logic like conditional blocks and function calls.
Import Statements: Ensure all import statements are included at the top of the file, making it easier to understand dependencies.
Here's an improved version of your code with these suggestions:
Key Improvements:
generate_related
MethodLet me know if you need further adjustments!