Skip to content

Commit

Permalink
fix(AbstractGraph): model selection bug
Browse files Browse the repository at this point in the history
  • Loading branch information
f-aguzzi committed Aug 28, 2024
1 parent 4eccc76 commit 4f120e2
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 24 deletions.
19 changes: 11 additions & 8 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,15 @@ def _create_llm(self, llm_config: dict) -> object:
"ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai",
"hugging_face", "deepseek", "ernie", "fireworks"}

split_model_provider = llm_params["model"].split("/")
split_model_provider = llm_params["model"].split("/", 1)
llm_params["model_provider"] = split_model_provider[0]
llm_params["model"] = split_model_provider[1:]
llm_params["model"] = split_model_provider[1]

if llm_params["model_provider"] not in known_providers:
raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.")

try:
self.model_token = models_tokens[llm_params["model_provider"]].get(llm_params["model"][0])
self.model_token = models_tokens[llm_params["model_provider"]][llm_params["model"]]
except KeyError:
print("Model not found, using default token size (8192)")
self.model_token = 8192
Expand All @@ -150,18 +150,21 @@ def _create_llm(self, llm_config: dict) -> object:
warnings.simplefilter("ignore")
return init_chat_model(**llm_params)
else:
if "deepseek" in llm_params["model"]:
if llm_params["model_provider"] == "deepseek":
return DeepSeek(**llm_params)

if "ernie" in llm_params["model"]:
if llm_params["model_provider"] == "ernie":
from langchain_community.chat_models import ErnieBotChat
return ErnieBotChat(**llm_params)

if "oneapi" in llm_params["model"]:
if llm_params["model_provider"] == "oneapi":
return OneApi(**llm_params)

if "nvidia" in llm_params["model"]:
from langchain_nvidia_ai_endpoints import ChatNVIDIA
if llm_params["model_provider"] == "nvidia":
try:
from langchain_nvidia_ai_endpoints import ChatNVIDIA
except ImportError:
raise ImportError("The langchain_nvidia_ai_endpoints module is not installed. Please install it using `pip install langchain_nvidia_ai_endpoints`.")
return ChatNVIDIA(**llm_params)

except Exception as e:
Expand Down
83 changes: 67 additions & 16 deletions tests/graphs/abstract_graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,80 @@
"""
import pytest
from unittest.mock import patch
from scrapegraphai.graphs import AbstractGraph
from scrapegraphai.graphs import AbstractGraph, BaseGraph
from scrapegraphai.nodes import (
FetchNode,
ParseNode
)
from scrapegraphai.models import OneApi, DeepSeek
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_community.chat_models import ChatOllama
from langchain_google_genai import ChatGoogleGenerativeAI



class TestGraph(AbstractGraph):
def __init__(self, prompt: str, config: dict):
super().__init__(prompt, config)

def _create_graph(self) -> BaseGraph:
fetch_node = FetchNode(
input="url| local_dir",
output=["doc", "link_urls", "img_urls"],
node_config={
"llm_model": self.llm_model,
"force": self.config.get("force", False),
"cut": self.config.get("cut", True),
"loader_kwargs": self.config.get("loader_kwargs", {}),
"browser_base": self.config.get("browser_base")
}
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token
}
)
return BaseGraph(
nodes=[
fetch_node,
parse_node
],
edges=[
(fetch_node, parse_node),
],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)

def run(self) -> str:
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)

return self.final_state.get("answer", "No answer found.")


class TestAbstractGraph:
@pytest.mark.parametrize("llm_config, expected_model", [
({"model": "openai/gpt-3.5-turbo"}, "ChatOpenAI"),
({"model": "azure_openai/gpt-3.5-turbo"}, "AzureChatOpenAI"),
({"model": "google_genai/gemini-pro"}, "ChatGoogleGenerativeAI"),
({"model": "google_vertexai/chat-bison"}, "ChatVertexAI"),
({"model": "ollama/llama2"}, "Ollama"),
({"model": "oneapi/text-davinci-003"}, "OneApi"),
({"model": "nvidia/clara-instant-1-base"}, "ChatNVIDIA"),
({"model": "deepseek/deepseek-coder-6.7b-instruct"}, "DeepSeek"),
({"model": "ernie/ernie-bot"}, "ErnieBotChat"),
({"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-randomtest001"}, ChatOpenAI),
({
"model": "azure_openai/gpt-3.5-turbo",
"api_key": "random-api-key",
"api_version": "no version",
"azure_endpoint": "https://www.example.com/"},
AzureChatOpenAI),
({"model": "google_genai/gemini-pro", "google_api_key": "google-key-test"}, ChatGoogleGenerativeAI),
({"model": "ollama/llama2"}, ChatOllama),
({"model": "oneapi/qwen-turbo"}, OneApi),
({"model": "deepseek/deepseek-coder"}, DeepSeek),
])

def test_create_llm(self, llm_config, expected_model):
graph = AbstractGraph("Test prompt", {"llm": llm_config})
graph = TestGraph("Test prompt", {"llm": llm_config})
assert isinstance(graph.llm_model, expected_model)

def test_create_llm_unknown_provider(self):
with pytest.raises(ValueError):
AbstractGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}})
TestGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}})

def test_create_llm_error(self):
with patch("your_module.init_chat_model", side_effect=Exception("Test error")):
with pytest.raises(Exception):
AbstractGraph("Test prompt", {"llm": {"model": "openai/gpt-3.5-turbo"}})

0 comments on commit 4f120e2

Please sign in to comment.