Skip to content

Commit

Permalink
Update llama API & parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
eoctet committed Aug 17, 2024
1 parent b18f8bc commit 84142e2
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 21 deletions.
26 changes: 14 additions & 12 deletions llama-java-core/llamajava/llamajava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ static jfieldID FIELD_THREADS;
static jfieldID FIELD_THREADS_BATCH;
static jfieldID FIELD_ROPE_SCALING_TYPE;
static jfieldID FIELD_POOLING_TYPE;
static jfieldID FIELD_ATTENTION_TYPE;
static jfieldID FIELD_YARN_EXT_FACTOR;
static jfieldID FIELD_YARN_ATTN_FACTOR;
static jfieldID FIELD_YARN_BETA_FAST;
Expand Down Expand Up @@ -219,6 +220,7 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
FIELD_THREADS_BATCH = env->GetFieldID(LLAMA_CONTEXT_PARAMS_CLASS, "threadsBatch", "I");
FIELD_ROPE_SCALING_TYPE = env->GetFieldID(LLAMA_CONTEXT_PARAMS_CLASS, "ropeScalingType", "I");
FIELD_POOLING_TYPE = env->GetFieldID(LLAMA_CONTEXT_PARAMS_CLASS, "poolingType", "I");
FIELD_ATTENTION_TYPE = env->GetFieldID(LLAMA_CONTEXT_PARAMS_CLASS, "attentionType", "I");
FIELD_YARN_EXT_FACTOR = env->GetFieldID(LLAMA_CONTEXT_PARAMS_CLASS, "yarnExtFactor", "F");
FIELD_YARN_ATTN_FACTOR = env->GetFieldID(LLAMA_CONTEXT_PARAMS_CLASS, "yarnAttnFactor", "F");
FIELD_YARN_BETA_FAST = env->GetFieldID(LLAMA_CONTEXT_PARAMS_CLASS, "yarnBetaFast", "F");
Expand Down Expand Up @@ -351,6 +353,7 @@ JNIEXPORT jobject JNICALL Java_chat_octet_model_LlamaService_getLlamaContextDefa
env->SetIntField(llama_context_params, FIELD_THREADS_BATCH, defaults.n_threads_batch);
env->SetIntField(llama_context_params, FIELD_ROPE_SCALING_TYPE, defaults.rope_scaling_type);
env->SetIntField(llama_context_params, FIELD_POOLING_TYPE, defaults.pooling_type);
env->SetIntField(llama_context_params, FIELD_ATTENTION_TYPE, defaults.attention_type);
env->SetFloatField(llama_context_params, FIELD_YARN_EXT_FACTOR, defaults.yarn_ext_factor);
env->SetFloatField(llama_context_params, FIELD_YARN_ATTN_FACTOR, defaults.yarn_attn_factor);
env->SetFloatField(llama_context_params, FIELD_YARN_BETA_FAST, defaults.yarn_beta_fast);
Expand Down Expand Up @@ -468,6 +471,7 @@ JNIEXPORT void JNICALL Java_chat_octet_model_LlamaService_loadLlamaModelFromFile
/*.n_threads_batch =*/ (uint32_t) env->GetIntField(jllama_context_params, FIELD_THREADS_BATCH),
/*.rope_scaling_type =*/ static_cast<enum llama_rope_scaling_type>(env->GetIntField(jllama_context_params, FIELD_ROPE_SCALING_TYPE)),
/*.pooling_type =*/ static_cast<enum llama_pooling_type>(env->GetIntField(jllama_context_params, FIELD_POOLING_TYPE)),
/*.attention_type =*/ static_cast<enum llama_attention_type>(env->GetIntField(jllama_context_params, FIELD_ATTENTION_TYPE)),
/*.rope_freq_base =*/ env->GetFloatField(jllama_context_params, FIELD_ROPE_FREQ_BASE),
/*.rope_freq_scale =*/ env->GetFloatField(jllama_context_params, FIELD_ROPE_FREQ_SCALE),
/*.yarn_ext_factor =*/ env->GetFloatField(jllama_context_params, FIELD_YARN_EXT_FACTOR),
Expand Down Expand Up @@ -591,14 +595,12 @@ JNIEXPORT jint JNICALL Java_chat_octet_model_LlamaService_loadLoraModelFromFile
jint threads) {
UNUSED(thisClass);
if (Check_Context_Is_Null(env)) return -1;
int status = llama_model_apply_lora_from_file(main_ctx->model,
env->GetStringUTFChars(lora_path, JNI_FALSE),
scale,
env->GetStringUTFChars(base_model_path, JNI_FALSE),
threads
);
JLOG_DEBUG("Successfully loaded lora model, status: %d.", status);
return status;
llama_lora_adapter* adapter = llama_lora_adapter_init(main_ctx->model, env->GetStringUTFChars(lora_path, JNI_FALSE));
if (adapter != nullptr) {
llama_lora_adapter_set(main_ctx->llama_ctx, adapter, scale);
return 0;
}
return -1;
}

/*
Expand Down Expand Up @@ -672,12 +674,12 @@ JNIEXPORT jint JNICALL Java_chat_octet_model_LlamaService_tokenize
* Method: tokenToPiece
*/
JNIEXPORT jint JNICALL Java_chat_octet_model_LlamaService_tokenToPiece
(JNIEnv *env, jclass thisClass, jint token, jbyteArray buf, jint buffer_length, jboolean special) {
(JNIEnv *env, jclass thisClass, jint token, jbyteArray buf, jint buffer_length, jint lstrip_length, jboolean special) {
UNUSED(thisClass);
if (Check_Context_Is_Null(env)) return -1;

jbyte *buffer = new jbyte[buffer_length];
int size = llama_token_to_piece(main_ctx->model, token, (char *) buffer, buffer_length, To_CBool(special));
int size = llama_token_to_piece(main_ctx->model, token, (char *) buffer, buffer_length, lstrip_length, To_CBool(special));
env->ReleaseByteArrayElements(buf, buffer, 0);
return size;
}
Expand Down Expand Up @@ -785,7 +787,7 @@ JNIEXPORT jint JNICALL Java_chat_octet_model_LlamaService_sampling
}

if (main_ctx->grammar != nullptr) {
llama_sample_grammar(main_ctx->llama_ctx, &candidates_p, main_ctx->grammar);
llama_grammar_sample(main_ctx->grammar, main_ctx->llama_ctx, &candidates_p);
}

llama_token token;
Expand Down Expand Up @@ -819,7 +821,7 @@ JNIEXPORT jint JNICALL Java_chat_octet_model_LlamaService_sampling
}

if (main_ctx->grammar != nullptr) {
llama_grammar_accept_token(main_ctx->llama_ctx, main_ctx->grammar, token);
llama_grammar_accept_token(main_ctx->grammar, main_ctx->llama_ctx, token);
}

//decode the next new token
Expand Down
2 changes: 1 addition & 1 deletion llama-java-core/llamajava/llamajava.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ JNIEXPORT jint JNICALL Java_chat_octet_model_LlamaService_tokenize
* Method: tokenToPiece
*/
JNIEXPORT jint JNICALL Java_chat_octet_model_LlamaService_tokenToPiece
(JNIEnv *, jclass, jint, jbyteArray, jint, jboolean);
(JNIEnv *, jclass, jint, jbyteArray, jint, jint, jboolean);

/*
* Class: chat_octet_model_LlamaService
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ private String appendMultiByteTokenBuffer(byte[] buffer, int length) {
*/
private String tokenToText(int token) {
byte[] buffer = new byte[64];
int length = LlamaService.tokenToPiece(token, buffer, buffer.length, generateParams.isSpecial());
int length = LlamaService.tokenToPiece(token, buffer, buffer.length, 0, generateParams.isSpecial());
if (length == 0) {
return StringUtils.EMPTY;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,11 @@ public static LlamaTokenAttr getLlamaTokenAttr(int token) {
* @param token Token id.
* @param buf Input byte buffer.
* @param bufferLength Input byte buffer length.
* @param lstripLength User can skip up to 'lstrip' leading spaces before copying.
* @param special If true, special tokens are rendered in the output.
* @return int, Returns byte buffer length of the piece.
*/
public static native int tokenToPiece(int token, byte[] buf, int bufferLength, boolean special);
public static native int tokenToPiece(int token, byte[] buf, int bufferLength, int lstripLength, boolean special);

/**
* Get sampling metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public static String decodeToken(boolean special, int... tokens) {
int length = 0;
for (int token : tokens) {
byte[] bytes = new byte[64];
int size = LlamaService.tokenToPiece(token, bytes, bytes.length, special);
int size = LlamaService.tokenToPiece(token, bytes, bytes.length, 0, special);
System.arraycopy(bytes, 0, buffer, length, size);
length += size;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package chat.octet.model.beans;

import chat.octet.model.enums.LlamaAttentionType;
import chat.octet.model.enums.LlamaPoolingType;
import chat.octet.model.enums.LlamaRoPEScalingType;
import lombok.ToString;
Expand Down Expand Up @@ -51,6 +52,12 @@ public class LlamaContextParams {
* @see LlamaPoolingType
*/
public int poolingType;
/**
* attention type.
*
* @see LlamaAttentionType
*/
public int attentionType;
/**
* YaRN extrapolation mix factor, NaN = from model.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package chat.octet.model.enums;


import com.google.common.collect.Maps;
import lombok.Getter;

import java.util.Collections;
import java.util.Map;

/**
* llama attention type
*
* @author <a href="https://github.com/eoctet">William</a>
*/
@Getter
public enum LlamaAttentionType {
/**
* unspecified type.
*/
LLAMA_ATTENTION_TYPE_UNSPECIFIED(-1),
/**
* causal type.
*/
LLAMA_ATTENTION_TYPE_CAUSAL(0),
/**
* non causal type.
*/
LLAMA_ATTENTION_TYPE_NON_CAUSAL(1);


private static final Map<Integer, LlamaAttentionType> TYPES;

static {
Map<Integer, LlamaAttentionType> map = Maps.newHashMap();

for (LlamaAttentionType type : values()) {
if (map.put(type.type, type) != null) {
throw new IllegalStateException("Duplicated key found: " + type.name());
}
}
TYPES = Collections.unmodifiableMap(map);
}

private final int type;

LlamaAttentionType(int type) {
this.type = type;
}

public static LlamaAttentionType valueOfType(int type) {
return TYPES.get(type);
}

@Override
public String toString() {
return this.name();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public enum ModelFileType {
/**
* tok_embeddings.weight and output.weight are F16
*/
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16(4),
//LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16(4),
LLAMA_FTYPE_MOSTLY_Q8_0(7),
LLAMA_FTYPE_MOSTLY_Q5_0(8),
LLAMA_FTYPE_MOSTLY_Q5_1(9),
Expand All @@ -47,6 +47,9 @@ public enum ModelFileType {
LLAMA_FTYPE_MOSTLY_IQ4_XS(30),
LLAMA_FTYPE_MOSTLY_IQ1_M(31),
LLAMA_FTYPE_MOSTLY_BF16(32),
LLAMA_FTYPE_MOSTLY_Q4_0_4_4(33),
LLAMA_FTYPE_MOSTLY_Q4_0_4_8(34),
LLAMA_FTYPE_MOSTLY_Q4_0_8_8(35),
/**
* not specified in the model file
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
import chat.octet.model.beans.LlamaContextParams;
import chat.octet.model.beans.LlamaModelParams;
import chat.octet.model.components.prompt.ChatTemplateFormatter;
import chat.octet.model.enums.LlamaNumaStrategy;
import chat.octet.model.enums.LlamaPoolingType;
import chat.octet.model.enums.LlamaRoPEScalingType;
import chat.octet.model.enums.LlamaSplitMode;
import chat.octet.model.enums.*;
import chat.octet.model.utils.JsonUtils;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
Expand Down Expand Up @@ -137,6 +134,14 @@ public class ModelParameter {
@Builder.Default
private int poolingType = LlamaPoolingType.LLAMA_POOLING_TYPE_UNSPECIFIED.getType();

/**
* Attention type for embeddings, use model default if unspecified.
*
* @see LlamaAttentionType
*/
@Builder.Default
private int attentionType = LlamaAttentionType.LLAMA_ATTENTION_TYPE_UNSPECIFIED.getType();

/**
* Base frequency for RoPE sampling.
*/
Expand Down

0 comments on commit 84142e2

Please # to comment.