diff options
author | Ashwin Bharambe <ashwin@meta.com> | 2024-07-29 23:43:37 -0700 |
---|---|---|
committer | Ashwin Bharambe <ashwin@meta.com> | 2024-07-29 23:43:37 -0700 |
commit | c36031e2e99f3615ce46b4c6a6f4dea0f1574eb3 (patch) | |
tree | 6703ebfa874aac7f80b5b8eae93126284f5fe9eb | |
parent | 1768925a9d343b19459b5bc2acc893c842874179 (diff) |
add recommended sampling params
-rw-r--r-- | models/llama3_1/api/datatypes.py | 1 | ||||
-rw-r--r-- | models/llama3_1/api/sku_list.py | 20 |
2 files changed, 21 insertions, 0 deletions
diff --git a/models/llama3_1/api/datatypes.py b/models/llama3_1/api/datatypes.py index 7d3e964..faaa08d 100644 --- a/models/llama3_1/api/datatypes.py +++ b/models/llama3_1/api/datatypes.py @@ -234,4 +234,5 @@ class ModelDefinition(BaseModel): quantization_format: CheckpointQuantizationFormat = ( CheckpointQuantizationFormat.bf16 ) + recommended_sampling_params: Optional[SamplingParams] = None model_args: Dict[str, Any] diff --git a/models/llama3_1/api/sku_list.py b/models/llama3_1/api/sku_list.py index 372d588..770b3d6 100644 --- a/models/llama3_1/api/sku_list.py +++ b/models/llama3_1/api/sku_list.py @@ -13,6 +13,8 @@ from .datatypes import ( ModelDefinition, ModelFamily, ModelSKU, + SamplingParams, + SamplingStrategy, ) @@ -24,6 +26,14 @@ def llama3_1_model_list() -> List[ModelDefinition]: return base_models() + instruct_models() +def recommended_sampling_params() -> SamplingParams: + return SamplingParams( + strategy=SamplingStrategy.top_p, + temperature=1.0, + top_p=0.9, + ) + + def base_models() -> List[ModelDefinition]: return [ ModelDefinition( @@ -36,6 +46,7 @@ def base_models() -> List[ModelDefinition]: gpu_count=1, memory_gb_per_gpu=20, ), + recommended_sampling_params=recommended_sampling_params(), model_args={ "dim": 4096, "n_layers": 32, @@ -59,6 +70,7 @@ def base_models() -> List[ModelDefinition]: gpu_count=8, memory_gb_per_gpu=20, ), + recommended_sampling_params=recommended_sampling_params(), model_args={ "dim": 8192, "n_layers": 80, @@ -82,6 +94,7 @@ def base_models() -> List[ModelDefinition]: gpu_count=8, memory_gb_per_gpu=120, ), + recommended_sampling_params=recommended_sampling_params(), model_args={ "dim": 16384, "n_layers": 126, @@ -106,6 +119,7 @@ def base_models() -> List[ModelDefinition]: memory_gb_per_gpu=70, ), quantization_format=CheckpointQuantizationFormat.fp8_mixed, + recommended_sampling_params=recommended_sampling_params(), model_args={ "dim": 16384, "n_layers": 126, @@ -129,6 +143,7 @@ def base_models() -> List[ModelDefinition]: gpu_count=16, memory_gb_per_gpu=70, ), + recommended_sampling_params=recommended_sampling_params(), model_args={ "dim": 16384, "n_layers": 126, @@ -157,6 +172,7 @@ def instruct_models() -> List[ModelDefinition]: gpu_count=1, memory_gb_per_gpu=20, ), + recommended_sampling_params=recommended_sampling_params(), model_args={ "dim": 4096, "n_layers": 32, @@ -180,6 +196,7 @@ def instruct_models() -> List[ModelDefinition]: gpu_count=8, memory_gb_per_gpu=20, ), + recommended_sampling_params=recommended_sampling_params(), model_args={ "dim": 8192, "n_layers": 80, @@ -203,6 +220,7 @@ def instruct_models() -> List[ModelDefinition]: gpu_count=8, memory_gb_per_gpu=120, ), + recommended_sampling_params=recommended_sampling_params(), model_args={ "dim": 16384, "n_layers": 126, @@ -227,6 +245,7 @@ def instruct_models() -> List[ModelDefinition]: memory_gb_per_gpu=70, ), quantization_format=CheckpointQuantizationFormat.fp8_mixed, + recommended_sampling_params=recommended_sampling_params(), model_args={ "dim": 16384, "n_layers": 126, @@ -250,6 +269,7 @@ def instruct_models() -> List[ModelDefinition]: gpu_count=16, memory_gb_per_gpu=70, ), + recommended_sampling_params=recommended_sampling_params(), model_args={ "dim": 16384, "n_layers": 126, |