Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

refactor: 查询模型优化 #1358

Merged
merged 2 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion apps/dataset/views/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ class Model(APIView):
dynamic_tag=keywords.get('dataset_id'))],
compare=CompareConstants.AND))
def get(self, request: Request, dataset_id: str):
print(dataset_id)
return result.success(
ModelSerializer.Query(
data={'user_id': request.user.id, 'model_type': 'LLM'}).list(
Expand Down
21 changes: 19 additions & 2 deletions apps/setting/serializers/provider_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,26 @@ class Query(serializers.Serializer):

provider = serializers.CharField(required=False, error_messages=ErrMessage.char("供应商"))

permission_type = serializers.CharField(required=False, error_messages=ErrMessage.char("权限"))

create_user = serializers.CharField(required=False, error_messages=ErrMessage.char("创建者"))


def list(self, with_valid):
if with_valid:
self.is_valid(raise_exception=True)
user_id = self.data.get('user_id')
name = self.data.get('name')
model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC')))
create_user = self.data.get('create_user')
if create_user is not None:
# 当前用户能查看自己的模型,包括公开和私有的
if create_user == user_id:
model_query_set = QuerySet(Model).filter(Q(user_id=create_user))
# 当前用户能查看其他人的模型,只能查看公开的
else:
model_query_set = QuerySet(Model).filter((Q(user_id=self.data.get('create_user')) & Q(permission_type='PUBLIC')))
else:
model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC')))
query_params = {}
if name is not None:
query_params['name__contains'] = name
Expand All @@ -90,11 +104,14 @@ def list(self, with_valid):
query_params['model_name'] = self.data.get('model_name')
if self.data.get('provider') is not None:
query_params['provider'] = self.data.get('provider')
if self.data.get('permission_type') is not None:
query_params['permission_type'] = self.data.get('permission_type')


return [
{'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name, 'status': model.status, 'meta': model.meta,
'permission_type': model.permission_type, 'user_id': model.user_id} for model in
'permission_type': model.permission_type, 'user_id': model.user_id, 'username': model.user.username} for model in
model_query_set.filter(**query_params).order_by("-create_time")]

class Edit(serializers.Serializer):
Expand Down
1 change: 1 addition & 0 deletions ui/src/api/type/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ interface Model {
*/
model_type: string
user_id: string
username: string
permission_type: 'PUBLIC' | 'PRIVATE'
/**
* 基础模型
Expand Down
6 changes: 6 additions & 0 deletions ui/src/views/template/component/ModelCard.vue
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@
{{ model.model_name }}</span
>
</li>
<li class="flex mt-12">
<el-text type="info">创建者</el-text>
<span class="ellipsis-1 ml-16" style="height: 20px; width: 70%">
{{ model.username }}</span
>
</li>
</ul>
</div>
<!-- progress -->
Expand Down
65 changes: 56 additions & 9 deletions ui/src/views/template/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,42 @@
<h4>{{ active_provider?.name }}</h4>
<div class="flex-between mt-16 mb-16">
<el-button type="primary" @click="openCreateModel(active_provider)">添加模型</el-button>
<el-input
v-model="model_search_form.name"
@change="list_model"
placeholder="按名称搜索"
prefix-icon="Search"
style="max-width: 240px"
clearable
/>
<div class="flex-between">
<el-select v-model="search_type" style="width: 200px" @change="search_type_change">
<el-option label="创建者" value="create_user" />
<el-option label="权限" value="permission_type" />
<el-option label="模型类型" value="model_type" />
<el-option label="模型名称" value="name" />
</el-select>
<el-input
v-if="search_type === 'name'"
v-model="model_search_form.name"
@change="list_model"
placeholder="按名称搜索"
prefix-icon="Search"
style="max-width: 240px"
clearable
/>
<el-select v-else-if="search_type === 'create_user'" v-model="model_search_form.create_user" @change="list_model"
clearable>
<el-option v-for="u in user_options" :key="u.id" :value="u.id" :label="u.username" />
</el-select>
<el-select v-else-if="search_type === 'permission_type'" v-model="model_search_form.permission_type"
clearable
@change="list_model">
<el-option label="公有" value="PUBLIC" />
<el-option label="私有" value="PRIVATE" />
</el-select>
<el-select v-else-if="search_type === 'model_type'" v-model="model_search_form.model_type"
clearable
@change="list_model">
<el-option label="大语言模型" value="LLM" />
<el-option label="向量模型" value="EMBEDDING" />
<el-option label="重排模型" value="RERANKER" />
<el-option label="语音识别" value="STT" />
<el-option label="语音合成" value="TTS" />
</el-select>
</div>
</div>
</div>
<div class="model-list-height">
Expand Down Expand Up @@ -114,7 +142,14 @@ const allObj = {
const loading = ref<boolean>(false)

const active_provider = ref<Provider>()
const model_search_form = ref<{ name: string }>({ name: '' })
const search_type = ref('name')
const model_search_form = ref<{ name: string, create_user: string, permission_type: string, model_type: string }>({
name: '',
create_user: '',
permission_type: '',
model_type: ''
})
const user_options = ref<any[]>([])
const list_model_loading = ref<boolean>(false)
const provider_list = ref<Array<Provider>>([])

Expand Down Expand Up @@ -150,9 +185,20 @@ const list_model = () => {
const params = active_provider.value?.provider ? { provider: active_provider.value.provider } : {}
ModelApi.getModel({ ...model_search_form.value, ...params }, list_model_loading).then((ok) => {
model_list.value = ok.data
const v = model_list.value.map((m) => ({ id: m.user_id, username: m.username }))
if (user_options.value.length === 0){
user_options.value = Array.from(
new Map(v.map(item => [item.id, item])).values()
)
}
})
}

const search_type_change = () => {
model_search_form.value = { name: '', create_user: '', permission_type: '', model_type: '' }
}


onMounted(() => {
ModelApi.getProvider(loading).then((ok) => {
active_provider.value = allObj
Expand All @@ -173,6 +219,7 @@ onMounted(() => {
.model-list-height {
height: calc(var(--create-dataset-height) - 70px);
}

.model-list-height-left {
height: calc(var(--create-dataset-height));
}
Expand Down
Loading