Skip to content

Commit 576c1b1

Browse files
authored
VertexAI: add test cases for countTokens() (#8317)
1 parent a90255a commit 576c1b1

File tree

5 files changed

+150
-0
lines changed

5 files changed

+150
-0
lines changed

Diff for: packages/vertexai/src/methods/count-tokens.test.ts

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/**
2+
* @license
3+
* Copyright 2024 Google LLC
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
import { expect, use } from 'chai';
19+
import { match, restore, stub } from 'sinon';
20+
import sinonChai from 'sinon-chai';
21+
import chaiAsPromised from 'chai-as-promised';
22+
import { getMockResponse } from '../../test-utils/mock-response';
23+
import * as request from '../requests/request';
24+
import { countTokens } from './count-tokens';
25+
import { CountTokensRequest } from '../types';
26+
import { ApiSettings } from '../types/internal';
27+
import { Task } from '../requests/request';
28+
29+
use(sinonChai);
30+
use(chaiAsPromised);
31+
32+
const fakeApiSettings: ApiSettings = {
33+
apiKey: 'key',
34+
project: 'my-project',
35+
location: 'us-central1'
36+
};
37+
38+
const fakeRequestParams: CountTokensRequest = {
39+
contents: [{ parts: [{ text: 'hello' }], role: 'user' }]
40+
};
41+
42+
describe('countTokens()', () => {
43+
afterEach(() => {
44+
restore();
45+
});
46+
it('total tokens', async () => {
47+
const mockResponse = getMockResponse(
48+
'count-tokens-success-total-tokens.json'
49+
);
50+
const makeRequestStub = stub(request, 'makeRequest').resolves(
51+
mockResponse as Response
52+
);
53+
const result = await countTokens(
54+
fakeApiSettings,
55+
'model',
56+
fakeRequestParams
57+
);
58+
expect(result.totalTokens).to.equal(6);
59+
expect(result.totalBillableCharacters).to.equal(16);
60+
expect(makeRequestStub).to.be.calledWith(
61+
'model',
62+
Task.COUNT_TOKENS,
63+
fakeApiSettings,
64+
false,
65+
match((value: string) => {
66+
return value.includes('contents');
67+
}),
68+
undefined
69+
);
70+
});
71+
it('total tokens no billable characters', async () => {
72+
const mockResponse = getMockResponse(
73+
'count-tokens-success-no-billable-characters.json'
74+
);
75+
const makeRequestStub = stub(request, 'makeRequest').resolves(
76+
mockResponse as Response
77+
);
78+
const result = await countTokens(
79+
fakeApiSettings,
80+
'model',
81+
fakeRequestParams
82+
);
83+
expect(result.totalTokens).to.equal(258);
84+
expect(result).to.not.have.property('totalBillableCharacters');
85+
expect(makeRequestStub).to.be.calledWith(
86+
'model',
87+
Task.COUNT_TOKENS,
88+
fakeApiSettings,
89+
false,
90+
match((value: string) => {
91+
return value.includes('contents');
92+
}),
93+
undefined
94+
);
95+
});
96+
it('model not found', async () => {
97+
const mockResponse = getMockResponse(
98+
'count-tokens-failure-model-not-found.json'
99+
);
100+
const mockFetch = stub(globalThis, 'fetch').resolves({
101+
ok: false,
102+
status: 404,
103+
json: mockResponse.json
104+
} as Response);
105+
await expect(
106+
countTokens(fakeApiSettings, 'model', fakeRequestParams)
107+
).to.be.rejectedWith(/404.*not found/);
108+
expect(mockFetch).to.be.called;
109+
});
110+
});

Diff for: packages/vertexai/src/models/generative-model.test.ts

+20
Original file line numberDiff line numberDiff line change
@@ -262,4 +262,24 @@ describe('GenerativeModel', () => {
262262
);
263263
restore();
264264
});
265+
it('calls countTokens', async () => {
266+
const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model' });
267+
const mockResponse = getMockResponse(
268+
'count-tokens-success-total-tokens.json'
269+
);
270+
const makeRequestStub = stub(request, 'makeRequest').resolves(
271+
mockResponse as Response
272+
);
273+
await genModel.countTokens('hello');
274+
expect(makeRequestStub).to.be.calledWith(
275+
'publishers/google/models/my-model',
276+
request.Task.COUNT_TOKENS,
277+
match.any,
278+
false,
279+
match((value: string) => {
280+
return value.includes('hello');
281+
})
282+
);
283+
restore();
284+
});
265285
});
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"error": {
3+
"code": 404,
4+
"message": "models/test-model-name is not found for API version v1beta, or is not supported for countTokens. Call ListModels to see the list of available models and their supported methods.",
5+
"status": "NOT_FOUND",
6+
"details": [
7+
{
8+
"@type": "type.googleapis.com/google.rpc.DebugInfo",
9+
"detail": "[ORIGINAL ERROR] generic::not_found: models/test-model-name is not found for API version v1beta, or is not supported for countTokens. Call ListModels to see the list of available models and their supported methods. [google.rpc.error_details_ext] { message: \"models/test-model-name is not found for API version v1beta, or is not supported for countTokens. Call ListModels to see the list of available models and their supported methods.\" }"
10+
}
11+
]
12+
}
13+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"totalTokens": 258
3+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"totalTokens": 6,
3+
"totalBillableCharacters": 16
4+
}

0 commit comments

Comments
 (0)