summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAshwin Bharambe <ashwin@meta.com>2024-07-29 23:43:37 -0700
committerAshwin Bharambe <ashwin@meta.com>2024-07-29 23:43:37 -0700
commitc36031e2e99f3615ce46b4c6a6f4dea0f1574eb3 (patch)
tree6703ebfa874aac7f80b5b8eae93126284f5fe9eb
parent1768925a9d343b19459b5bc2acc893c842874179 (diff)
add recommended sampling params
-rw-r--r--models/llama3_1/api/datatypes.py1
-rw-r--r--models/llama3_1/api/sku_list.py20
2 files changed, 21 insertions, 0 deletions
diff --git a/models/llama3_1/api/datatypes.py b/models/llama3_1/api/datatypes.py
index 7d3e964..faaa08d 100644
--- a/models/llama3_1/api/datatypes.py
+++ b/models/llama3_1/api/datatypes.py
@@ -234,4 +234,5 @@ class ModelDefinition(BaseModel):
quantization_format: CheckpointQuantizationFormat = (
CheckpointQuantizationFormat.bf16
)
+ recommended_sampling_params: Optional[SamplingParams] = None
model_args: Dict[str, Any]
diff --git a/models/llama3_1/api/sku_list.py b/models/llama3_1/api/sku_list.py
index 372d588..770b3d6 100644
--- a/models/llama3_1/api/sku_list.py
+++ b/models/llama3_1/api/sku_list.py
@@ -13,6 +13,8 @@ from .datatypes import (
ModelDefinition,
ModelFamily,
ModelSKU,
+ SamplingParams,
+ SamplingStrategy,
)
@@ -24,6 +26,14 @@ def llama3_1_model_list() -> List[ModelDefinition]:
return base_models() + instruct_models()
+def recommended_sampling_params() -> SamplingParams:
+ return SamplingParams(
+ strategy=SamplingStrategy.top_p,
+ temperature=1.0,
+ top_p=0.9,
+ )
+
+
def base_models() -> List[ModelDefinition]:
return [
ModelDefinition(
@@ -36,6 +46,7 @@ def base_models() -> List[ModelDefinition]:
gpu_count=1,
memory_gb_per_gpu=20,
),
+ recommended_sampling_params=recommended_sampling_params(),
model_args={
"dim": 4096,
"n_layers": 32,
@@ -59,6 +70,7 @@ def base_models() -> List[ModelDefinition]:
gpu_count=8,
memory_gb_per_gpu=20,
),
+ recommended_sampling_params=recommended_sampling_params(),
model_args={
"dim": 8192,
"n_layers": 80,
@@ -82,6 +94,7 @@ def base_models() -> List[ModelDefinition]:
gpu_count=8,
memory_gb_per_gpu=120,
),
+ recommended_sampling_params=recommended_sampling_params(),
model_args={
"dim": 16384,
"n_layers": 126,
@@ -106,6 +119,7 @@ def base_models() -> List[ModelDefinition]:
memory_gb_per_gpu=70,
),
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
+ recommended_sampling_params=recommended_sampling_params(),
model_args={
"dim": 16384,
"n_layers": 126,
@@ -129,6 +143,7 @@ def base_models() -> List[ModelDefinition]:
gpu_count=16,
memory_gb_per_gpu=70,
),
+ recommended_sampling_params=recommended_sampling_params(),
model_args={
"dim": 16384,
"n_layers": 126,
@@ -157,6 +172,7 @@ def instruct_models() -> List[ModelDefinition]:
gpu_count=1,
memory_gb_per_gpu=20,
),
+ recommended_sampling_params=recommended_sampling_params(),
model_args={
"dim": 4096,
"n_layers": 32,
@@ -180,6 +196,7 @@ def instruct_models() -> List[ModelDefinition]:
gpu_count=8,
memory_gb_per_gpu=20,
),
+ recommended_sampling_params=recommended_sampling_params(),
model_args={
"dim": 8192,
"n_layers": 80,
@@ -203,6 +220,7 @@ def instruct_models() -> List[ModelDefinition]:
gpu_count=8,
memory_gb_per_gpu=120,
),
+ recommended_sampling_params=recommended_sampling_params(),
model_args={
"dim": 16384,
"n_layers": 126,
@@ -227,6 +245,7 @@ def instruct_models() -> List[ModelDefinition]:
memory_gb_per_gpu=70,
),
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
+ recommended_sampling_params=recommended_sampling_params(),
model_args={
"dim": 16384,
"n_layers": 126,
@@ -250,6 +269,7 @@ def instruct_models() -> List[ModelDefinition]:
gpu_count=16,
memory_gb_per_gpu=70,
),
+ recommended_sampling_params=recommended_sampling_params(),
model_args={
"dim": 16384,
"n_layers": 126,