Skip to content

Commit

Permalink
Merge pull request #723 from ekinsenler/cond_node
Browse files Browse the repository at this point in the history
feat: conditional_node
  • Loading branch information
VinciGit00 authored Oct 5, 2024
2 parents 84d7937 + c06cafc commit ae5d2ef
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 6 deletions.
42 changes: 42 additions & 0 deletions examples/groq/smart_scraper_multi_cond_groq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Basic example of scraping pipeline using SmartScraperMultiConcatGraph with Groq
"""

import os
import json
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperMultiCondGraph

load_dotenv()

# ************************************************
# Define the configuration for the graph
# ************************************************

groq_key = os.getenv("GROQ_APIKEY")

graph_config = {
"llm": {
"model": "groq/gemma-7b-it",
"api_key": groq_key,
"temperature": 0
},
"headless": False
}

# *******************************************************
# Create the SmartScraperMultiCondGraph instance and run it
# *******************************************************

multiple_search_graph = SmartScraperMultiCondGraph(
prompt="Who is Marco Perini?",
source=[
"https://perinim.github.io/",
"https://perinim.github.io/cv/"
],
schema=None,
config=graph_config
)

result = multiple_search_graph.run()
print(json.dumps(result, indent=4))
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ undetected-playwright>=0.3.0
google>=3.0.0
semchunk>=1.0.1
langchain-ollama>=0.1.3
simpleeval>=0.9.13
1 change: 1 addition & 0 deletions scrapegraphai/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@
from .screenshot_scraper_graph import ScreenshotScraperGraph
from .smart_scraper_multi_concat_graph import SmartScraperMultiConcatGraph
from .code_generator_graph import CodeGeneratorGraph
from .smart_scraper_multi_cond_graph import SmartScraperMultiCondGraph
from .depth_search_graph import DepthSearchGraph
26 changes: 24 additions & 2 deletions scrapegraphai/graphs/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def __init__(self, nodes: list, edges: list, entry_point: str,
# raise a warning if the entry point is not the first node in the list
warnings.warn(
"Careful! The entry point node is different from the first node in the graph.")

self._set_conditional_node_edges()

# Burr configuration
self.use_burr = use_burr
Expand All @@ -77,9 +79,24 @@ def _create_edges(self, edges: list) -> dict:

edge_dict = {}
for from_node, to_node in edges:
edge_dict[from_node.node_name] = to_node.node_name
if from_node.node_type != 'conditional_node':
edge_dict[from_node.node_name] = to_node.node_name
return edge_dict

def _set_conditional_node_edges(self):
"""
Sets the true_node_name and false_node_name for each ConditionalNode.
"""
for node in self.nodes:
if node.node_type == 'conditional_node':
# Find outgoing edges from this ConditionalNode
outgoing_edges = [(from_node, to_node) for from_node, to_node in self.raw_edges if from_node.node_name == node.node_name]
if len(outgoing_edges) != 2:
raise ValueError(f"ConditionalNode '{node.node_name}' must have exactly two outgoing edges.")
# Assign true_node_name and false_node_name
node.true_node_name = outgoing_edges[0][1].node_name
node.false_node_name = outgoing_edges[1][1].node_name

def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
"""
Executes the graph by traversing nodes starting from the
Expand Down Expand Up @@ -201,7 +218,12 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
cb_total["total_cost_USD"] += cb_data["total_cost_USD"]

if current_node.node_type == "conditional_node":
current_node_name = result
node_names = {node.node_name for node in self.nodes}
if result in node_names:
current_node_name = result
else:
raise ValueError(f"Conditional Node returned a node name '{result}' that does not exist in the graph")

elif current_node_name in self.edges:
current_node_name = self.edges[current_node_name]
else:
Expand Down
130 changes: 130 additions & 0 deletions scrapegraphai/graphs/smart_scraper_multi_cond_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
SmartScraperMultiCondGraph Module with ConditionalNode
"""
from copy import deepcopy
from typing import List, Optional
from pydantic import BaseModel
from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
from .smart_scraper_graph import SmartScraperGraph
from ..nodes import (
GraphIteratorNode,
MergeAnswersNode,
ConcatAnswersNode,
ConditionalNode
)
from ..utils.copy import safe_deepcopy

class SmartScraperMultiCondGraph(AbstractGraph):
"""
SmartScraperMultiConditionalGraph is a scraping pipeline that scrapes a
list of URLs and generates answers to a given prompt.
Attributes:
prompt (str): The user prompt to search the internet.
llm_model (dict): The configuration for the language model.
embedder_model (dict): The configuration for the embedder model.
headless (bool): A flag to run the browser in headless mode.
verbose (bool): A flag to display the execution information.
model_token (int): The token limit for the language model.
Args:
prompt (str): The user prompt to search the internet.
source (List[str]): The source of the graph.
config (dict): Configuration parameters for the graph.
schema (Optional[BaseModel]): The schema for the graph output.
Example:
>>> search_graph = MultipleSearchGraph(
... "What is Chioggia famous for?",
... {"llm": {"model": "openai/gpt-3.5-turbo"}}
... )
>>> result = search_graph.run()
"""

def __init__(self, prompt: str, source: List[str],
config: dict, schema: Optional[BaseModel] = None):

self.max_results = config.get("max_results", 3)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)

super().__init__(prompt, config, source, schema)

def _create_graph(self) -> BaseGraph:
"""
Creates the graph of nodes representing the workflow for web scraping and searching,
including a ConditionalNode to decide between merging or concatenating the results.
Returns:
BaseGraph: A graph instance representing the web scraping and searching workflow.
"""

# Node that iterates over the URLs and collects results
graph_iterator_node = GraphIteratorNode(
input="user_prompt & urls",
output=["results"],
node_config={
"graph_instance": SmartScraperGraph,
"scraper_config": self.copy_config,
},
schema=self.copy_schema,
node_name="GraphIteratorNode"
)

# ConditionalNode to check if len(results) > 2
conditional_node = ConditionalNode(
input="results",
output=["results"],
node_name="ConditionalNode",
node_config={
'key_name': 'results',
'condition': 'len(results) > 2'
}
)

merge_answers_node = MergeAnswersNode(
input="user_prompt & results",
output=["answer"],
node_config={
"llm_model": self.llm_model,
"schema": self.copy_schema
},
node_name="MergeAnswersNode"
)

concat_node = ConcatAnswersNode(
input="results",
output=["answer"],
node_config={},
node_name="ConcatNode"
)

# Build the graph
return BaseGraph(
nodes=[
graph_iterator_node,
conditional_node,
merge_answers_node,
concat_node,
],
edges=[
(graph_iterator_node, conditional_node),
(conditional_node, merge_answers_node), # True node (len(results) > 2)
(conditional_node, concat_node), # False node (len(results) <= 2)
],
entry_point=graph_iterator_node,
graph_name=self.__class__.__name__
)

def run(self) -> str:
"""
Executes the web scraping and searching process.
Returns:
str: The answer to the prompt.
"""
inputs = {"user_prompt": self.prompt, "urls": self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)

return self.final_state.get("answer", "No answer found.")
1 change: 1 addition & 0 deletions scrapegraphai/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .html_analyzer_node import HtmlAnalyzerNode
from .generate_code_node import GenerateCodeNode
from .search_node_with_context import SearchLinksWithContext
from .conditional_node import ConditionalNode
from .reasoning_node import ReasoningNode
from .fetch_node_level_k import FetchNodeLevelK
from .generate_answer_node_k_level import GenerateAnswerNodeKLevel
Expand Down
65 changes: 61 additions & 4 deletions scrapegraphai/nodes/conditional_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from typing import Optional, List
from .base_node import BaseNode
from simpleeval import simple_eval, EvalWithCompoundTypes

class ConditionalNode(BaseNode):
"""
Expand All @@ -28,13 +29,28 @@ class ConditionalNode(BaseNode):
"""

def __init__(self):
def __init__(self,
input: str,
output: List[str],
node_config: Optional[dict] = None,
node_name: str = "Cond",):
"""
Initializes an empty ConditionalNode.
"""
#super().__init__(node_name, "node", input, output, 2, node_config)
pass
super().__init__(node_name, "conditional_node", input, output, 2, node_config)

try:
self.key_name = self.node_config["key_name"]
except:
raise NotImplementedError("You need to provide key_name inside the node config")

self.true_node_name = None
self.false_node_name = None

self.condition = self.node_config.get("condition", None)

self.eval_instance = EvalWithCompoundTypes()
self.eval_instance.functions = {'len': len}

def execute(self, state: dict) -> dict:
"""
Expand All @@ -47,4 +63,45 @@ def execute(self, state: dict) -> dict:
str: The name of the next node to execute based on the presence of the key.
"""

pass
if self.true_node_name is None or self.false_node_name is None:
raise ValueError("ConditionalNode's next nodes are not set properly.")

# Evaluate the condition
if self.condition:
condition_result = self._evaluate_condition(state, self.condition)
else:
# Default behavior: check existence and non-emptiness of key_name
value = state.get(self.key_name)
condition_result = value is not None and value != ''

# Return the appropriate next node name
if condition_result:
return self.true_node_name
else:
return self.false_node_name

def _evaluate_condition(self, state: dict, condition: str) -> bool:
"""
Parses and evaluates the condition expression against the state.
Args:
state (dict): The current state of the graph.
condition (str): The condition expression to evaluate.
Returns:
bool: The result of the condition evaluation.
"""
# Combine state and allowed functions for evaluation context
eval_globals = self.eval_instance.functions.copy()
eval_globals.update(state)

try:
result = simple_eval(
condition,
names=eval_globals,
functions=self.eval_instance.functions,
operators=self.eval_instance.operators
)
return bool(result)
except Exception as e:
raise ValueError(f"Error evaluating condition '{condition}' in {self.node_name}: {e}")

0 comments on commit ae5d2ef

Please sign in to comment.