diff options
author | Ashwin Bharambe <ashwin@meta.com> | 2024-09-14 08:03:12 -0700 |
---|---|---|
committer | Ashwin Bharambe <ashwin@meta.com> | 2024-09-14 08:03:12 -0700 |
commit | a9776356c332b73a50926456f89413ef9f26b3f6 (patch) | |
tree | bf79a78de632e9d82577c432171b909243bce600 | |
parent | 2b79f4e2c7d9e58d3720cdaf7064910063dc8043 (diff) |
Add an explicit `pth_file_count` field
-rw-r--r-- | models/datatypes.py | 5 | ||||
-rw-r--r-- | models/sku_list.py | 43 |
2 files changed, 31 insertions, 17 deletions
diff --git a/models/datatypes.py b/models/datatypes.py index e92ba67..80a473d 100644 --- a/models/datatypes.py +++ b/models/datatypes.py @@ -174,10 +174,8 @@ class Model(BaseModel): def variant(self) -> str: parts = [ self.quantization_format.value, + f"mp{self.pth_file_count}", ] - # really ad-hoc, we should probably drop these SKUs - if pth_count := self.metadata.get("pth_file_count", 0): - parts.append(f"mp{pth_count}") return "-".join(parts) @@ -195,6 +193,7 @@ class Model(BaseModel): ) recommended_sampling_params: Optional[SamplingParams] = None model_args: Dict[str, Any] + pth_file_count: int metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) @property diff --git a/models/sku_list.py b/models/sku_list.py index 2df33da..2811754 100644 --- a/models/sku_list.py +++ b/models/sku_list.py @@ -5,6 +5,7 @@ # top-level folder for each specific model found within the models/ directory at # the top-level of this source tree. +from dataclasses import dataclass from functools import lru_cache from typing import List, Optional @@ -85,6 +86,7 @@ def llama2_base_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": False, }, + pth_file_count=1, ), Model( core_model_id=CoreModelId.meta_llama2_13b, @@ -104,6 +106,7 @@ def llama2_base_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": False, }, + pth_file_count=1, ), Model( core_model_id=CoreModelId.meta_llama2_70b, @@ -123,6 +126,7 @@ def llama2_base_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": False, }, + pth_file_count=8, ), ] @@ -147,6 +151,7 @@ def llama3_base_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": False, }, + pth_file_count=1, ), Model( core_model_id=CoreModelId.meta_llama3_70b, @@ -166,6 +171,7 @@ def llama3_base_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": False, }, + pth_file_count=8, ), ] @@ -190,6 +196,7 @@ def llama3_1_base_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": True, }, + pth_file_count=1, ), Model( core_model_id=CoreModelId.meta_llama3_1_70b, @@ -209,6 +216,7 @@ def llama3_1_base_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": True, }, + pth_file_count=8, ), Model( core_model_id=CoreModelId.meta_llama3_1_405b, @@ -228,6 +236,7 @@ def llama3_1_base_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": True, }, + pth_file_count=8, ), Model( core_model_id=CoreModelId.meta_llama3_1_405b, @@ -248,6 +257,7 @@ def llama3_1_base_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": True, }, + pth_file_count=8, ), Model( core_model_id=CoreModelId.meta_llama3_1_405b, @@ -267,9 +277,7 @@ def llama3_1_base_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": True, }, - metadata={ - "pth_file_count": 16, - }, + pth_file_count=16, ), ] @@ -294,6 +302,7 @@ def llama2_instruct_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": False, }, + pth_file_count=1, ), Model( core_model_id=CoreModelId.meta_llama2_13b_chat, @@ -313,6 +322,7 @@ def llama2_instruct_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": False, }, + pth_file_count=1, ), Model( core_model_id=CoreModelId.meta_llama2_70b_chat, @@ -332,6 +342,7 @@ def llama2_instruct_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": False, }, + pth_file_count=8, ), ] @@ -356,6 +367,7 @@ def llama3_instruct_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": False, }, + pth_file_count=1, ), Model( core_model_id=CoreModelId.meta_llama3_70b_instruct, @@ -375,6 +387,7 @@ def llama3_instruct_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": False, }, + pth_file_count=8, ), ] @@ -399,6 +412,7 @@ def llama3_1_instruct_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": True, }, + pth_file_count=1, ), Model( core_model_id=CoreModelId.meta_llama3_1_70b_instruct, @@ -418,6 +432,7 @@ def llama3_1_instruct_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": True, }, + pth_file_count=8, ), Model( core_model_id=CoreModelId.meta_llama3_1_405b_instruct, @@ -437,6 +452,7 @@ def llama3_1_instruct_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": True, }, + pth_file_count=8, ), Model( core_model_id=CoreModelId.meta_llama3_1_405b_instruct, @@ -457,6 +473,7 @@ def llama3_1_instruct_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": True, }, + pth_file_count=8, ), Model( core_model_id=CoreModelId.meta_llama3_1_405b_instruct, @@ -476,9 +493,7 @@ def llama3_1_instruct_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": True, }, - metadata={ - "pth_file_count": 16, - }, + pth_file_count=16, ), ] @@ -503,6 +518,7 @@ def safety_models() -> List[Model]: "use_scaled_rope": False, "vocab_size": LLAMA3_VOCAB_SIZE, }, + pth_file_count=1, ), Model( core_model_id=CoreModelId.llama_guard_3_8b, @@ -522,6 +538,7 @@ def safety_models() -> List[Model]: "use_scaled_rope": False, "vocab_size": 128256, }, + pth_file_count=1, ), Model( core_model_id=CoreModelId.prompt_guard_86m, @@ -529,6 +546,7 @@ def safety_models() -> List[Model]: description_markdown="Prompt Guard 86M injection safety model", huggingface_repo="meta-llama/Prompt-Guard-86M", model_args={}, + pth_file_count=1, ), Model( core_model_id=CoreModelId.llama_guard_2_8b, @@ -547,13 +565,11 @@ def safety_models() -> List[Model]: "rope_theta": 500000.0, "use_scaled_rope": False, }, + pth_file_count=1, ), ] -from dataclasses import dataclass - - @dataclass class LlamaDownloadInfo: folder: str @@ -564,7 +580,7 @@ class LlamaDownloadInfo: def llama_meta_net_info(model: Model) -> LlamaDownloadInfo: """Information needed to download model from llamameta.net""" - pth_count = model.metadata.get("pth_file_count", 16) + pth_count = model.pth_file_count if model.core_model_id == CoreModelId.meta_llama3_1_405b: if pth_count == 16: folder = "Meta-Llama-3.1-405B-MP16" @@ -626,8 +642,8 @@ def llama_meta_net_info(model: Model) -> LlamaDownloadInfo: ] ) if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: - files.extend([f"fp8_scales_{i}.pt" for i in range(gpu)]) - files.extend([f"consolidated.{i:02d}.pth" for i in range(gpu)]) + files.extend([f"fp8_scales_{i}.pt" for i in range(pth_count)]) + files.extend([f"consolidated.{i:02d}.pth" for i in range(pth_count)]) return LlamaDownloadInfo( folder=folder, @@ -644,8 +660,7 @@ def llama_meta_pth_size(model: Model) -> int: ): return 0 - pth_count = model.metadata.get("pth_file_count", 0) - if pth_count == 16: + if model.pth_file_count == 16: return 51268302389 elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed: return 60903742309 |