-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from VinciGit00/graph_implementation
Graph implementation
- Loading branch information
Showing
14 changed files
with
636 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from yosoai.graphs import SmartScraper | ||
|
||
OPENAI_API_KEY = '' | ||
|
||
llm_config = { | ||
"api_key": OPENAI_API_KEY, | ||
"model_name": "gpt-3.5-turbo", | ||
} | ||
|
||
url = "https://perinim.github.io/projects/" | ||
prompt = "List me all the titles and project descriptions" | ||
|
||
smart_scraper = SmartScraper(prompt, url, llm_config) | ||
|
||
answer = smart_scraper.run() | ||
print(answer) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
wheel==0.42.0 | ||
setuptools==65.5.1 | ||
twine==4.0.2 | ||
sphinx==7.1.2 | ||
sphinx-rtd-theme==2.0.0 | ||
pytest==8.0.0 | ||
pytest==8.0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .base_graph import BaseGraph | ||
from .smart_scraper_graph import SmartScraper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
class BaseGraph: | ||
""" | ||
BaseGraph manages the execution flow of a graph composed of interconnected nodes. | ||
Attributes: | ||
nodes (dict): A dictionary mapping each node's name to its corresponding node instance. | ||
edges (dict): A dictionary representing the directed edges of the graph where each | ||
key-value pair corresponds to the from-node and to-node relationship. | ||
entry_point (str): The name of the entry point node from which the graph execution begins. | ||
Methods: | ||
execute(initial_state): Executes the graph's nodes starting from the entry point and | ||
traverses the graph based on the provided initial state. | ||
Args: | ||
nodes (iterable): An iterable of node instances that will be part of the graph. | ||
edges (iterable): An iterable of tuples where each tuple represents a directed edge | ||
in the graph, defined by a pair of nodes (from_node, to_node). | ||
entry_point (BaseNode): The node instance that represents the entry point of the graph. | ||
""" | ||
|
||
def __init__(self, nodes, edges, entry_point): | ||
""" | ||
Initializes the graph with nodes, edges, and the entry point. | ||
""" | ||
self.nodes = {node.node_name: node for node in nodes} | ||
self.edges = self._create_edges(edges) | ||
self.entry_point = entry_point.node_name | ||
|
||
def _create_edges(self, edges): | ||
""" | ||
Helper method to create a dictionary of edges from the given iterable of tuples. | ||
Args: | ||
edges (iterable): An iterable of tuples representing the directed edges. | ||
Returns: | ||
dict: A dictionary of edges with the from-node as keys and to-node as values. | ||
""" | ||
edge_dict = {} | ||
for from_node, to_node in edges: | ||
edge_dict[from_node.node_name] = to_node.node_name | ||
return edge_dict | ||
|
||
def execute(self, initial_state): | ||
""" | ||
Executes the graph by traversing nodes starting from the entry point. The execution | ||
follows the edges based on the result of each node's execution and continues until | ||
it reaches a node with no outgoing edges. | ||
Args: | ||
initial_state (dict): The initial state to pass to the entry point node. | ||
Returns: | ||
dict: The state after execution has completed, which may have been altered by the nodes. | ||
""" | ||
current_node_name = self.entry_point | ||
state = initial_state | ||
|
||
while current_node_name is not None: | ||
current_node = self.nodes[current_node_name] | ||
result = current_node.execute(state) | ||
|
||
if current_node.node_type == "conditional_node": | ||
# For ConditionalNode, result is the next node based on the condition | ||
current_node_name = result | ||
elif current_node_name in self.edges: | ||
# For regular nodes, move to the next node based on the defined edges | ||
current_node_name = self.edges[current_node_name] | ||
else: | ||
# No further edges, end the execution | ||
current_node_name = None | ||
|
||
return state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from langchain_openai import ChatOpenAI | ||
from .base_graph import BaseGraph | ||
from ..nodes import FetchHTMLNode,ConditionalNode, GetProbableTagsNode, GenerateAnswerNode, ParseHTMLNode | ||
|
||
class SmartScraper: | ||
""" | ||
SmartScraper is a comprehensive web scraping tool that automates the process of extracting | ||
information from web pages using a natural language model to interpret and answer prompts. | ||
Attributes: | ||
prompt (str): The user's natural language prompt for the information to be extracted. | ||
url (str): The URL of the web page to scrape. | ||
llm_config (dict): Configuration parameters for the language model, with 'api_key' being mandatory. | ||
llm (ChatOpenAI): An instance of the ChatOpenAI class configured with llm_config. | ||
graph (BaseGraph): An instance of the BaseGraph class representing the scraping workflow. | ||
Methods: | ||
run(): Executes the web scraping process and returns the answer to the prompt. | ||
Args: | ||
prompt (str): The user's natural language prompt for the information to be extracted. | ||
url (str): The URL of the web page to scrape. | ||
llm_config (dict): A dictionary containing configuration options for the language model. | ||
Must include 'api_key', may also specify 'model_name', 'temperature', and 'streaming'. | ||
""" | ||
|
||
def __init__(self, prompt, url, llm_config): | ||
""" | ||
Initializes the SmartScraper with a prompt, URL, and language model configuration. | ||
""" | ||
self.prompt = prompt | ||
self.url = url | ||
self.llm_config = llm_config | ||
self.llm = self._create_llm() | ||
self.graph = self._create_graph() | ||
|
||
def _create_llm(self): | ||
""" | ||
Creates an instance of the ChatOpenAI class with the provided language model configuration. | ||
Returns: | ||
ChatOpenAI: An instance of the ChatOpenAI class. | ||
Raises: | ||
ValueError: If 'api_key' is not provided in llm_config. | ||
""" | ||
llm_defaults = { | ||
"model_name": "gpt-3.5-turbo", | ||
"temperature": 0, | ||
"streaming": True | ||
} | ||
# Update defaults with any LLM parameters that were provided | ||
llm_params = {**llm_defaults, **self.llm_config} | ||
# Ensure the api_key is set, raise an error if it's not | ||
if "api_key" not in llm_params: | ||
raise ValueError("LLM configuration must include an 'api_key'.") | ||
# Create the ChatOpenAI instance with the provided and default parameters | ||
return ChatOpenAI(**llm_params) | ||
|
||
def _create_graph(self): | ||
""" | ||
Creates the graph of nodes representing the workflow for web scraping. | ||
Returns: | ||
BaseGraph: An instance of the BaseGraph class. | ||
""" | ||
fetch_html_node = FetchHTMLNode("fetch_html") | ||
get_probable_tags_node = GetProbableTagsNode(self.llm, "get_probable_tags") | ||
parse_document_node = ParseHTMLNode("parse_document") | ||
generate_answer_node = GenerateAnswerNode(self.llm, "generate_answer") | ||
conditional_node = ConditionalNode("conditional", [parse_document_node, generate_answer_node]) | ||
|
||
return BaseGraph( | ||
nodes={ | ||
fetch_html_node, | ||
get_probable_tags_node, | ||
conditional_node, | ||
parse_document_node, | ||
generate_answer_node, | ||
}, | ||
edges={ | ||
(fetch_html_node, get_probable_tags_node), | ||
(get_probable_tags_node, conditional_node), | ||
(parse_document_node, generate_answer_node) | ||
}, | ||
entry_point=fetch_html_node | ||
) | ||
|
||
def run(self): | ||
""" | ||
Executes the scraping process by running the graph and returns the extracted information. | ||
Returns: | ||
str: The answer extracted from the web page, corresponding to the given prompt. | ||
""" | ||
inputs = {"keys": {"user_input": self.prompt, "url": self.url}} | ||
final_state = self.graph.execute(inputs) | ||
return final_state["keys"].get("answer", "No answer found.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .fetch_html_node import FetchHTMLNode | ||
from .conditional_node import ConditionalNode | ||
from .get_probable_tags_node import GetProbableTagsNode | ||
from .generate_answer_node import GenerateAnswerNode | ||
from .parse_html_node import ParseHTMLNode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
class BaseNode(ABC): | ||
""" | ||
An abstract base class for nodes in a graph-based workflow. Each node is | ||
intended to perform a specific action when executed as part of the graph's | ||
processing flow. | ||
Attributes: | ||
node_name (str): A unique identifier for the node. | ||
node_type (str): Specifies the node's type, which influences how the | ||
node interacts within the graph. Valid values are | ||
"node" for standard nodes and "conditional_node" for | ||
nodes that determine the flow based on conditions. | ||
Methods: | ||
execute(state): An abstract method that subclasses must implement. This | ||
method should contain the logic that the node executes | ||
when it is reached in the graph's flow. It takes the | ||
graph's current state as input and returns the updated | ||
state after execution. | ||
Args: | ||
node_name (str): The unique identifier name for the node. This name is | ||
used to reference the node within the graph. | ||
node_type (str): The type of the node, limited to "node" or | ||
"conditional_node". This categorization helps in | ||
determining the node's role and behavior within the | ||
graph. | ||
Raises: | ||
ValueError: If the provided `node_type` is not one of the allowed | ||
values ("node" or "conditional_node"), a ValueError is | ||
raised to indicate the incorrect usage. | ||
""" | ||
|
||
def __init__(self, node_name: str, node_type: str): | ||
""" | ||
Initialize the node with a unique identifier and a specified node type. | ||
Args: | ||
node_name (str): The unique identifier name for the node. | ||
node_type (str): The type of the node, limited to "node" or "conditional_node". | ||
Raises: | ||
ValueError: If node_type is not "node" or "conditional_node". | ||
""" | ||
self.node_name = node_name | ||
if node_type not in ["node", "conditional_node"]: | ||
raise ValueError(f"node_type must be 'node' or 'conditional_node', got '{node_type}'") | ||
self.node_type = node_type | ||
|
||
@abstractmethod | ||
def execute(self, state): | ||
""" | ||
Execute the node's logic and return the updated state. | ||
:param state: The current state of the graph. | ||
:return: The updated state after executing this node. | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from .base_node import BaseNode | ||
|
||
class ConditionalNode(BaseNode): | ||
""" | ||
A node that determines the next step in the graph's execution flow based on | ||
the presence and content of a specified key in the graph's state. It extends | ||
the BaseNode by adding condition-based logic to the execution process. | ||
This node type is used to implement branching logic within the graph, allowing | ||
for dynamic paths based on the data available in the current state. | ||
Attributes: | ||
key_name (str): The name of the key in the state to check for its presence. | ||
next_nodes (list): A list of two node instances. The first node is chosen | ||
for execution if the key exists and has a non-empty value, | ||
and the second node is chosen if the key does not exist or | ||
is empty. | ||
Args: | ||
key_name (str): The name of the key to check in the graph's state. This is | ||
used to determine the path the graph's execution should take. | ||
next_nodes (list): A list containing exactly two node instances, specifying | ||
the next nodes to execute based on the condition's outcome. | ||
node_name (str, optional): The unique identifier name for the node. Defaults | ||
to "ConditionalNode". | ||
Raises: | ||
ValueError: If next_nodes does not contain exactly two elements, indicating | ||
a misconfiguration in specifying the conditional paths. | ||
""" | ||
|
||
def __init__(self, key_name, next_nodes, node_name="ConditionalNode"): | ||
""" | ||
Initializes the node with the key to check and the next node names based on the condition. | ||
Args: | ||
key_name (str): The name of the key to check in the state. | ||
next_nodes (list): A list containing exactly two names of the next nodes. | ||
The first is used if the key exists, the second if it does not. | ||
Raises: | ||
ValueError: If next_nodes does not contain exactly two elements. | ||
""" | ||
|
||
super().__init__(node_name, "conditional_node") | ||
self.key_name = key_name | ||
if len(next_nodes) != 2: | ||
raise ValueError("next_nodes must contain exactly two elements.") | ||
self.next_nodes = next_nodes | ||
|
||
def execute(self, state): | ||
""" | ||
Checks if the specified key is present in the state and decides the next node accordingly. | ||
Args: | ||
state (dict): The current state of the graph. | ||
Returns: | ||
str: The name of the next node to execute based on the presence of the key. | ||
""" | ||
|
||
if self.key_name in state.get("keys", {}) and len(state["keys"][self.key_name]) > 0: | ||
return self.next_nodes[0].node_name | ||
else: | ||
return self.next_nodes[1].node_name |
Oops, something went wrong.