Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: conditional_node #723

Merged
merged 2 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/groq/smart_scraper_multi_cond_groq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Basic example of scraping pipeline using SmartScraperMultiConcatGraph with Groq
"""

import os
import json
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperMultiCondGraph

load_dotenv()

# ************************************************
# Define the configuration for the graph
# ************************************************

groq_key = os.getenv("GROQ_APIKEY")

graph_config = {
"llm": {
"model": "groq/gemma-7b-it",
"api_key": groq_key,
"temperature": 0
},
"headless": False
}

# *******************************************************
# Create the SmartScraperMultiCondGraph instance and run it
# *******************************************************

multiple_search_graph = SmartScraperMultiCondGraph(
prompt="Who is Marco Perini?",
source=[
"https://perinim.github.io/",
"https://perinim.github.io/cv/"
],
schema=None,
config=graph_config
)

result = multiple_search_graph.run()
print(json.dumps(result, indent=4))
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ undetected-playwright>=0.3.0
google>=3.0.0
semchunk>=1.0.1
langchain-ollama>=0.1.3
simpleeval>=0.9.13
1 change: 1 addition & 0 deletions scrapegraphai/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@
from .screenshot_scraper_graph import ScreenshotScraperGraph
from .smart_scraper_multi_concat_graph import SmartScraperMultiConcatGraph
from .code_generator_graph import CodeGeneratorGraph
from .smart_scraper_multi_cond_graph import SmartScraperMultiCondGraph
from .depth_search_graph import DepthSearchGraph
26 changes: 24 additions & 2 deletions scrapegraphai/graphs/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def __init__(self, nodes: list, edges: list, entry_point: str,
# raise a warning if the entry point is not the first node in the list
warnings.warn(
"Careful! The entry point node is different from the first node in the graph.")

self._set_conditional_node_edges()

# Burr configuration
self.use_burr = use_burr
Expand All @@ -77,9 +79,24 @@ def _create_edges(self, edges: list) -> dict:

edge_dict = {}
for from_node, to_node in edges:
edge_dict[from_node.node_name] = to_node.node_name
if from_node.node_type != 'conditional_node':
edge_dict[from_node.node_name] = to_node.node_name
return edge_dict

def _set_conditional_node_edges(self):
"""
Sets the true_node_name and false_node_name for each ConditionalNode.
"""
for node in self.nodes:
if node.node_type == 'conditional_node':
# Find outgoing edges from this ConditionalNode
outgoing_edges = [(from_node, to_node) for from_node, to_node in self.raw_edges if from_node.node_name == node.node_name]
if len(outgoing_edges) != 2:
raise ValueError(f"ConditionalNode '{node.node_name}' must have exactly two outgoing edges.")
# Assign true_node_name and false_node_name
node.true_node_name = outgoing_edges[0][1].node_name
node.false_node_name = outgoing_edges[1][1].node_name

def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
"""
Executes the graph by traversing nodes starting from the
Expand Down Expand Up @@ -201,7 +218,12 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
cb_total["total_cost_USD"] += cb_data["total_cost_USD"]

if current_node.node_type == "conditional_node":
current_node_name = result
node_names = {node.node_name for node in self.nodes}
if result in node_names:
current_node_name = result
else:
raise ValueError(f"Conditional Node returned a node name '{result}' that does not exist in the graph")

elif current_node_name in self.edges:
current_node_name = self.edges[current_node_name]
else:
Expand Down
130 changes: 130 additions & 0 deletions scrapegraphai/graphs/smart_scraper_multi_cond_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
SmartScraperMultiCondGraph Module with ConditionalNode
"""
from copy import deepcopy
from typing import List, Optional
from pydantic import BaseModel
from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
from .smart_scraper_graph import SmartScraperGraph
from ..nodes import (
GraphIteratorNode,
MergeAnswersNode,
ConcatAnswersNode,
ConditionalNode
)
from ..utils.copy import safe_deepcopy

class SmartScraperMultiCondGraph(AbstractGraph):
"""
SmartScraperMultiConditionalGraph is a scraping pipeline that scrapes a
list of URLs and generates answers to a given prompt.

Attributes:
prompt (str): The user prompt to search the internet.
llm_model (dict): The configuration for the language model.
embedder_model (dict): The configuration for the embedder model.
headless (bool): A flag to run the browser in headless mode.
verbose (bool): A flag to display the execution information.
model_token (int): The token limit for the language model.

Args:
prompt (str): The user prompt to search the internet.
source (List[str]): The source of the graph.
config (dict): Configuration parameters for the graph.
schema (Optional[BaseModel]): The schema for the graph output.

Example:
>>> search_graph = MultipleSearchGraph(
... "What is Chioggia famous for?",
... {"llm": {"model": "openai/gpt-3.5-turbo"}}
... )
>>> result = search_graph.run()
"""

def __init__(self, prompt: str, source: List[str],
config: dict, schema: Optional[BaseModel] = None):

self.max_results = config.get("max_results", 3)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)

super().__init__(prompt, config, source, schema)

def _create_graph(self) -> BaseGraph:
"""
Creates the graph of nodes representing the workflow for web scraping and searching,
including a ConditionalNode to decide between merging or concatenating the results.

Returns:
BaseGraph: A graph instance representing the web scraping and searching workflow.
"""

# Node that iterates over the URLs and collects results
graph_iterator_node = GraphIteratorNode(
input="user_prompt & urls",
output=["results"],
node_config={
"graph_instance": SmartScraperGraph,
"scraper_config": self.copy_config,
},
schema=self.copy_schema,
node_name="GraphIteratorNode"
)

# ConditionalNode to check if len(results) > 2
conditional_node = ConditionalNode(
input="results",
output=["results"],
node_name="ConditionalNode",
node_config={
'key_name': 'results',
'condition': 'len(results) > 2'
}
)

merge_answers_node = MergeAnswersNode(
input="user_prompt & results",
output=["answer"],
node_config={
"llm_model": self.llm_model,
"schema": self.copy_schema
},
node_name="MergeAnswersNode"
)

concat_node = ConcatAnswersNode(
input="results",
output=["answer"],
node_config={},
node_name="ConcatNode"
)

# Build the graph
return BaseGraph(
nodes=[
graph_iterator_node,
conditional_node,
merge_answers_node,
concat_node,
],
edges=[
(graph_iterator_node, conditional_node),
(conditional_node, merge_answers_node), # True node (len(results) > 2)
(conditional_node, concat_node), # False node (len(results) <= 2)
],
entry_point=graph_iterator_node,
graph_name=self.__class__.__name__
)

def run(self) -> str:
"""
Executes the web scraping and searching process.

Returns:
str: The answer to the prompt.
"""
inputs = {"user_prompt": self.prompt, "urls": self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)

return self.final_state.get("answer", "No answer found.")
1 change: 1 addition & 0 deletions scrapegraphai/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .html_analyzer_node import HtmlAnalyzerNode
from .generate_code_node import GenerateCodeNode
from .search_node_with_context import SearchLinksWithContext
from .conditional_node import ConditionalNode
from .reasoning_node import ReasoningNode
from .fetch_node_level_k import FetchNodeLevelK
from .generate_answer_node_k_level import GenerateAnswerNodeKLevel
Expand Down
65 changes: 61 additions & 4 deletions scrapegraphai/nodes/conditional_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from typing import Optional, List
from .base_node import BaseNode
from simpleeval import simple_eval, EvalWithCompoundTypes

class ConditionalNode(BaseNode):
"""
Expand All @@ -28,13 +29,28 @@ class ConditionalNode(BaseNode):

"""

def __init__(self):
def __init__(self,
input: str,
output: List[str],
node_config: Optional[dict] = None,
node_name: str = "Cond",):
"""
Initializes an empty ConditionalNode.
"""
#super().__init__(node_name, "node", input, output, 2, node_config)
pass
super().__init__(node_name, "conditional_node", input, output, 2, node_config)

try:
self.key_name = self.node_config["key_name"]
except:
raise NotImplementedError("You need to provide key_name inside the node config")

self.true_node_name = None
self.false_node_name = None

self.condition = self.node_config.get("condition", None)

self.eval_instance = EvalWithCompoundTypes()
self.eval_instance.functions = {'len': len}

def execute(self, state: dict) -> dict:
"""
Expand All @@ -47,4 +63,45 @@ def execute(self, state: dict) -> dict:
str: The name of the next node to execute based on the presence of the key.
"""

pass
if self.true_node_name is None or self.false_node_name is None:
raise ValueError("ConditionalNode's next nodes are not set properly.")

# Evaluate the condition
if self.condition:
condition_result = self._evaluate_condition(state, self.condition)
else:
# Default behavior: check existence and non-emptiness of key_name
value = state.get(self.key_name)
condition_result = value is not None and value != ''

# Return the appropriate next node name
if condition_result:
return self.true_node_name
else:
return self.false_node_name

def _evaluate_condition(self, state: dict, condition: str) -> bool:
"""
Parses and evaluates the condition expression against the state.

Args:
state (dict): The current state of the graph.
condition (str): The condition expression to evaluate.

Returns:
bool: The result of the condition evaluation.
"""
# Combine state and allowed functions for evaluation context
eval_globals = self.eval_instance.functions.copy()
eval_globals.update(state)

try:
result = simple_eval(
condition,
names=eval_globals,
functions=self.eval_instance.functions,
operators=self.eval_instance.operators
)
return bool(result)
except Exception as e:
raise ValueError(f"Error evaluating condition '{condition}' in {self.node_name}: {e}")