Skip to content

Commit

Permalink
Fix prompt usage when sharing the same session id (by stripping out t…
Browse files Browse the repository at this point in the history
…he preamble)
  • Loading branch information
tjake committed Dec 27, 2024
1 parent 190464c commit 7eda3aa
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ public void run() {

UUID session = UUID.randomUUID();
PromptSupport promptSupport = m.promptSupport().get();
PromptSupport.Builder builder = promptSupport.builder();
PrintWriter out = System.console().writer();

out.println("\nChatting with " + modelName + "...\n");
Expand All @@ -82,7 +83,6 @@ public void run() {
break;
}

PromptSupport.Builder builder = promptSupport.builder();
if (first && systemPrompt != null) {
builder.addSystemMessage(systemPrompt);
}
Expand All @@ -97,6 +97,9 @@ public void run() {
makeOutHandler()
);

// New prompt builder and strip out the preamble since we're continuing the conversation
builder = promptSupport.builder().stripPreamble();

out.println(
"\n\n"
+ statsColor.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,16 @@ public AbstractTensor batchForward(
KvBufferCache.KvBuffer kvbuf,
Optional<Consumer<List<AbstractTensor>>> tensorReducer
) {
AbstractTensor embedding = embedInput.batchInputsToEmbeddings(token_ids, startPos);
return forward(embedding, startPos, kvbuf, tensorReducer);
AbstractTensor embedding = null;

//Batch prompt into groups of 1024
for (int i = 0; i < token_ids.length; i += 1024) {
int[] batch = Arrays.copyOfRange(token_ids, i, Math.min(token_ids.length, i + 1024));
embedding = embedInput.batchInputsToEmbeddings(batch, startPos + i);
embedding = forward(embedding, startPos + i, kvbuf, tensorReducer);
}

return embedding;
}

public AbstractTensor forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ public static class Builder {
private boolean addGenerationPrompt = true;

private List<Message> messages = new ArrayList<>(2);
private boolean stripPreamble = false;

private Builder(TokenizerModel m) {
this.m = m;
Expand Down Expand Up @@ -230,6 +231,11 @@ public Builder addAssistantMessage(String content) {
return this;
}

public Builder stripPreamble() {
stripPreamble = true;
return this;
}

public PromptContext build() {
return build(Optional.empty());
}
Expand Down Expand Up @@ -259,8 +265,29 @@ private PromptContext build(Optional<List<Tool>> optionalTools) {
"This model does not support tools, but tools are specified"
);

Map<String, Object> args = new HashMap<>();

String preamble = "";
if (stripPreamble) {
Map<String, Object> args = new HashMap<>();
args.putAll(
Map.of(
"messages",
Map.of(),
"add_generation_prompt",
false,
"eos_token",
m.eosToken(),
"bos_token",
""
)
); // We add the BOS ourselves
optionalTools.ifPresent(tools -> args.put("tools", tools));

RenderResult r = jinjava.renderForResult(template, args);
preamble = r.getOutput();
}

Map<String, Object> args = new HashMap<>();
args.putAll(
Map.of(
"messages",
Expand All @@ -280,7 +307,8 @@ private PromptContext build(Optional<List<Tool>> optionalTools) {

if (r.hasErrors()) logger.debug("Prompt template errors: " + r.getErrors());

return new PromptContext(r.getOutput(), optionalTools);
String output = r.getOutput();
return new PromptContext(output.substring(preamble.length()), optionalTools);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NativeTensorOperations implements TensorOperations {
private static final Logger logger = LoggerFactory.getLogger(NativeTensorOperations.class);
public class NativeSimdTensorOperations implements TensorOperations {
private static final Logger logger = LoggerFactory.getLogger(NativeSimdTensorOperations.class);

static {
if (!JarSupport.maybeLoadLibrary()) System.loadLibrary("jlama");
if (!JarSupport.maybeLoadLibrary("jlama")) System.loadLibrary("jlama");
}

public static final int HAS_F16C = NativeSimd.HAS_F16C();
Expand All @@ -52,7 +52,7 @@ public class NativeTensorOperations implements TensorOperations {

final int flags;

public NativeTensorOperations() {
public NativeSimdTensorOperations() {
int f = 0;

if (RuntimeSupport.isLinux()) f |= HAS_F16C;
Expand All @@ -63,7 +63,7 @@ public NativeTensorOperations() {
checkLib();
}

NativeTensorOperations(int flags) {
NativeSimdTensorOperations(int flags) {
this.flags = flags;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@
public class JarSupport {
private static final Logger logger = LoggerFactory.getLogger(JarSupport.class);

public static boolean maybeLoadLibrary() {
public static boolean maybeLoadLibrary(String libname) {
String ext = RuntimeSupport.isMac() ? ".dylib" : RuntimeSupport.isWin() ? ".dll" : ".so";
URL lib = JarSupport.class.getClassLoader().getResource("META-INF/native/lib/libjlama" + ext);
URL lib = JarSupport.class.getClassLoader().getResource("META-INF/native/lib/lib" + libname + ext);

if (lib != null) {
try {
final File libpath = Files.createTempDirectory("jlama").toFile();
libpath.deleteOnExit(); // just in case

File libfile = Paths.get(libpath.getAbsolutePath(), "libjlama" + ext).toFile();
File libfile = Paths.get(libpath.getAbsolutePath(), "lib" + libname + ext).toFile();
libfile.deleteOnExit(); // just in case

final InputStream in = lib.openStream();
Expand All @@ -53,10 +53,10 @@ public static boolean maybeLoadLibrary() {
out.close();
in.close();
System.load(libfile.getAbsolutePath());
logger.debug("Loaded jlama-native library: {}", libfile.getAbsolutePath());
logger.debug("Loaded {}-native library: {}", libname, libfile.getAbsolutePath());
return true;
} catch (IOException e) {
logger.warn("Error loading jlama-native library");
logger.warn("Error loading {}-native library", libname);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ Object createChatCompletion(@RequestHeader Map<String, String> headers, @Valid @
if (request.getStream() != null && request.getStream()) {
SseEmitter emitter = new SseEmitter(-1L);
CompletableFuture.supplyAsync(
() -> model.generate(sessionId, builder.build(), temperature, maxTokens, (t, f) -> CompletableFuture.supplyAsync(() -> {
() -> model. generate(sessionId, builder.build(), temperature, maxTokens, (t, f) -> CompletableFuture.supplyAsync(() -> {
try {
emitter.send(
new CreateChatCompletionStreamResponse().id(sessionId.toString())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ public void testPromptSupportWithTools() {
@Test
public void testMistralTools() {

String modelPrefix = "../models/Mistral-7B-Instruct-v0.3";
String modelPrefix = "../models/tjake_Mistral-7B-Instruct-v0.3-JQ4";
Assume.assumeTrue(Files.exists(Paths.get(modelPrefix)));

Tokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix));
Expand Down Expand Up @@ -420,4 +420,51 @@ public void testToolParse() throws JsonProcessingException {

Assert.assertEquals(2, toolCalls.size());
}

@Test
public void testPromptBuilderSession() {
String modelPrefix = "../models/Qwen_Qwen2.5-0.5B-Instruct-JQ4";
Assume.assumeTrue(Files.exists(Paths.get(modelPrefix)));

Tokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix));
PromptSupport.Builder builder = tokenizer.promptSupport().get().builder();
builder.addSystemMessage("You always respond as a pirate");
builder.addUserMessage("What is the weather in paris right now?");
builder.addGenerationPrompt(true);

Tool t = Tool.from(
Function.builder()
.name("get_current_temperature")
.description("Simulates getting the current temperature at a location.")
.addParameter("location", "string", "The location to get the temperature for, in the format \"City, Country\".", true)
.addParameter("unit", "string", "The unit to return the temperature in (e.g., \"celsius\", \"fahrenheit\").", true)
.build()
);

PromptContext prompt = builder.build(t);
Assert.assertEquals(
"<|im_start|>system\n" + "You always respond as a pirate\n"
+ "\n"
+ "# Tools\n"
+ "\n"
+ "You may call one or more functions to assist with the user query.\n"
+ "\n"
+ "You are provided with function signatures within <tools></tools> XML tags:\n"
+ "<tools>\n"
+ "{\"type\": \"function\", \"function\": {\"name\": \"get_current_temperature\", \"description\": \"Simulates getting the current temperature at a location.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"location\": {\"type\": \"string\", \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\".\"}, \"unit\": {\"type\": \"string\", \"description\": \"The unit to return the temperature in (e.g., \\\"celsius\\\", \\\"fahrenheit\\\").\"}}, \"required\": [\"location\", \"unit\"]}}}\n"
+ "</tools>\n"
+ "\n"
+ "For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n"
+ "<tool_call>\n"
+ "{\"name\": <function-name>, \"arguments\": <args-json-object>}\n"
+ "</tool_call><|im_end|>\n"
+ "<|im_start|>user\n"
+ "What is the weather in paris right now?<|im_end|>\n"
+ "<|im_start|>assistant\n",
prompt.getPrompt());

prompt = tokenizer.promptSupport().get().builder().addUserMessage("This is a test").stripPreamble().build();
Assert.assertEquals(
"<|im_start|>user\n" + "This is a test<|im_end|>\n" + "<|im_start|>assistant\n", prompt.getPrompt());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.github.tjake.jlama.tensor.operations;

import static com.github.tjake.jlama.tensor.operations.NativeTensorOperations.*;
import static com.github.tjake.jlama.tensor.operations.NativeSimdTensorOperations.*;

import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.safetensors.DType;
Expand Down Expand Up @@ -59,19 +59,19 @@ public static void init() {
opTypes.add(new PanamaTensorOperations(MachineSpec.Type.AVX_256));
opTypes.add(new PanamaTensorOperations(MachineSpec.Type.ARM_128));

if (globalOps instanceof NativeTensorOperations) {
opTypes.add(new NativeTensorOperations());
opTypes.add(new NativeTensorOperations(0));
if (globalOps instanceof NativeSimdTensorOperations) {
opTypes.add(new NativeSimdTensorOperations());
opTypes.add(new NativeSimdTensorOperations(0));

if (MachineSpec.VECTOR_TYPE == MachineSpec.Type.AVX_512) opTypes.add(new NativeTensorOperations(HAS_AVX2));
if (MachineSpec.VECTOR_TYPE == MachineSpec.Type.AVX_512) opTypes.add(new NativeSimdTensorOperations(HAS_AVX2));

if (RuntimeSupport.isLinux() || RuntimeSupport.isWin()) {
opTypes.add(new NativeTensorOperations(HAS_F16C));
if (MachineSpec.VECTOR_TYPE == MachineSpec.Type.AVX_512) opTypes.add(new NativeTensorOperations(HAS_F16C | HAS_AVX2));
opTypes.add(new NativeSimdTensorOperations(HAS_F16C));
if (MachineSpec.VECTOR_TYPE == MachineSpec.Type.AVX_512) opTypes.add(new NativeSimdTensorOperations(HAS_F16C | HAS_AVX2));
}

if (RuntimeSupport.isArm()) {
opTypes.add(new NativeTensorOperations(MachineSpec.Type.ARM_128.ctag));
opTypes.add(new NativeSimdTensorOperations(MachineSpec.Type.ARM_128.ctag));
}
}

Expand Down Expand Up @@ -198,7 +198,7 @@ public void testSplitDotProduct() {

@Test
public void testNativeDotProduct() {
Assume.assumeTrue(globalOps instanceof NativeTensorOperations);
Assume.assumeTrue(globalOps instanceof NativeSimdTensorOperations);
AbstractTensor a = makeTensor(SIZE);
AbstractTensor b = makeTensor(SIZE);

Expand Down Expand Up @@ -476,7 +476,7 @@ public void testBatchDotProductWithResultOffset() {
@Test
public void testNativeBatchDotProduct() {
// M == BATCH, N == ROWS, K == SIZE
Assume.assumeTrue(globalOps instanceof NativeTensorOperations);
Assume.assumeTrue(globalOps instanceof NativeSimdTensorOperations);

FloatBufferTensor c = new FloatBufferTensor(BATCH, ROWS);
FloatBufferTensor c1 = new FloatBufferTensor(BATCH, ROWS);
Expand Down Expand Up @@ -512,7 +512,7 @@ public void testNativeBatchDotProduct() {
@Test
public void testNativeBatchDotProductWithOffsets() {
// M == BATCH, N == ROWS, K == SIZE
Assume.assumeTrue(globalOps instanceof NativeTensorOperations);
Assume.assumeTrue(globalOps instanceof NativeSimdTensorOperations);

FloatBufferTensor c = new FloatBufferTensor(BATCH, ROWS);
FloatBufferTensor c1 = new FloatBufferTensor(BATCH, ROWS);
Expand Down Expand Up @@ -548,7 +548,7 @@ public void testNativeBatchDotProductWithOffsets() {
@Test
public void testNativeDotProductFast() {
// M == BATCH, N == ROWS, K == SIZE
Assume.assumeTrue(globalOps instanceof NativeTensorOperations);
Assume.assumeTrue(globalOps instanceof NativeSimdTensorOperations);

FloatBufferTensor c = new FloatBufferTensor(1, SIZE);
FloatBufferTensor c1 = new FloatBufferTensor(1, SIZE);
Expand Down

0 comments on commit 7eda3aa

Please # to comment.