Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions cli/decompose/decompose.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import keyword
import re
import shutil
from enum import Enum
from graphlib import TopologicalSorter
from pathlib import Path
Expand All @@ -16,7 +17,8 @@
class DecompVersion(str, Enum):
latest = "latest"
v1 = "v1"
# v2 = "v2"
v2 = "v2"
# v3 = "v3"


this_file_dir = Path(__file__).resolve().parent
Expand Down Expand Up @@ -307,27 +309,33 @@ def run(
backend_api_key=backend_api_key,
)

# Verify that all user variables are properly defined before use
# This may reorder subtasks if dependencies are out of order
decomp_data = verify_user_variables(decomp_data, input_var)
decomp_dir = out_dir / out_name
val_fn_dir = decomp_dir / "validations"
val_fn_dir.mkdir(parents=True)

with open(out_dir / f"{out_name}.json", "w") as f:
(val_fn_dir / "__init__.py").touch()

for constraint in decomp_data["identified_constraints"]:
if constraint["val_fn"] is not None:
with open(val_fn_dir / f"{constraint['val_fn_name']}.py", "w") as f:
f.write(constraint["val_fn"] + "\n")

with open(decomp_dir / f"{out_name}.json", "w") as f:
json.dump(decomp_data, f, indent=2)

with open(out_dir / f"{out_name}.py", "w") as f:
with open(decomp_dir / f"{out_name}.py", "w") as f:
f.write(
m_template.render(
subtasks=decomp_data["subtasks"], user_inputs=input_var
subtasks=decomp_data["subtasks"],
user_inputs=input_var,
identified_constraints=decomp_data["identified_constraints"],
)
+ "\n"
)
except Exception:
created_json = Path(out_dir / f"{out_name}.json")
created_py = Path(out_dir / f"{out_name}.py")

if created_json.exists() and created_json.is_file():
created_json.unlink()
if created_py.exists() and created_py.is_file():
created_py.unlink()
decomp_dir = out_dir / out_name
if decomp_dir.exists() and decomp_dir.is_dir():
shutil.rmtree(decomp_dir)

raise Exception
15 changes: 15 additions & 0 deletions cli/decompose/m_decomp_result_v1.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ import os
import textwrap

import mellea
{%- if "code" in identified_constraints | map(attribute="val_strategy") %}
from mellea.stdlib.requirement import req
{% for c in identified_constraints %}
{%- if c.val_fn %}
from validations.{{ c.val_fn_name }} import validate_input as {{ c.val_fn_name }}
{%- endif %}
{%- endfor %}
{%- endif %}

m = mellea.start_session()
{%- if user_inputs %}
Expand All @@ -30,7 +38,14 @@ except KeyError as e:
{%- if item.constraints %}
requirements=[
{%- for c in item.constraints %}
{%- if c.val_fn %}
req(
{{ c.constraint | tojson}},
validation_fn={{ c.val_fn_name }},
),
{%- else %}
{{ c.constraint | tojson}},
{%- endif %}
{%- endfor %}
],
{%- else %}
Expand Down
91 changes: 91 additions & 0 deletions cli/decompose/m_decomp_result_v2.py.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
{% if user_inputs -%}
import os
{% endif -%}
import textwrap

import mellea
{%- if "code" in identified_constraints | map(attribute="val_strategy") %}
from mellea.stdlib.requirement import req
{% for c in identified_constraints %}
{%- if c.val_fn %}
from validations.{{ c.val_fn_name }} import validate_input as {{ c.val_fn_name }}
{%- endif %}
{%- endfor %}
{%- endif %}

m = mellea.start_session()
{%- if user_inputs %}


# User Input Variables
try:
{%- for var in user_inputs %}
{{ var | lower }} = os.environ["{{ var | upper }}"]
{%- endfor %}
except KeyError as e:
print(f"ERROR: One or more required environment variables are not set; {e}")
exit(1)
{%- endif %}
{%- for item in subtasks %}


{{ item.tag | lower }}_gnrl = textwrap.dedent(
R"""
{{ item.general_instructions | trim | indent(width=4, first=False) }}
""".strip()
)
{{ item.tag | lower }} = m.instruct(
{%- if not item.input_vars_required %}
{{ item.subtask[3:] | trim | tojson }},
{%- else %}
textwrap.dedent(
R"""
{{ item.subtask[3:] | trim }}

Here are the input variables and their content:
{%- for var in item.input_vars_required %}

- {{ var | upper }} = {{ "{{" }}{{ var | upper }}{{ "}}" }}
{%- endfor %}
""".strip()
),
{%- endif %}
{%- if item.constraints %}
requirements=[
{%- for c in item.constraints %}
{%- if c.val_fn %}
req(
{{ c.constraint | tojson}},
validation_fn={{ c.val_fn_name }},
),
{%- else %}
{{ c.constraint | tojson}},
{%- endif %}
{%- endfor %}
],
{%- else %}
requirements=None,
{%- endif %}
{%- if item.input_vars_required %}
user_variables={
{%- for var in item.input_vars_required %}
{{ var | upper | tojson }}: {{ var | lower }},
{%- endfor %}
},
{%- endif %}
grounding_context={
"GENERAL_INSTRUCTIONS": {{ item.tag | lower }}_gnrl,
{%- for var in item.depends_on %}
{{ var | upper | tojson }}: {{ var | lower }}.value,
{%- endfor %}
},
)
assert {{ item.tag | lower }}.value is not None, 'ERROR: task "{{ item.tag | lower }}" execution failed'
{%- if loop.last %}


final_answer = {{ item.tag | lower }}.value

print(final_answer)
{%- endif -%}
{%- endfor -%}
61 changes: 48 additions & 13 deletions cli/decompose/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,36 @@

from .prompt_modules import (
constraint_extractor,
# general_instructions,
general_instructions,
subtask_constraint_assign,
subtask_list,
subtask_prompt_generator,
validation_code_generator,
validation_decision,
)
from .prompt_modules.subtask_constraint_assign import SubtaskPromptConstraintsItem
from .prompt_modules.subtask_list import SubtaskItem
from .prompt_modules.subtask_prompt_generator import SubtaskPromptItem


class ConstraintValData(TypedDict):
val_strategy: Literal["code", "llm"]
val_fn: str | None


class ConstraintResult(TypedDict):
constraint: str
validation_strategy: str
val_strategy: Literal["code", "llm"]
val_fn: str | None
val_fn_name: str


class DecompSubtasksResult(TypedDict):
subtask: str
tag: str
constraints: list[ConstraintResult]
prompt_template: str
# general_instructions: str
general_instructions: str
input_vars_required: list[str]
depends_on: list[str]
generated_response: NotRequired[str]
Expand Down Expand Up @@ -72,7 +80,9 @@ def decompose(
case DecompBackend.ollama:
m_session = MelleaSession(
OllamaModelBackend(
model_id=model_id, model_options={ModelOption.CONTEXT_WINDOW: 16384}
model_id=model_id,
base_url=backend_endpoint,
model_options={ModelOption.CONTEXT_WINDOW: 16384},
)
)
case DecompBackend.openai:
Expand Down Expand Up @@ -115,11 +125,27 @@ def decompose(
m_session, task_prompt, enforce_same_words=False
).parse()

constraint_validation_strategies: dict[str, Literal["code", "llm"]] = {
cons_key: validation_decision.generate(m_session, cons_key).parse()
constraint_val_strategy: dict[
str, dict[Literal["val_strategy"], Literal["code", "llm"]]
] = {
cons_key: {
"val_strategy": validation_decision.generate(m_session, cons_key).parse()
}
for cons_key in task_prompt_constraints
}

constraint_val_data: dict[str, ConstraintValData] = {}

for cons_key in constraint_val_strategy:
constraint_val_data[cons_key] = {
"val_strategy": constraint_val_strategy[cons_key]["val_strategy"],
"val_fn": None,
}
if constraint_val_data[cons_key]["val_strategy"] == "code":
constraint_val_data[cons_key]["val_fn"] = (
validation_code_generator.generate(m_session, cons_key).parse()
)

subtask_prompts: list[SubtaskPromptItem] = subtask_prompt_generator.generate(
m_session,
task_prompt,
Expand All @@ -142,14 +168,21 @@ def decompose(
constraints=[
{
"constraint": cons_str,
"validation_strategy": constraint_validation_strategies[cons_str],
"val_strategy": constraint_val_data[cons_str]["val_strategy"],
"val_fn_name": f"val_fn_{task_prompt_constraints.index(cons_str) + 1}",
# >> Always include generated "val_fn" code (experimental)
"val_fn": constraint_val_data[cons_str]["val_fn"],
# >> Include generated "val_fn" code only for the last subtask (experimental)
# "val_fn": constraint_val_data[cons_str]["val_fn"]
# if subtask_i + 1 == len(subtask_prompts_with_constraints)
# else None,
}
for cons_str in subtask_data.constraints
],
prompt_template=subtask_data.prompt_template,
# general_instructions=general_instructions.generate(
# m_session, input_str=subtask_data.prompt_template
# ).parse(),
general_instructions=general_instructions.generate(
m_session, input_str=subtask_data.prompt_template
).parse(),
input_vars_required=list(
dict.fromkeys( # Remove duplicates while preserving the original order.
[
Expand All @@ -173,7 +206,7 @@ def decompose(
)
),
)
for subtask_data in subtask_prompts_with_constraints
for subtask_i, subtask_data in enumerate(subtask_prompts_with_constraints)
]

return DecompPipelineResult(
Expand All @@ -182,9 +215,11 @@ def decompose(
identified_constraints=[
{
"constraint": cons_str,
"validation_strategy": constraint_validation_strategies[cons_str],
"val_strategy": constraint_val_data[cons_str]["val_strategy"],
"val_fn": constraint_val_data[cons_str]["val_fn"],
"val_fn_name": f"val_fn_{cons_i + 1}",
}
for cons_str in task_prompt_constraints
for cons_i, cons_str in enumerate(task_prompt_constraints)
],
subtasks=decomp_subtask_result,
)
3 changes: 3 additions & 0 deletions cli/decompose/prompt_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@
from .subtask_prompt_generator import (
subtask_prompt_generator as subtask_prompt_generator,
)
from .validation_code_generator import (
validation_code_generator as validation_code_generator,
)
from .validation_decision import validation_decision as validation_decision
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ You will be provided with the following 4 parameters inside their respective tag
4. <all_constraints> : A list of candidate (possible) constraints that can be assigned to the target task.
</parameters>

The <all_constraints> list contain the constraints of all tasks on the <execution_plan>, your job is to filter and select only the constraints belonging to your target task.
The <all_constraints> is a list of constraints identified for the entire <execution_plan>, your job is to filter and select only the constraints belonging to your target task.
It is possible that none of the constraints in the <all_constraints> are relevant or related to your target task.

Below, enclosed in <general_instructions> tags, are instructions to guide you on how to complete your assignment:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ._exceptions import (
BackendGenerationError as BackendGenerationError,
TagExtractionError as TagExtractionError,
)
from ._validation_code_generator import (
validation_code_generator as validation_code_generator,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Any


class ValidationCodeGeneratorError(Exception):
def __init__(self, error_message: str, **kwargs: dict[str, Any]):
self.error_message = error_message
self.__dict__.update(kwargs)
super().__init__(
f'Module Error "validation_code_generator"; {self.error_message}'
)


class BackendGenerationError(ValidationCodeGeneratorError):
"""Raised when LLM generation fails in the "validation_code_generator" prompt module."""

def __init__(self, error_message: str, **kwargs: dict[str, Any]):
super().__init__(error_message, **kwargs)


class TagExtractionError(ValidationCodeGeneratorError):
"""Raised when tag extraction fails in the "validation_code_generator" prompt module."""

def __init__(self, error_message: str, **kwargs: dict[str, Any]):
super().__init__(error_message, **kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._icl_examples import icl_examples as default_icl_examples
from ._prompt import (
get_system_prompt as get_system_prompt,
get_user_prompt as get_user_prompt,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from ._icl_examples import icl_examples as icl_examples
from ._types import ICLExample as ICLExample
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._example import example as example
Loading
Loading