Skip to content

Commit 400a5a1

Browse files
committed
server : do not normalize embeddings when there is no pooling
ggml-ci
1 parent c63d869 commit 400a5a1

File tree

6 files changed

+20
-6
lines changed

6 files changed

+20
-6
lines changed

common/common.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1780,7 +1780,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm)
17801780
break;
17811781
case 0: // max absolute
17821782
for (int i = 0; i < n; i++) {
1783-
if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
1783+
if (sum < std::abs(inp[i])) {
1784+
sum = std::abs(inp[i]);
1785+
}
17841786
}
17851787
sum /= 32760.0; // make an int16 range
17861788
break;

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si
596596
// Embedding utils
597597
//
598598

599-
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
599+
// TODO: repace embd_norm with an enum
600+
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
600601

601602
float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);
602603

examples/gritlm/gritlm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
7575
}
7676

7777
std::vector<float> emb_norm(emb_unorm.size());
78-
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
78+
common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2);
7979
result.push_back(emb_norm);
8080

8181
#ifdef GRIT_DEBUG

examples/retrieval/retrieval.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
107107
}
108108

109109
float * out = output + batch.seq_id[i][0] * n_embd;
110-
common_embd_normalize(embd, out, n_embd);
110+
common_embd_normalize(embd, out, n_embd, 2);
111111
}
112112
}
113113

examples/server/server.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2049,8 +2049,14 @@ struct server_context {
20492049
continue;
20502050
}
20512051

2052-
common_embd_normalize(embd, embd_res.data(), n_embd);
2053-
res->embedding.push_back(embd_res);
2052+
// normalize only when there is pooling
2053+
// TODO: configurable
2054+
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
2055+
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
2056+
res->embedding.push_back(embd_res);
2057+
} else {
2058+
res->embedding.push_back({ embd, embd + n_embd });
2059+
}
20542060
}
20552061

20562062
SLT_DBG(slot, "%s", "sending embeddings\n");

examples/server/tests/unit/test_embedding.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ def test_embedding_pooling_none():
5858
assert 'embedding' in res.body[0]
5959
assert len(res.body[0]['embedding']) == 3
6060

61+
# make sure embedding vector is not normalized
62+
for x in res.body[0]['embedding']:
63+
assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
64+
6165

6266
def test_embedding_pooling_none_oai():
6367
global server
@@ -66,6 +70,7 @@ def test_embedding_pooling_none_oai():
6670
res = server.make_request("POST", "/v1/embeddings", data={
6771
"input": "hello hello hello",
6872
})
73+
6974
# /v1/embeddings does not support pooling type 'none'
7075
assert res.status_code == 400
7176

0 commit comments

Comments
 (0)