Skip to content

Commit

Permalink
WIP adding image class
Browse files Browse the repository at this point in the history
  • Loading branch information
Isaac Miller committed Sep 18, 2024
1 parent 14ad003 commit f0e8afa
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 27 deletions.
2 changes: 1 addition & 1 deletion dsp/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __getattr__(self, name):
if name in self.config:
return self.config[name]

super().__getattr__(name)
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")

def __append(self, config):
thread_id = threading.get_ident()
Expand Down
2 changes: 1 addition & 1 deletion dspy/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs):
try:
value = self.parse(signature, output)
except Exception as e:
print("Failed to parse", messages, output)
print("Failed to parse", inputs, output)
raise e
assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
values.append(value)
Expand Down
39 changes: 24 additions & 15 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,30 +58,39 @@ def format_chat_turn(field_names, values):
# if not set(values).issuperset(set(field_names)):
# raise ValueError(f"Expected {field_names} but got {values.keys()}")

text_content = format_fields({k: values[k] for k in field_names if 'image' not in k and ('rationale' not in k or 'rationale' in values)})
# text_content = format_fields({k: values[k] for k in field_names if 'image' not in k})

request = []

for k in field_names:
if 'image' in k:
image = values[k]
image = values.get(k)
if not image:
continue
raise ValueError(f"Image not found for field {k}")

image_base64 = encode_image(image)
if image_base64:
if not image_base64:
raise ValueError(f"Failed to encode image for field {k}")

if request and request[-1]["type"] == "text":
request[-1]["text"] += f"\n\n[[[ ### {k} ### ]]]\n"
else:
request.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
}
"type": "text",
"text": f"\n\n[[[ ### {k} ### ]]]\n"
})
else:
raise ValueError(f"Failed to encode image for field {k}")

request.append({
"type": "text",
"text": text_content
})

request.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
}
})
else:
request.append({
"type": "text",
"text": format_fields({k: values[k]})
})

return request

Expand Down
10 changes: 9 additions & 1 deletion dspy/clients/base_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@ def _inspect_history(lm, n: int = 1):
print("\n\n\n")
for msg in messages:
print(_red(f"{msg['role'].capitalize()} message:"))
print(msg['content'].strip())
if isinstance(msg['content'], str):
print(msg['content'].strip())
else:
if isinstance(msg['content'], list):
for c in msg['content']:
if c["type"] == "text":
print(c["text"].strip())
elif c["type"] == "image_url":
print("<IMAGE URL>\n")
print("\n")

print(_red("Response:"))
Expand Down
51 changes: 42 additions & 9 deletions examples/vlm/mmmu.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"metadata": {},
"outputs": [],
"source": [
"lm = dspy.LM(model=\"openai/Qwen/Qwen2-VL-7B-Instruct\", api_base=\"http://localhost:8000/v1\", api_key=\"sk-proj-1234567890\")\n",
"lm = dspy.LM(model=\"openai/Qwen/Qwen2-VL-72B-Instruct\", api_base=\"http://localhost:8000/v1\", api_key=\"sk-proj-1234567890\")\n",
"\n",
"# adapter = dspy.ChatAdapter()\n",
"dspy.settings.configure(lm=lm)"
Expand All @@ -48,8 +48,10 @@
"devset = []\n",
"valset = []\n",
"for subset in subsets:\n",
" devset.extend(DataLoader().from_huggingface(\"MMMU/MMMU\", subset, split=[\"dev\"], input_keys=input_keys)[\"dev\"])\n",
" valset.extend(DataLoader().from_huggingface(\"MMMU/MMMU\", subset, split=[\"validation\"], input_keys=input_keys)[\"validation\"])"
" dataset = DataLoader().from_huggingface(\"MMMU/MMMU\", subset, split=[\"dev\", \"validation\"], input_keys=input_keys)\n",
" devset.extend(dataset[\"dev\"])\n",
" valset.extend(dataset[\"validation\"])\n",
" "
]
},
{
Expand All @@ -66,8 +68,9 @@
" image_counts[count] += 1\n",
" return image_counts\n",
"\n",
"devset = [example for example in devset if sum(1 for key in example.inputs().keys() if key.startswith('image_') and example.inputs()[key] is not None) <= 2]\n",
"valset = [example for example in valset if sum(1 for key in example.inputs().keys() if key.startswith('image_') and example.inputs()[key] is not None) <= 2]\n",
"max_images = 1\n",
"devset = [example for example in devset if sum(1 for key in example.inputs().keys() if key.startswith('image_') and example.inputs()[key] is not None) <= max_images]\n",
"valset = [example for example in valset if sum(1 for key in example.inputs().keys() if key.startswith('image_') and example.inputs()[key] is not None) <= max_images]\n",
"\n",
"devset_image_counts = count_images(devset)\n",
"valset_image_counts = count_images(valset)\n",
Expand All @@ -92,8 +95,8 @@
" \"\"\"Output a rationale and the answer to a multiple choice question about an image.\"\"\"\n",
"\n",
" question: str = dspy.InputField(desc=\"A question about the image(s)\")\n",
" image_1: Optional[Image] = dspy.InputField(desc=\"An image of a math problem\")\n",
" image_2: Optional[Image] = dspy.InputField(desc=\"An image of a math problem\")\n",
" image_1: Image = dspy.InputField(desc=\"An image of a math problem\")\n",
" # image_2: Image = dspy.InputField(desc=\"An image of a math problem\")\n",
" options: List[str] = dspy.InputField(desc=\"The options to the question\")\n",
" answer: str = dspy.OutputField(desc=\"The answer to the question\")\n",
"\n",
Expand All @@ -109,10 +112,40 @@
"source": [
"# sample_input=devset[0]\n",
"results = []\n",
"for sample_input in devset[:10]:\n",
"for sample_input in devset[10:20]:\n",
" x = predictor(**sample_input.inputs())\n",
" results.append(x)\n",
" print(x)"
" print(x)\n",
"\n",
"# predictor(**devset[0].inputs())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lm.inspect_history()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# raise ValueError(\"stop here\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"evaluate = Evaluate(metric=answer_exact_match, num_threads=300, devset= valset, display_progress=True)\n",
"print(evaluate(predictor))"
]
},
{
Expand Down

0 comments on commit f0e8afa

Please sign in to comment.