Skip to content

Commit 449c7a2

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: safety return types
PiperOrigin-RevId: 591915953
1 parent 04a32e6 commit 449c7a2

File tree

2 files changed

+87
-35
lines changed

2 files changed

+87
-35
lines changed

src/index_test.ts

+62-31
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,19 @@ import 'jasmine';
2020

2121
import {ChatSession, GenerativeModel, StartChatParams, VertexAI} from './index';
2222
import * as StreamFunctions from './process_stream';
23-
import {CountTokensRequest, FinishReason, GenerateContentRequest, GenerateContentResponse, GenerateContentResult, HarmBlockThreshold, HarmCategory, StreamGenerateContentResult,} from './types/content';
23+
import {
24+
CountTokensRequest,
25+
FinishReason,
26+
GenerateContentRequest,
27+
GenerateContentResponse,
28+
GenerateContentResult,
29+
HarmBlockThreshold,
30+
HarmCategory,
31+
HarmProbability,
32+
SafetyRating,
33+
SafetySetting,
34+
StreamGenerateContentResult,
35+
} from './types/content';
2436
import {constants} from './util';
2537

2638
const PROJECT = 'test_project';
@@ -38,8 +50,8 @@ const TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE = [
3850
{
3951
file_data: {
4052
file_uri: 'gs://test_bucket/test_image.jpeg',
41-
mime_type: 'image/jpeg'
42-
}
53+
mime_type: 'image/jpeg',
54+
},
4355
},
4456
],
4557
},
@@ -55,10 +67,17 @@ const TEST_USER_CHAT_MESSAGE_WITH_INVALID_GCS_FILE = [
5567
},
5668
];
5769

58-
const TEST_SAFETY_RATINGS = [
70+
const TEST_SAFETY_SETTINGS: SafetySetting[] = [
5971
{
6072
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
61-
threshold: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
73+
threshold: HarmBlockThreshold.BLOCK_ONLY_HIGH,
74+
},
75+
];
76+
77+
const TEST_SAFETY_RATINGS: SafetyRating[] = [
78+
{
79+
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
80+
probability: HarmProbability.NEGLIGIBLE,
6281
},
6382
];
6483
const TEST_GENERATION_CONFIG = {
@@ -76,13 +95,14 @@ const TEST_CANDIDATES = [
7695
finishMessage: '',
7796
safetyRatings: TEST_SAFETY_RATINGS,
7897
citationMetadata: {
79-
citationSources: [{
80-
startIndex: 367,
81-
endIndex: 491,
82-
uri:
83-
'https://www.numerade.com/ask/question/why-does-the-uncertainty-principle-make-it-impossible-to-predict-a-trajectory-for-the-clectron-95172/'
84-
}]
85-
}
98+
citationSources: [
99+
{
100+
startIndex: 367,
101+
endIndex: 491,
102+
uri: 'https://www.numerade.com/ask/question/why-does-the-uncertainty-principle-make-it-impossible-to-predict-a-trajectory-for-the-clectron-95172/',
103+
},
104+
],
105+
},
86106
},
87107
];
88108
const TEST_MODEL_RESPONSE = {
@@ -178,8 +198,9 @@ describe('VertexAI', () => {
178198
response: Promise.resolve(TEST_MODEL_RESPONSE),
179199
stream: testGenerator(),
180200
};
181-
spyOn(StreamFunctions, 'processStream')
182-
.and.returnValue(expectedStreamResult);
201+
spyOn(StreamFunctions, 'processStream').and.returnValue(
202+
expectedStreamResult
203+
);
183204
const resp = await model.generateContent(req);
184205
expect(resp).toEqual(expectedResult);
185206
});
@@ -190,15 +211,17 @@ describe('VertexAI', () => {
190211
const req: GenerateContentRequest = {
191212
contents: TEST_USER_CHAT_MESSAGE_WITH_INVALID_GCS_FILE,
192213
};
193-
await expectAsync(model.generateContent(req)).toBeRejectedWithError(URIError);
214+
await expectAsync(model.generateContent(req)).toBeRejectedWithError(
215+
URIError
216+
);
194217
});
195218
});
196219

197220
describe('generateContent', () => {
198221
it('returns a GenerateContentResponse when passed safety_settings and generation_config', async () => {
199222
const req: GenerateContentRequest = {
200223
contents: TEST_USER_CHAT_MESSAGE,
201-
safety_settings: TEST_SAFETY_RATINGS,
224+
safety_settings: TEST_SAFETY_SETTINGS,
202225
generation_config: TEST_GENERATION_CONFIG,
203226
};
204227
const expectedResult: GenerateContentResult = {
@@ -223,7 +246,8 @@ describe('VertexAI', () => {
223246
location: LOCATION,
224247
apiEndpoint: TEST_ENDPOINT_BASE_PATH,
225248
});
226-
vertexaiWithBasePath.preview['tokenInternalPromise'] = Promise.resolve(TEST_TOKEN);
249+
vertexaiWithBasePath.preview['tokenInternalPromise'] =
250+
Promise.resolve(TEST_TOKEN);
227251
model = vertexaiWithBasePath.preview.getGenerativeModel({
228252
model: 'gemini-pro',
229253
});
@@ -255,7 +279,8 @@ describe('VertexAI', () => {
255279
project: PROJECT,
256280
location: LOCATION,
257281
});
258-
vertexaiWithoutBasePath.preview['tokenInternalPromise'] = Promise.resolve(TEST_TOKEN);
282+
vertexaiWithoutBasePath.preview['tokenInternalPromise'] =
283+
Promise.resolve(TEST_TOKEN);
259284
model = vertexaiWithoutBasePath.preview.getGenerativeModel({
260285
model: 'gemini-pro',
261286
});
@@ -275,8 +300,9 @@ describe('VertexAI', () => {
275300
expectedStreamResult
276301
);
277302
await model.generateContent(req);
278-
expect(requestSpy.calls.allArgs()[0][0].toString())
279-
.toContain(`${LOCATION}-aiplatform.googleapis.com`);
303+
expect(requestSpy.calls.allArgs()[0][0].toString()).toContain(
304+
`${LOCATION}-aiplatform.googleapis.com`
305+
);
280306
});
281307
});
282308

@@ -295,11 +321,12 @@ describe('VertexAI', () => {
295321
stream: testGenerator(),
296322
};
297323
const requestSpy = spyOn(global, 'fetch');
298-
spyOn(StreamFunctions, 'processStream')
299-
.and.returnValue(expectedStreamResult);
324+
spyOn(StreamFunctions, 'processStream').and.returnValue(
325+
expectedStreamResult
326+
);
300327
await model.generateContent(reqWithEmptyConfigs);
301328
const requestArgs = requestSpy.calls.allArgs()[0][1];
302-
if (typeof requestArgs == 'object' && requestArgs) {
329+
if (typeof requestArgs === 'object' && requestArgs) {
303330
expect(JSON.stringify(requestArgs['body'])).not.toContain('top_k');
304331
}
305332
});
@@ -320,11 +347,12 @@ describe('VertexAI', () => {
320347
stream: testGenerator(),
321348
};
322349
const requestSpy = spyOn(global, 'fetch');
323-
spyOn(StreamFunctions, 'processStream')
324-
.and.returnValue(expectedStreamResult);
350+
spyOn(StreamFunctions, 'processStream').and.returnValue(
351+
expectedStreamResult
352+
);
325353
await model.generateContent(reqWithEmptyConfigs);
326354
const requestArgs = requestSpy.calls.allArgs()[0][1];
327-
if (typeof requestArgs == 'object' && requestArgs) {
355+
if (typeof requestArgs === 'object' && requestArgs) {
328356
expect(JSON.stringify(requestArgs['body'])).toContain('top_k');
329357
}
330358
});
@@ -342,14 +370,17 @@ describe('VertexAI', () => {
342370
response: Promise.resolve(TEST_MODEL_RESPONSE),
343371
stream: testGenerator(),
344372
};
345-
spyOn(StreamFunctions, 'processStream')
346-
.and.returnValue(expectedStreamResult);
373+
spyOn(StreamFunctions, 'processStream').and.returnValue(
374+
expectedStreamResult
375+
);
347376
const resp = await model.generateContent(req);
348377
console.log(resp.response.candidates[0].citationMetadata, 'yoyoyo');
349378
expect(
350-
resp.response.candidates[0].citationMetadata?.citationSources.length)
351-
.toEqual(TEST_MODEL_RESPONSE.candidates[0]
352-
.citationMetadata.citationSources.length);
379+
resp.response.candidates[0].citationMetadata?.citationSources.length
380+
).toEqual(
381+
TEST_MODEL_RESPONSE.candidates[0].citationMetadata.citationSources
382+
.length
383+
);
353384
});
354385
});
355386

src/types/content.ts

+25-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
/**
1919
* Params used to initialize the Vertex SDK
2020
* @param{string} project - the project name of your Google Cloud project. It is not the numeric project ID.
21-
* @param{string} location - the location of your project.
21+
* @param{string} location - the location of your project.
2222
* @param{string} apiEndpoint - Optional. If not specified, a default value will be resolved by SDK.
2323
*/
2424
export declare interface VertexInit {
@@ -51,7 +51,7 @@ export declare interface CountTokensResponse {
5151

5252
/**
5353
* Configuration for initializing a model, for example via getGenerativeModel
54-
* @param {string} model - model name.
54+
* @param {string} model - model name.
5555
* @example "gemini-pro"
5656
*/
5757
export declare interface ModelParams extends BaseModelParams {
@@ -85,7 +85,9 @@ export declare interface GenerationConfig {
8585
top_p?: number;
8686
top_k?: number;
8787
}
88-
88+
/**
89+
* Harm categories that would cause prompts or candidates to be blocked.
90+
*/
8991
export enum HarmCategory {
9092
HARM_CATEGORY_UNSPECIFIED = 'HARM_CATEGORY_UNSPECIFIED',
9193
HARM_CATEGORY_HATE_SPEECH = 'HARM_CATEGORY_HATE_SPEECH',
@@ -94,6 +96,9 @@ export enum HarmCategory {
9496
HARM_CATEGORY_SEXUALLY_EXPLICIT = 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
9597
}
9698

99+
/**
100+
* Threshhold above which a prompt or candidate will be blocked.
101+
*/
97102
export enum HarmBlockThreshold {
98103
// Unspecified harm block threshold.
99104
HARM_BLOCK_THRESHOLD_UNSPECIFIED = 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
@@ -107,12 +112,28 @@ export enum HarmBlockThreshold {
107112
BLOCK_NONE = 'BLOCK_NONE',
108113
}
109114

115+
/**
116+
* Probability that a prompt or candidate matches a harm category.
117+
*/
118+
export enum HarmProbability {
119+
// Probability is unspecified.
120+
HARM_PROBABILITY_UNSPECIFIED = 'HARM_PROBABILITY_UNSPECIFIED',
121+
// Content has a negligible chance of being unsafe.
122+
NEGLIGIBLE = 'NEGLIGIBLE',
123+
// Content has a low chance of being unsafe.
124+
LOW = 'LOW',
125+
// Content has a medium chance of being unsafe.
126+
MEDIUM = 'MEDIUM',
127+
// Content has a high chance of being unsafe.
128+
HIGH = 'HIGH',
129+
}
130+
110131
/**
111132
* Safety rating for a piece of content
112133
*/
113134
export declare interface SafetyRating {
114135
category: HarmCategory;
115-
threshold: HarmBlockThreshold;
136+
probability: HarmProbability;
116137
}
117138

118139
/**

0 commit comments

Comments
 (0)