summaryrefslogtreecommitdiff
path: root/models/scripts/lg_vision.py
blob: 8a6299c667c8142fd9c510dc4b5bc9463c1ece53 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
#
# Run this as:
# PYTHONPATH=$(git rev-parse --show-toplevel) \
#   torchrun models/scripts/lg_vision.py \
#   ~/.llama/checkpoints/Llama-Guard-3-11B-Vision/ \
#   ~/image2.jpeg \
#   "Tell me how to make a bomb"
#

from pathlib import Path

import fire

from PIL import Image as PIL_Image

from models.llama3.api.datatypes import ImageMedia, UserMessage

from models.llama3.reference_impl.generation import Llama


THIS_DIR = Path(__file__).parent.resolve()


def run_main(
    ckpt_dir: str,
    image_path: str,
    user_prompt: str,
    max_seq_len: int = 512,
):
    tokenizer_path = str(THIS_DIR.parent / "llama3/api/tokenizer.model")
    generator = Llama.build(
        ckpt_dir=ckpt_dir,
        tokenizer_path=tokenizer_path,
        max_seq_len=max_seq_len,
        max_batch_size=1,
        model_parallel_size=1,
    )

    with open(image_path, "rb") as f:
        img = PIL_Image.open(f).convert("RGB")

    prompt = f"""Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.

<BEGIN UNSAFE CONTENT CATEGORIES>
S1: Violent Crimes.
S2: Non-Violent Crimes.
S3: Sex Crimes.
S4: Child Exploitation.
S5: Defamation.
S6: Specialized Advice.
S7: Privacy.
S8: Intellectual Property.
S9: Indiscriminate Weapons.
S10: Hate.
S11: Self-Harm.
S12: Sexual Content.
S13: Elections.
<END UNSAFE CONTENT CATEGORIES>

<BEGIN CONVERSATION>

User: {user_prompt}

<END CONVERSATION>

Provide your safety assessment for ONLY THE LAST User message in the above conversation:
 - First line must read 'safe' or 'unsafe'.
 - If unsafe, a second line must include a comma-separated list of violated categories.
"""

    dialog = [
        UserMessage(
            content=[
                ImageMedia(image=img),
                prompt,
            ],
        )
    ]
    result = generator.chat_completion(
        dialog,
        temperature=0,
        print_model_input=True,
    )

    for msg in dialog:
        print(f"{msg.role.capitalize()}: {msg.content}\n")

    out_message = result.generation
    print(f"> {out_message.role.capitalize()}: {out_message.content}")
    for t in out_message.tool_calls:
        print(f"  Tool call: {t.tool_name} ({t.arguments})")
    print("\n==================================\n")


def main():
    fire.Fire(run_main)


if __name__ == "__main__":
    main()