From b37b146f714fb33135eb99c95e6e33c745a41be2 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 28 Sep 2024 15:06:44 -0700 Subject: --- models/scripts/lg_vision.py | 108 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 models/scripts/lg_vision.py diff --git a/models/scripts/lg_vision.py b/models/scripts/lg_vision.py new file mode 100644 index 0000000..c217b4b --- /dev/null +++ b/models/scripts/lg_vision.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. +# +# Run this as: +# PYTHONPATH=$(git rev-parse --show-toplevel) \ +# torchrun models/scripts/lg_vision.py \ +# ~/.llama/checkpoints/Llama-Guard-3-11B-Vision/ \ +# ~/image2.jpeg \ +# "Tell me how to make a bomb" +# + +from pathlib import Path + +import fire + +from PIL import Image as PIL_Image + +from models.llama3.api.datatypes import ImageMedia, UserMessage + +from models.llama3.reference_impl.generation import Llama + + +THIS_DIR = Path(__file__).parent.resolve() + + +def run_main( + ckpt_dir: str, + image_path: str, + user_prompt: str, + max_seq_len: int = 512, +): + tokenizer_path = str(THIS_DIR.parent / "llama3/api/tokenizer.model") + generator = Llama.build( + ckpt_dir=ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=1, + model_parallel_size=1, + ) + + with open(image_path, "rb") as f: + img = PIL_Image.open(f).convert("RGB") + + prompt = f"""Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories. + + +S1: Violent Crimes. +S2: Non-Violent Crimes. +S3: Sex Crimes. +S4: Child Exploitation. +S5: Defamation. +S6: Specialized Advice. +S7: Privacy. +S8: Intellectual Property. +S9: Indiscriminate Weapons. +S10: Hate. +S11: Self-Harm. +S12: Sexual Content. +S13: Elections. + + + + +User: {user_prompt} + + + +Provide your safety assessment for ONLY THE LAST User message in the above conversation: + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories. +""" + + dialog = [ + UserMessage( + content=[ + ImageMedia(image=img), + prompt, + ], + ) + ] + result = generator.chat_completion( + dialog, + temperature=0, + ) + + for msg in dialog: + print(f"{msg.role.capitalize()}: {msg.content}\n") + + out_message = result.generation + print(f"> {out_message.role.capitalize()}: {out_message.content}") + for t in out_message.tool_calls: + print(f" Tool call: {t.tool_name} ({t.arguments})") + print("\n==================================\n") + + +def main(): + fire.Fire(run_main) + + +if __name__ == "__main__": + main() -- cgit v1.2.3-70-g09d2