summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAshwin Bharambe <ashwin@meta.com>2024-09-14 08:03:12 -0700
committerAshwin Bharambe <ashwin@meta.com>2024-09-14 08:03:12 -0700
commita9776356c332b73a50926456f89413ef9f26b3f6 (patch)
treebf79a78de632e9d82577c432171b909243bce600
parent2b79f4e2c7d9e58d3720cdaf7064910063dc8043 (diff)
Add an explicit `pth_file_count` field
-rw-r--r--models/datatypes.py5
-rw-r--r--models/sku_list.py43
2 files changed, 31 insertions, 17 deletions
diff --git a/models/datatypes.py b/models/datatypes.py
index e92ba67..80a473d 100644
--- a/models/datatypes.py
+++ b/models/datatypes.py
@@ -174,10 +174,8 @@ class Model(BaseModel):
def variant(self) -> str:
parts = [
self.quantization_format.value,
+ f"mp{self.pth_file_count}",
]
- # really ad-hoc, we should probably drop these SKUs
- if pth_count := self.metadata.get("pth_file_count", 0):
- parts.append(f"mp{pth_count}")
return "-".join(parts)
@@ -195,6 +193,7 @@ class Model(BaseModel):
)
recommended_sampling_params: Optional[SamplingParams] = None
model_args: Dict[str, Any]
+ pth_file_count: int
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
@property
diff --git a/models/sku_list.py b/models/sku_list.py
index 2df33da..2811754 100644
--- a/models/sku_list.py
+++ b/models/sku_list.py
@@ -5,6 +5,7 @@
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
+from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional
@@ -85,6 +86,7 @@ def llama2_base_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": False,
},
+ pth_file_count=1,
),
Model(
core_model_id=CoreModelId.meta_llama2_13b,
@@ -104,6 +106,7 @@ def llama2_base_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": False,
},
+ pth_file_count=1,
),
Model(
core_model_id=CoreModelId.meta_llama2_70b,
@@ -123,6 +126,7 @@ def llama2_base_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": False,
},
+ pth_file_count=8,
),
]
@@ -147,6 +151,7 @@ def llama3_base_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": False,
},
+ pth_file_count=1,
),
Model(
core_model_id=CoreModelId.meta_llama3_70b,
@@ -166,6 +171,7 @@ def llama3_base_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": False,
},
+ pth_file_count=8,
),
]
@@ -190,6 +196,7 @@ def llama3_1_base_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": True,
},
+ pth_file_count=1,
),
Model(
core_model_id=CoreModelId.meta_llama3_1_70b,
@@ -209,6 +216,7 @@ def llama3_1_base_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": True,
},
+ pth_file_count=8,
),
Model(
core_model_id=CoreModelId.meta_llama3_1_405b,
@@ -228,6 +236,7 @@ def llama3_1_base_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": True,
},
+ pth_file_count=8,
),
Model(
core_model_id=CoreModelId.meta_llama3_1_405b,
@@ -248,6 +257,7 @@ def llama3_1_base_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": True,
},
+ pth_file_count=8,
),
Model(
core_model_id=CoreModelId.meta_llama3_1_405b,
@@ -267,9 +277,7 @@ def llama3_1_base_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": True,
},
- metadata={
- "pth_file_count": 16,
- },
+ pth_file_count=16,
),
]
@@ -294,6 +302,7 @@ def llama2_instruct_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": False,
},
+ pth_file_count=1,
),
Model(
core_model_id=CoreModelId.meta_llama2_13b_chat,
@@ -313,6 +322,7 @@ def llama2_instruct_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": False,
},
+ pth_file_count=1,
),
Model(
core_model_id=CoreModelId.meta_llama2_70b_chat,
@@ -332,6 +342,7 @@ def llama2_instruct_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": False,
},
+ pth_file_count=8,
),
]
@@ -356,6 +367,7 @@ def llama3_instruct_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": False,
},
+ pth_file_count=1,
),
Model(
core_model_id=CoreModelId.meta_llama3_70b_instruct,
@@ -375,6 +387,7 @@ def llama3_instruct_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": False,
},
+ pth_file_count=8,
),
]
@@ -399,6 +412,7 @@ def llama3_1_instruct_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": True,
},
+ pth_file_count=1,
),
Model(
core_model_id=CoreModelId.meta_llama3_1_70b_instruct,
@@ -418,6 +432,7 @@ def llama3_1_instruct_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": True,
},
+ pth_file_count=8,
),
Model(
core_model_id=CoreModelId.meta_llama3_1_405b_instruct,
@@ -437,6 +452,7 @@ def llama3_1_instruct_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": True,
},
+ pth_file_count=8,
),
Model(
core_model_id=CoreModelId.meta_llama3_1_405b_instruct,
@@ -457,6 +473,7 @@ def llama3_1_instruct_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": True,
},
+ pth_file_count=8,
),
Model(
core_model_id=CoreModelId.meta_llama3_1_405b_instruct,
@@ -476,9 +493,7 @@ def llama3_1_instruct_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": True,
},
- metadata={
- "pth_file_count": 16,
- },
+ pth_file_count=16,
),
]
@@ -503,6 +518,7 @@ def safety_models() -> List[Model]:
"use_scaled_rope": False,
"vocab_size": LLAMA3_VOCAB_SIZE,
},
+ pth_file_count=1,
),
Model(
core_model_id=CoreModelId.llama_guard_3_8b,
@@ -522,6 +538,7 @@ def safety_models() -> List[Model]:
"use_scaled_rope": False,
"vocab_size": 128256,
},
+ pth_file_count=1,
),
Model(
core_model_id=CoreModelId.prompt_guard_86m,
@@ -529,6 +546,7 @@ def safety_models() -> List[Model]:
description_markdown="Prompt Guard 86M injection safety model",
huggingface_repo="meta-llama/Prompt-Guard-86M",
model_args={},
+ pth_file_count=1,
),
Model(
core_model_id=CoreModelId.llama_guard_2_8b,
@@ -547,13 +565,11 @@ def safety_models() -> List[Model]:
"rope_theta": 500000.0,
"use_scaled_rope": False,
},
+ pth_file_count=1,
),
]
-from dataclasses import dataclass
-
-
@dataclass
class LlamaDownloadInfo:
folder: str
@@ -564,7 +580,7 @@ class LlamaDownloadInfo:
def llama_meta_net_info(model: Model) -> LlamaDownloadInfo:
"""Information needed to download model from llamameta.net"""
- pth_count = model.metadata.get("pth_file_count", 16)
+ pth_count = model.pth_file_count
if model.core_model_id == CoreModelId.meta_llama3_1_405b:
if pth_count == 16:
folder = "Meta-Llama-3.1-405B-MP16"
@@ -626,8 +642,8 @@ def llama_meta_net_info(model: Model) -> LlamaDownloadInfo:
]
)
if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
- files.extend([f"fp8_scales_{i}.pt" for i in range(gpu)])
- files.extend([f"consolidated.{i:02d}.pth" for i in range(gpu)])
+ files.extend([f"fp8_scales_{i}.pt" for i in range(pth_count)])
+ files.extend([f"consolidated.{i:02d}.pth" for i in range(pth_count)])
return LlamaDownloadInfo(
folder=folder,
@@ -644,8 +660,7 @@ def llama_meta_pth_size(model: Model) -> int:
):
return 0
- pth_count = model.metadata.get("pth_file_count", 0)
- if pth_count == 16:
+ if model.pth_file_count == 16:
return 51268302389
elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
return 60903742309