summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAshwin Bharambe <ashwin.bharambe@gmail.com>2024-09-28 15:10:34 -0700
committerAshwin Bharambe <ashwin.bharambe@gmail.com>2024-09-30 13:12:41 -0700
commit47eee078abfc695284fa559abf94cb44f830f51c (patch)
tree6c24068d6925cef604f2551891a6046a401ae300
parentb37b146f714fb33135eb99c95e6e33c745a41be2 (diff)
Add images
-rw-r--r--models/llama3/reference_impl/generation.py4
-rw-r--r--models/scripts/image1.jpegbin0 -> 262859 bytes
-rw-r--r--models/scripts/image2.jpegbin0 -> 95361 bytes
-rw-r--r--models/scripts/lg_vision.py1
4 files changed, 5 insertions, 0 deletions
diff --git a/models/llama3/reference_impl/generation.py b/models/llama3/reference_impl/generation.py
index 7f7eddd..dd58b33 100644
--- a/models/llama3/reference_impl/generation.py
+++ b/models/llama3/reference_impl/generation.py
@@ -305,6 +305,7 @@ class Llama:
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
+ print_model_input: bool = False,
) -> CompletionPrediction:
if (
max_gen_len is None
@@ -325,6 +326,7 @@ class Llama:
top_p=top_p,
logprobs=logprobs,
echo=echo,
+ print_model_input=print_model_input,
):
tokens.append(result.token)
if logprobs:
@@ -350,6 +352,7 @@ class Llama:
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
echo: bool = False,
+ print_model_input: bool = False,
) -> ChatPrediction:
if (
max_gen_len is None
@@ -372,6 +375,7 @@ class Llama:
top_p=top_p,
logprobs=logprobs,
echo=echo,
+ print_model_input=print_model_input,
):
tokens.append(result.token)
if result.text == "<|eot_id|>":
diff --git a/models/scripts/image1.jpeg b/models/scripts/image1.jpeg
new file mode 100644
index 0000000..1a18e5c
--- /dev/null
+++ b/models/scripts/image1.jpeg
Binary files differ
diff --git a/models/scripts/image2.jpeg b/models/scripts/image2.jpeg
new file mode 100644
index 0000000..9a0cfe9
--- /dev/null
+++ b/models/scripts/image2.jpeg
Binary files differ
diff --git a/models/scripts/lg_vision.py b/models/scripts/lg_vision.py
index c217b4b..8a6299c 100644
--- a/models/scripts/lg_vision.py
+++ b/models/scripts/lg_vision.py
@@ -88,6 +88,7 @@ Provide your safety assessment for ONLY THE LAST User message in the above conve
result = generator.chat_completion(
dialog,
temperature=0,
+ print_model_input=True,
)
for msg in dialog: