diff --git a/libs/core/langchain_core/_api/beta_decorator.py b/libs/core/langchain_core/_api/beta_decorator.py index d27a27a1c834c..064bca94a8574 100644 --- a/libs/core/langchain_core/_api/beta_decorator.py +++ b/libs/core/langchain_core/_api/beta_decorator.py @@ -154,7 +154,7 @@ def warn_if_direct_instance( _name = _name or obj.fget.__qualname__ old_doc = obj.__doc__ - class _beta_property(property): + class _BetaProperty(property): """A beta property.""" def __init__(self, fget=None, fset=None, fdel=None, doc=None): @@ -185,7 +185,7 @@ def __set_name__(self, owner, set_name): def finalize(wrapper: Callable[..., Any], new_doc: str) -> Any: """Finalize the property.""" - return _beta_property( + return _BetaProperty( fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc ) diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py index 7f46ff2802f90..ba40c201e786a 100644 --- a/libs/core/langchain_core/_api/deprecation.py +++ b/libs/core/langchain_core/_api/deprecation.py @@ -265,7 +265,7 @@ def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: _name = _name or cast(Union[Type, Callable], obj.fget).__qualname__ old_doc = obj.__doc__ - class _deprecated_property(property): + class _DeprecatedProperty(property): """A deprecated property.""" def __init__(self, fget=None, fset=None, fdel=None, doc=None): # type: ignore[no-untyped-def] @@ -298,7 +298,7 @@ def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: """Finalize the property.""" return cast( T, - _deprecated_property( + _DeprecatedProperty( fget=obj.fget, fset=obj.fset, fdel=obj.fdel, doc=new_doc ), ) diff --git a/libs/core/langchain_core/exceptions.py b/libs/core/langchain_core/exceptions.py index 3c3f278511245..7c60ccfa4db38 100644 --- a/libs/core/langchain_core/exceptions.py +++ b/libs/core/langchain_core/exceptions.py @@ -3,7 +3,7 @@ from typing import Any, Optional -class LangChainException(Exception): +class LangChainException(Exception): # noqa: N818 """General LangChain exception.""" @@ -11,7 +11,7 @@ class TracerException(LangChainException): """Base class for exceptions in tracers module.""" -class OutputParserException(ValueError, LangChainException): +class OutputParserException(ValueError, LangChainException): # noqa: N818 """Exception that output parsers should raise to signify a parsing error. This exists to differentiate parsing errors from other code or execution errors diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index 0705f9e9f582b..06d6aa9ef500e 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -19,7 +19,7 @@ ) from pydantic import BaseModel, ConfigDict, Field, field_validator -from typing_extensions import TypeAlias, TypedDict +from typing_extensions import TypeAlias, TypedDict, override from langchain_core._api import deprecated from langchain_core.messages import ( @@ -148,6 +148,7 @@ def set_verbose(cls, verbose: Optional[bool]) -> bool: return verbose @property + @override def InputType(self) -> TypeAlias: """Get the input type for this runnable.""" from langchain_core.prompt_values import ( diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index f761a8ed69489..477e2cace7e27 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -30,6 +30,7 @@ Field, model_validator, ) +from typing_extensions import override from langchain_core._api import deprecated from langchain_core.caches import BaseCache @@ -255,6 +256,7 @@ def _serialized(self) -> dict[str, Any]: # --- Runnable methods --- @property + @override def OutputType(self) -> Any: """Get the output type for this runnable.""" return AnyMessage diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index d30249337d711..861b2abfb92db 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -37,6 +37,7 @@ stop_after_attempt, wait_exponential, ) +from typing_extensions import override from langchain_core._api import deprecated from langchain_core.caches import BaseCache @@ -324,6 +325,7 @@ def _serialized(self) -> dict[str, Any]: # --- Runnable methods --- @property + @override def OutputType(self) -> Type[str]: """Get the input type for this runnable.""" return str diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index 6e1420d9b714c..0aff251cc96c5 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -13,6 +13,8 @@ Union, ) +from typing_extensions import override + from langchain_core.language_models import LanguageModelOutput from langchain_core.messages import AnyMessage, BaseMessage from langchain_core.outputs import ChatGeneration, Generation @@ -66,11 +68,13 @@ class BaseGenerationOutputParser( """Base class to parse the output of an LLM call.""" @property + @override def InputType(self) -> Any: """Return the input type for the parser.""" return Union[str, AnyMessage] @property + @override def OutputType(self) -> Type[T]: """Return the output type for the parser.""" # even though mypy complains this isn't valid, @@ -151,11 +155,13 @@ def _type(self) -> str: """ # noqa: E501 @property + @override def InputType(self) -> Any: """Return the input type for the parser.""" return Union[str, AnyMessage] @property + @override def OutputType(self) -> Type[T]: """Return the output type for the parser. diff --git a/libs/core/langchain_core/output_parsers/pydantic.py b/libs/core/langchain_core/output_parsers/pydantic.py index 8d48e98b2b349..76324afa16a43 100644 --- a/libs/core/langchain_core/output_parsers/pydantic.py +++ b/libs/core/langchain_core/output_parsers/pydantic.py @@ -3,7 +3,7 @@ import pydantic from pydantic import SkipValidation -from typing_extensions import Annotated +from typing_extensions import Annotated, override from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import JsonOutputParser @@ -108,6 +108,7 @@ def _type(self) -> str: return "pydantic" @property + @override def OutputType(self) -> Type[TBaseModel]: """Return the pydantic model.""" return self.pydantic_object diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index b8355d97c3e08..b7df58d3963e3 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -1,6 +1,6 @@ import re import xml -import xml.etree.ElementTree as ET +import xml.etree.ElementTree as ET # noqa: N817 from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union from xml.etree.ElementTree import TreeBuilder @@ -45,14 +45,14 @@ def __init__(self, parser: Literal["defusedxml", "xml"]) -> None: """ if parser == "defusedxml": try: - from defusedxml import ElementTree as DET # type: ignore + import defusedxml # type: ignore except ImportError as e: raise ImportError( "defusedxml is not installed. " "Please install it to use the defusedxml parser." "You can install it with `pip install defusedxml` " ) from e - _parser = DET.DefusedXMLParser(target=TreeBuilder()) + _parser = defusedxml.ElementTree.DefusedXMLParser(target=TreeBuilder()) else: _parser = None self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser) @@ -188,7 +188,7 @@ def parse(self, text: str) -> Dict[str, Union[str, List[Any]]]: # likely if you're reading this you can move them to the top of the file if self.parser == "defusedxml": try: - from defusedxml import ElementTree as DET # type: ignore + import defusedxml # type: ignore except ImportError as e: raise ImportError( "defusedxml is not installed. " @@ -196,9 +196,9 @@ def parse(self, text: str) -> Dict[str, Union[str, List[Any]]]: "You can install it with `pip install defusedxml`" "See https://github.com/tiran/defusedxml for more details" ) from e - _ET = DET # Use the defusedxml parser + _et = defusedxml.ElementTree # Use the defusedxml parser else: - _ET = ET # Use the standard library parser + _et = ET # Use the standard library parser match = re.search(r"```(xml)?(.*)```", text, re.DOTALL) if match is not None: diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index 62ce4cece8117..a4ae8dccaea84 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -20,7 +20,7 @@ import yaml from pydantic import BaseModel, ConfigDict, Field, model_validator -from typing_extensions import Self +from typing_extensions import Self, override from langchain_core.load import dumpd from langchain_core.output_parsers.base import BaseOutputParser @@ -109,6 +109,7 @@ def _serialized(self) -> dict[str, Any]: return dumpd(self) @property + @override def OutputType(self) -> Any: """Return the output type of the prompt.""" return Union[StringPromptValue, ChatPromptValueConcrete] diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index a61f3c79a415b..3704403054a32 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -37,7 +37,7 @@ ) from pydantic import BaseModel, ConfigDict, Field, RootModel -from typing_extensions import Literal, get_args, get_type_hints +from typing_extensions import Literal, get_args, get_type_hints, override from langchain_core._api import beta_decorator from langchain_core.load.serializable import ( @@ -273,7 +273,7 @@ def get_name( return name_ @property - def InputType(self) -> Type[Input]: + def InputType(self) -> Type[Input]: # noqa: N802 """The type of input this Runnable accepts specified as a type annotation.""" # First loop through all parent classes and if any of them is # a pydantic model, we will pick up the generic parameterization @@ -298,7 +298,7 @@ def InputType(self) -> Type[Input]: ) @property - def OutputType(self) -> Type[Output]: + def OutputType(self) -> Type[Output]: # noqa: N802 """The type of output this Runnable produces specified as a type annotation.""" # First loop through bases -- this will help generic # any pydantic models. @@ -2804,11 +2804,13 @@ def is_lc_serializable(cls) -> bool: ) @property + @override def InputType(self) -> Type[Input]: """The type of the input to the Runnable.""" return self.first.InputType @property + @override def OutputType(self) -> Type[Output]: """The type of the output of the Runnable.""" return self.last.OutputType @@ -3557,6 +3559,7 @@ def get_name( return super().get_name(suffix, name=name) @property + @override def InputType(self) -> Any: """The type of the input to the Runnable.""" for step in self.steps__.values(): @@ -4050,6 +4053,7 @@ def __init__( self.name = "RunnableGenerator" @property + @override def InputType(self) -> Any: func = getattr(self, "_transform", None) or self._atransform try: @@ -4086,6 +4090,7 @@ def get_input_schema( ) @property + @override def OutputType(self) -> Any: func = getattr(self, "_transform", None) or self._atransform try: @@ -4331,6 +4336,7 @@ def __init__( pass @property + @override def InputType(self) -> Any: """The type of the input to this Runnable.""" func = getattr(self, "func", None) or self.afunc @@ -4390,6 +4396,7 @@ def get_input_schema( return super().get_input_schema(config) @property + @override def OutputType(self) -> Any: """The type of the output of this Runnable as a type annotation. @@ -4939,6 +4946,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): ) @property + @override def InputType(self) -> Any: return List[self.bound.InputType] # type: ignore[name-defined] @@ -4962,6 +4970,7 @@ def get_input_schema( ) @property + @override def OutputType(self) -> Type[List[Output]]: return List[self.bound.OutputType] # type: ignore[name-defined] @@ -5255,6 +5264,7 @@ def get_name( return self.bound.get_name(suffix, name=name) @property + @override def InputType(self) -> Type[Input]: return ( cast(Type[Input], self.custom_input_type) @@ -5263,6 +5273,7 @@ def InputType(self) -> Type[Input]: ) @property + @override def OutputType(self) -> Type[Output]: return ( cast(Type[Output], self.custom_output_type) diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py index fdcd7d0e95928..c6aa7a1f44222 100644 --- a/libs/core/langchain_core/runnables/configurable.py +++ b/libs/core/langchain_core/runnables/configurable.py @@ -22,6 +22,7 @@ from weakref import WeakValueDictionary from pydantic import BaseModel, ConfigDict +from typing_extensions import override from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.config import ( @@ -74,10 +75,12 @@ def get_lc_namespace(cls) -> List[str]: return ["langchain", "schema", "runnable"] @property + @override def InputType(self) -> Type[Input]: return self.default.InputType @property + @override def OutputType(self) -> Type[Output]: return self.default.OutputType diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index ef97aaf770c39..bb67e2de11854 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -19,6 +19,7 @@ ) from pydantic import BaseModel, ConfigDict +from typing_extensions import override from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.config import ( @@ -112,10 +113,12 @@ def when_all_is_lost(inputs): ) @property + @override def InputType(self) -> Type[Input]: return self.runnable.InputType @property + @override def OutputType(self) -> Type[Output]: return self.runnable.OutputType diff --git a/libs/core/langchain_core/runnables/graph_ascii.py b/libs/core/langchain_core/runnables/graph_ascii.py index 46677213f81c3..7e9225c275d99 100644 --- a/libs/core/langchain_core/runnables/graph_ascii.py +++ b/libs/core/langchain_core/runnables/graph_ascii.py @@ -245,27 +245,27 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str: # NOTE: coordinates might me negative, so we need to shift # everything to the positive plane before we actually draw it. - Xs = [] - Ys = [] + xlist = [] + ylist = [] sug = _build_sugiyama_layout(vertices, edges) for vertex in sug.g.sV: # NOTE: moving boxes w/2 to the left - Xs.append(vertex.view.xy[0] - vertex.view.w / 2.0) - Xs.append(vertex.view.xy[0] + vertex.view.w / 2.0) - Ys.append(vertex.view.xy[1]) - Ys.append(vertex.view.xy[1] + vertex.view.h) + xlist.append(vertex.view.xy[0] - vertex.view.w / 2.0) + xlist.append(vertex.view.xy[0] + vertex.view.w / 2.0) + ylist.append(vertex.view.xy[1]) + ylist.append(vertex.view.xy[1] + vertex.view.h) for edge in sug.g.sE: for x, y in edge.view._pts: - Xs.append(x) - Ys.append(y) + xlist.append(x) + ylist.append(y) - minx = min(Xs) - miny = min(Ys) - maxx = max(Xs) - maxy = max(Ys) + minx = min(xlist) + miny = min(ylist) + maxx = max(xlist) + maxy = max(ylist) canvas_cols = int(math.ceil(math.ceil(maxx) - math.floor(minx))) + 1 canvas_lines = int(round(maxy - miny)) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index da297ec3bca00..ef2b7ea06b130 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -14,6 +14,7 @@ ) from pydantic import BaseModel +from typing_extensions import override from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.load.load import load @@ -398,6 +399,7 @@ def get_input_schema( ) @property + @override def OutputType(self) -> Type[Output]: output_type = self._history_chain.OutputType return output_type diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index a613533674d73..288c426848531 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -22,6 +22,7 @@ ) from pydantic import BaseModel, RootModel +from typing_extensions import override from langchain_core.runnables.base import ( Other, @@ -199,10 +200,12 @@ def get_lc_namespace(cls) -> List[str]: return ["langchain", "schema", "runnable"] @property + @override def InputType(self) -> Any: return self.input_type or Any @property + @override def OutputType(self) -> Any: return self.input_type or Any diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 20d82be3fd62b..be66e1ce1d860 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -29,7 +29,7 @@ Union, ) -from typing_extensions import TypeGuard +from typing_extensions import TypeGuard, override from langchain_core.runnables.schema import StreamEvent @@ -136,6 +136,7 @@ def __init__(self, name: str, keys: Set[str]) -> None: self.name = name self.keys = keys + @override def visit_Subscript(self, node: ast.Subscript) -> Any: """Visit a subscript node. @@ -155,6 +156,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: # we've found a subscript access on the name we're looking for self.keys.add(node.slice.value) + @override def visit_Call(self, node: ast.Call) -> Any: """Visit a call node. @@ -183,6 +185,7 @@ class IsFunctionArgDict(ast.NodeVisitor): def __init__(self) -> None: self.keys: Set[str] = set() + @override def visit_Lambda(self, node: ast.Lambda) -> Any: """Visit a lambda function. @@ -197,6 +200,7 @@ def visit_Lambda(self, node: ast.Lambda) -> Any: input_arg_name = node.args.args[0].arg IsLocalDict(input_arg_name, self.keys).visit(node.body) + @override def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: """Visit a function definition. @@ -211,6 +215,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: input_arg_name = node.args.args[0].arg IsLocalDict(input_arg_name, self.keys).visit(node) + @override def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: """Visit an async function definition. @@ -233,6 +238,7 @@ def __init__(self) -> None: self.loads: Set[str] = set() self.stores: Set[str] = set() + @override def visit_Name(self, node: ast.Name) -> Any: """Visit a name node. @@ -247,6 +253,7 @@ def visit_Name(self, node: ast.Name) -> Any: elif isinstance(node.ctx, ast.Store): self.stores.add(node.id) + @override def visit_Attribute(self, node: ast.Attribute) -> Any: """Visit an attribute node. @@ -273,6 +280,7 @@ class FunctionNonLocals(ast.NodeVisitor): def __init__(self) -> None: self.nonlocals: Set[str] = set() + @override def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: """Visit a function definition. @@ -286,6 +294,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: visitor.visit(node) self.nonlocals.update(visitor.loads - visitor.stores) + @override def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: """Visit an async function definition. @@ -299,6 +308,7 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: visitor.visit(node) self.nonlocals.update(visitor.loads - visitor.stores) + @override def visit_Lambda(self, node: ast.Lambda) -> Any: """Visit a lambda function. @@ -321,6 +331,7 @@ def __init__(self) -> None: self.source: Optional[str] = None self.count = 0 + @override def visit_Lambda(self, node: ast.Lambda) -> Any: """Visit a lambda function. diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index f528a80e565b2..3a4c6d62eaceb 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -315,7 +315,7 @@ def create_schema_from_function( ) -class ToolException(Exception): +class ToolException(Exception): # noqa: N818 """Optional exception that tool throws when execution error occurs. When this exception is thrown, the agent will not stop working, diff --git a/libs/core/langchain_core/tracers/langchain_v1.py b/libs/core/langchain_core/tracers/langchain_v1.py index bf1237d66abbe..ea1c882ea67da 100644 --- a/libs/core/langchain_core/tracers/langchain_v1.py +++ b/libs/core/langchain_core/tracers/langchain_v1.py @@ -9,7 +9,7 @@ def get_headers(*args: Any, **kwargs: Any) -> Any: ) -def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: +def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: # noqa: N802 """Throw an error because this has been replaced by LangChainTracer.""" raise RuntimeError( "LangChainTracerV1 is no longer supported. Please use LangChainTracer instead." diff --git a/libs/core/langchain_core/tracers/schemas.py b/libs/core/langchain_core/tracers/schemas.py index 5f6c8ed8c6b7d..c0573fb6c19de 100644 --- a/libs/core/langchain_core/tracers/schemas.py +++ b/libs/core/langchain_core/tracers/schemas.py @@ -18,7 +18,7 @@ @deprecated("0.1.0", alternative="Use string instead.", removal="1.0") -def RunTypeEnum() -> Type[RunTypeEnumDep]: +def RunTypeEnum() -> Type[RunTypeEnumDep]: # noqa: N802 """RunTypeEnum.""" warnings.warn( "RunTypeEnum is deprecated. Please directly use a string instead" diff --git a/libs/core/langchain_core/utils/aiter.py b/libs/core/langchain_core/utils/aiter.py index b2bc92699466b..6478134457aec 100644 --- a/libs/core/langchain_core/utils/aiter.py +++ b/libs/core/langchain_core/utils/aiter.py @@ -238,7 +238,7 @@ async def aclose(self) -> None: atee = Tee -class aclosing(AbstractAsyncContextManager): +class aclosing(AbstractAsyncContextManager): # noqa: N801 """Async context manager for safely finalizing an asynchronously cleaned-up resource such as an async generator, calling its ``aclose()`` method. diff --git a/libs/core/langchain_core/vectorstores/utils.py b/libs/core/langchain_core/vectorstores/utils.py index 73b9aac9cbe22..a850ce947410c 100644 --- a/libs/core/langchain_core/vectorstores/utils.py +++ b/libs/core/langchain_core/vectorstores/utils.py @@ -17,12 +17,12 @@ logger = logging.getLogger(__name__) -def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: +def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray: """Row-wise cosine similarity between two equal-width matrices. Args: - X: A matrix of shape (n, m). - Y: A matrix of shape (k, m). + x: A matrix of shape (n, m). + y: A matrix of shape (k, m). Returns: A matrix of shape (n, k) where each element (i, j) is the cosine similarity @@ -40,33 +40,33 @@ def _cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: "Please install numpy with `pip install numpy`." ) from e - if len(X) == 0 or len(Y) == 0: + if len(x) == 0 or len(y) == 0: return np.array([]) - X = np.array(X) - Y = np.array(Y) - if X.shape[1] != Y.shape[1]: + x = np.array(x) + y = np.array(y) + if x.shape[1] != y.shape[1]: raise ValueError( - f"Number of columns in X and Y must be the same. X has shape {X.shape} " - f"and Y has shape {Y.shape}." + f"Number of columns in X and Y must be the same. X has shape {x.shape} " + f"and Y has shape {y.shape}." ) try: import simsimd as simd # type: ignore[import-not-found] - X = np.array(X, dtype=np.float32) - Y = np.array(Y, dtype=np.float32) - Z = 1 - np.array(simd.cdist(X, Y, metric="cosine")) - return Z + x = np.array(x, dtype=np.float32) + y = np.array(y, dtype=np.float32) + z = 1 - np.array(simd.cdist(x, y, metric="cosine")) + return z except ImportError: logger.debug( "Unable to import simsimd, defaulting to NumPy implementation. If you want " "to use simsimd please install with `pip install simsimd`." ) - X_norm = np.linalg.norm(X, axis=1) - Y_norm = np.linalg.norm(Y, axis=1) + x_norm = np.linalg.norm(x, axis=1) + y_norm = np.linalg.norm(y, axis=1) # Ignore divide by zero errors run time warnings as those are handled below. with np.errstate(divide="ignore", invalid="ignore"): - similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) + similarity = np.dot(x, y.T) / np.outer(x_norm, y_norm) similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 return similarity diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index b185380b7e016..a5bf8ead1b4dd 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -46,9 +46,17 @@ pydantic = [ [tool.poetry.extras] [tool.ruff.lint] -select = ["B", "E", "F", "I", "T201", "UP"] +select = ["B", "E", "F", "I", "N", "T201", "UP"] ignore = ["UP006", "UP007"] +[tool.ruff.lint.pep8-naming] +classmethod-decorators = [ + "classmethod", + "langchain_core.utils.pydantic.pre_init", + "pydantic.field_validator", + "pydantic.v1.root_validator", +] + [tool.coverage.run] omit = ["tests/*"] diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index 873ca5deec3be..a6ade13c3442b 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -9,9 +9,9 @@ from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from tests.unit_tests.stubs import ( - _AnyIdAIMessage, - _AnyIdAIMessageChunk, - _AnyIdHumanMessage, + _any_id_ai_message, + _any_id_ai_message_chunk, + _any_id_human_message, ) @@ -20,11 +20,11 @@ def test_generic_fake_chat_model_invoke() -> None: infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")]) model = GenericFakeChatModel(messages=infinite_cycle) response = model.invoke("meow") - assert response == _AnyIdAIMessage(content="hello") + assert response == _any_id_ai_message(content="hello") response = model.invoke("kitty") - assert response == _AnyIdAIMessage(content="goodbye") + assert response == _any_id_ai_message(content="goodbye") response = model.invoke("meow") - assert response == _AnyIdAIMessage(content="hello") + assert response == _any_id_ai_message(content="hello") async def test_generic_fake_chat_model_ainvoke() -> None: @@ -32,11 +32,11 @@ async def test_generic_fake_chat_model_ainvoke() -> None: infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")]) model = GenericFakeChatModel(messages=infinite_cycle) response = await model.ainvoke("meow") - assert response == _AnyIdAIMessage(content="hello") + assert response == _any_id_ai_message(content="hello") response = await model.ainvoke("kitty") - assert response == _AnyIdAIMessage(content="goodbye") + assert response == _any_id_ai_message(content="goodbye") response = await model.ainvoke("meow") - assert response == _AnyIdAIMessage(content="hello") + assert response == _any_id_ai_message(content="hello") async def test_generic_fake_chat_model_stream() -> None: @@ -49,17 +49,17 @@ async def test_generic_fake_chat_model_stream() -> None: model = GenericFakeChatModel(messages=infinite_cycle) chunks = [chunk async for chunk in model.astream("meow")] assert chunks == [ - _AnyIdAIMessageChunk(content="hello"), - _AnyIdAIMessageChunk(content=" "), - _AnyIdAIMessageChunk(content="goodbye"), + _any_id_ai_message_chunk(content="hello"), + _any_id_ai_message_chunk(content=" "), + _any_id_ai_message_chunk(content="goodbye"), ] assert len({chunk.id for chunk in chunks}) == 1 chunks = [chunk for chunk in model.stream("meow")] assert chunks == [ - _AnyIdAIMessageChunk(content="hello"), - _AnyIdAIMessageChunk(content=" "), - _AnyIdAIMessageChunk(content="goodbye"), + _any_id_ai_message_chunk(content="hello"), + _any_id_ai_message_chunk(content=" "), + _any_id_ai_message_chunk(content="goodbye"), ] assert len({chunk.id for chunk in chunks}) == 1 @@ -69,8 +69,8 @@ async def test_generic_fake_chat_model_stream() -> None: model = GenericFakeChatModel(messages=cycle([message])) chunks = [chunk async for chunk in model.astream("meow")] assert chunks == [ - _AnyIdAIMessageChunk(content="", additional_kwargs={"foo": 42}), - _AnyIdAIMessageChunk(content="", additional_kwargs={"bar": 24}), + _any_id_ai_message_chunk(content="", additional_kwargs={"foo": 42}), + _any_id_ai_message_chunk(content="", additional_kwargs={"bar": 24}), ] assert len({chunk.id for chunk in chunks}) == 1 @@ -88,19 +88,19 @@ async def test_generic_fake_chat_model_stream() -> None: chunks = [chunk async for chunk in model.astream("meow")] assert chunks == [ - _AnyIdAIMessageChunk( + _any_id_ai_message_chunk( content="", additional_kwargs={"function_call": {"name": "move_file"}} ), - _AnyIdAIMessageChunk( + _any_id_ai_message_chunk( content="", additional_kwargs={ "function_call": {"arguments": '{\n "source_path": "foo"'}, }, ), - _AnyIdAIMessageChunk( + _any_id_ai_message_chunk( content="", additional_kwargs={"function_call": {"arguments": ","}} ), - _AnyIdAIMessageChunk( + _any_id_ai_message_chunk( content="", additional_kwargs={ "function_call": {"arguments": '\n "destination_path": "bar"\n}'}, @@ -138,9 +138,9 @@ async def test_generic_fake_chat_model_astream_log() -> None: ] final = log_patches[-1] assert final.state["streamed_output"] == [ - _AnyIdAIMessageChunk(content="hello"), - _AnyIdAIMessageChunk(content=" "), - _AnyIdAIMessageChunk(content="goodbye"), + _any_id_ai_message_chunk(content="hello"), + _any_id_ai_message_chunk(content=" "), + _any_id_ai_message_chunk(content="goodbye"), ] assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1 @@ -189,9 +189,9 @@ async def on_llm_new_token( # New model results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]})) assert results == [ - _AnyIdAIMessageChunk(content="hello"), - _AnyIdAIMessageChunk(content=" "), - _AnyIdAIMessageChunk(content="goodbye"), + _any_id_ai_message_chunk(content="hello"), + _any_id_ai_message_chunk(content=" "), + _any_id_ai_message_chunk(content="goodbye"), ] assert tokens == ["hello", " ", "goodbye"] assert len({chunk.id for chunk in results}) == 1 @@ -200,6 +200,8 @@ async def on_llm_new_token( def test_chat_model_inputs() -> None: fake = ParrotFakeChatModel() - assert fake.invoke("hello") == _AnyIdHumanMessage(content="hello") - assert fake.invoke([("ai", "blah")]) == _AnyIdAIMessage(content="blah") - assert fake.invoke([AIMessage(content="blah")]) == _AnyIdAIMessage(content="blah") + assert fake.invoke("hello") == _any_id_human_message(content="hello") + assert fake.invoke([("ai", "blah")]) == _any_id_ai_message(content="blah") + assert fake.invoke([AIMessage(content="blah")]) == _any_id_ai_message( + content="blah" + ) diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index f14c2a1b8d04d..a9006a647a4b0 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -26,7 +26,7 @@ FakeAsyncCallbackHandler, FakeCallbackHandler, ) -from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk +from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk @pytest.fixture @@ -146,10 +146,10 @@ def _llm_type(self) -> str: model = ModelWithGenerate() chunks = [chunk for chunk in model.stream("anything")] - assert chunks == [_AnyIdAIMessage(content="hello")] + assert chunks == [_any_id_ai_message(content="hello")] chunks = [chunk async for chunk in model.astream("anything")] - assert chunks == [_AnyIdAIMessage(content="hello")] + assert chunks == [_any_id_ai_message(content="hello")] async def test_astream_implementation_fallback_to_stream() -> None: @@ -184,15 +184,15 @@ def _llm_type(self) -> str: model = ModelWithSyncStream() chunks = [chunk for chunk in model.stream("anything")] assert chunks == [ - _AnyIdAIMessageChunk(content="a"), - _AnyIdAIMessageChunk(content="b"), + _any_id_ai_message_chunk(content="a"), + _any_id_ai_message_chunk(content="b"), ] assert len({chunk.id for chunk in chunks}) == 1 assert type(model)._astream == BaseChatModel._astream astream_chunks = [chunk async for chunk in model.astream("anything")] assert astream_chunks == [ - _AnyIdAIMessageChunk(content="a"), - _AnyIdAIMessageChunk(content="b"), + _any_id_ai_message_chunk(content="a"), + _any_id_ai_message_chunk(content="b"), ] assert len({chunk.id for chunk in astream_chunks}) == 1 @@ -229,8 +229,8 @@ def _llm_type(self) -> str: model = ModelWithAsyncStream() chunks = [chunk async for chunk in model.astream("anything")] assert chunks == [ - _AnyIdAIMessageChunk(content="a"), - _AnyIdAIMessageChunk(content="b"), + _any_id_ai_message_chunk(content="a"), + _any_id_ai_message_chunk(content="b"), ] assert len({chunk.id for chunk in chunks}) == 1 diff --git a/libs/core/tests/unit_tests/prompts/test_loading.py b/libs/core/tests/unit_tests/prompts/test_loading.py index d7092aa94c630..76af9259e80b2 100644 --- a/libs/core/tests/unit_tests/prompts/test_loading.py +++ b/libs/core/tests/unit_tests/prompts/test_loading.py @@ -25,7 +25,7 @@ def change_directory(dir: Path) -> Iterator: os.chdir(origin) -def test_loading_from_YAML() -> None: +def test_loading_from_yaml() -> None: """Test loading from yaml file.""" prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.yaml") expected_prompt = PromptTemplate( @@ -36,7 +36,7 @@ def test_loading_from_YAML() -> None: assert prompt == expected_prompt -def test_loading_from_JSON() -> None: +def test_loading_from_json() -> None: """Test loading from json file.""" prompt = load_prompt(EXAMPLE_DIR / "simple_prompt.json") expected_prompt = PromptTemplate( @@ -46,14 +46,14 @@ def test_loading_from_JSON() -> None: assert prompt == expected_prompt -def test_loading_jinja_from_JSON() -> None: +def test_loading_jinja_from_json() -> None: """Test that loading jinja2 format prompts from JSON raises ValueError.""" prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.json" with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"): load_prompt(prompt_path) -def test_loading_jinja_from_YAML() -> None: +def test_loading_jinja_from_yaml() -> None: """Test that loading jinja2 format prompts from YAML raises ValueError.""" prompt_path = EXAMPLE_DIR / "jinja_injection_prompt.yaml" with pytest.raises(ValueError, match=".*can lead to arbitrary code execution.*"): diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 9f75682cdc31f..2699e6cab48f8 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -2,6 +2,7 @@ from pydantic import BaseModel from syrupy import SnapshotAssertion +from typing_extensions import override from langchain_core.language_models import FakeListLLM from langchain_core.output_parsers.list import CommaSeparatedListOutputParser @@ -353,9 +354,11 @@ def test_runnable_get_graph_with_invalid_input_type() -> None: class InvalidInputTypeRunnable(Runnable[int, int]): @property + @override def InputType(self) -> type: raise TypeError() + @override def invoke( self, input: int, @@ -375,9 +378,11 @@ def test_runnable_get_graph_with_invalid_output_type() -> None: class InvalidOutputTypeRunnable(Runnable[int, int]): @property + @override def OutputType(self) -> type: raise TypeError() + @override def invoke( self, input: int, diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 377e7a72457f1..bdb7f5fec7174 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -492,7 +492,7 @@ def test_get_output_schema() -> None: def test_get_input_schema_input_messages() -> None: from pydantic import RootModel - RunnableWithMessageHistoryInput = RootModel[Sequence[BaseMessage]] + runnable_with_message_history_input = RootModel[Sequence[BaseMessage]] runnable = RunnableLambda( lambda messages: { @@ -514,7 +514,7 @@ def test_get_input_schema_input_messages() -> None: with_history = RunnableWithMessageHistory( runnable, get_session_history, output_messages_key="output" ) - expected_schema = _schema(RunnableWithMessageHistoryInput) + expected_schema = _schema(runnable_with_message_history_input) expected_schema["title"] = "RunnableWithChatHistoryInput" assert _schema(with_history.get_input_schema()) == expected_schema diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index cb6d16dd1cc9e..a4bcdc0aec6e4 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -92,7 +92,7 @@ ) from langchain_core.tracers.context import collect_runs from tests.unit_tests.pydantic_utils import _normalize_schema, _schema -from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk +from tests.unit_tests.stubs import AnyStr, _any_id_ai_message, _any_id_ai_message_chunk PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split("."))) @@ -1704,7 +1704,7 @@ def test_prompt_with_chat_model( tracer = FakeTracer() assert chain.invoke( {"question": "What is your name?"}, dict(callbacks=[tracer]) - ) == _AnyIdAIMessage(content="foo") + ) == _any_id_ai_message(content="foo") assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( messages=[ @@ -1729,8 +1729,8 @@ def test_prompt_with_chat_model( ], dict(callbacks=[tracer]), ) == [ - _AnyIdAIMessage(content="foo"), - _AnyIdAIMessage(content="foo"), + _any_id_ai_message(content="foo"), + _any_id_ai_message(content="foo"), ] assert prompt_spy.call_args.args[1] == [ {"question": "What is your name?"}, @@ -1770,9 +1770,9 @@ def test_prompt_with_chat_model( assert [ *chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer])) ] == [ - _AnyIdAIMessageChunk(content="f"), - _AnyIdAIMessageChunk(content="o"), - _AnyIdAIMessageChunk(content="o"), + _any_id_ai_message_chunk(content="f"), + _any_id_ai_message_chunk(content="o"), + _any_id_ai_message_chunk(content="o"), ] assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( @@ -1810,7 +1810,7 @@ async def test_prompt_with_chat_model_async( tracer = FakeTracer() assert await chain.ainvoke( {"question": "What is your name?"}, dict(callbacks=[tracer]) - ) == _AnyIdAIMessage(content="foo") + ) == _any_id_ai_message(content="foo") assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( messages=[ @@ -1835,8 +1835,8 @@ async def test_prompt_with_chat_model_async( ], dict(callbacks=[tracer]), ) == [ - _AnyIdAIMessage(content="foo"), - _AnyIdAIMessage(content="foo"), + _any_id_ai_message(content="foo"), + _any_id_ai_message(content="foo"), ] assert prompt_spy.call_args.args[1] == [ {"question": "What is your name?"}, @@ -1879,9 +1879,9 @@ async def test_prompt_with_chat_model_async( {"question": "What is your name?"}, dict(callbacks=[tracer]) ) ] == [ - _AnyIdAIMessageChunk(content="f"), - _AnyIdAIMessageChunk(content="o"), - _AnyIdAIMessageChunk(content="o"), + _any_id_ai_message_chunk(content="f"), + _any_id_ai_message_chunk(content="o"), + _any_id_ai_message_chunk(content="o"), ] assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( @@ -2553,7 +2553,7 @@ def test_prompt_with_chat_model_and_parser( HumanMessage(content="What is your name?"), ] ) - assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar") + assert parser_spy.call_args.args[1] == _any_id_ai_message(content="foo, bar") assert tracer.runs == snapshot @@ -2690,7 +2690,7 @@ def test_seq_dict_prompt_llm( ), ] ) - assert parser_spy.call_args.args[1] == _AnyIdAIMessage(content="foo, bar") + assert parser_spy.call_args.args[1] == _any_id_ai_message(content="foo, bar") assert len([r for r in tracer.runs if r.parent_run_id is None]) == 1 parent_run = next(r for r in tracer.runs if r.parent_run_id is None) assert len(parent_run.child_runs) == 4 @@ -2736,7 +2736,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) -> assert chain.invoke( {"question": "What is your name?"}, dict(callbacks=[tracer]) ) == { - "chat": _AnyIdAIMessage(content="i'm a chatbot"), + "chat": _any_id_ai_message(content="i'm a chatbot"), "llm": "i'm a textbot", } assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} @@ -2946,7 +2946,7 @@ def test_seq_prompt_map(mocker: MockerFixture, snapshot: SnapshotAssertion) -> N assert chain.invoke( {"question": "What is your name?"}, dict(callbacks=[tracer]) ) == { - "chat": _AnyIdAIMessage(content="i'm a chatbot"), + "chat": _any_id_ai_message(content="i'm a chatbot"), "llm": "i'm a textbot", "passthrough": ChatPromptValue( messages=[ @@ -3010,7 +3010,7 @@ def test_map_stream() -> None: assert streamed_chunks[0] in [ {"passthrough": prompt.invoke({"question": "What is your name?"})}, {"llm": "i"}, - {"chat": _AnyIdAIMessageChunk(content="i")}, + {"chat": _any_id_ai_message_chunk(content="i")}, ] assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1 assert all(len(c.keys()) == 1 for c in streamed_chunks) @@ -3069,11 +3069,11 @@ def test_map_stream() -> None: assert streamed_chunks[0] in [ {"llm": "i"}, - {"chat": _AnyIdAIMessageChunk(content="i")}, + {"chat": _any_id_ai_message_chunk(content="i")}, ] if not ( # TODO(Rewrite properly) statement above streamed_chunks[0] == {"llm": "i"} - or {"chat": _AnyIdAIMessageChunk(content="i")} + or {"chat": _any_id_ai_message_chunk(content="i")} ): raise AssertionError(f"Got an unexpected chunk: {streamed_chunks[0]}") @@ -3118,7 +3118,7 @@ def test_map_stream_iterator_input() -> None: assert streamed_chunks[0] in [ {"passthrough": "i"}, {"llm": "i"}, - {"chat": _AnyIdAIMessageChunk(content="i")}, + {"chat": _any_id_ai_message_chunk(content="i")}, ] assert len(streamed_chunks) == len(chat_res) + len(llm_res) + len(llm_res) assert all(len(c.keys()) == 1 for c in streamed_chunks) @@ -3162,7 +3162,7 @@ async def test_map_astream() -> None: assert streamed_chunks[0] in [ {"passthrough": prompt.invoke({"question": "What is your name?"})}, {"llm": "i"}, - {"chat": _AnyIdAIMessageChunk(content="i")}, + {"chat": _any_id_ai_message_chunk(content="i")}, ] assert len(streamed_chunks) == len(chat_res) + len(llm_res) + 1 assert all(len(c.keys()) == 1 for c in streamed_chunks) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index f8876c1bfba13..2d6d652afe4e8 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -31,7 +31,7 @@ from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.schema import StreamEvent from langchain_core.tools import tool -from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk +from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]: @@ -502,7 +502,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": {"a": "b"}, "name": "my_model", @@ -511,7 +511,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": {"a": "b"}, "name": "my_model", @@ -520,7 +520,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": {"a": "b"}, "name": "my_model", @@ -529,7 +529,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"output": _AnyIdAIMessageChunk(content="hello world!")}, + "data": {"output": _any_id_ai_message_chunk(content="hello world!")}, "event": "on_chat_model_end", "metadata": {"a": "b"}, "name": "my_model", @@ -574,7 +574,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -587,7 +587,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -600,7 +600,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -620,7 +620,9 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: [ { "generation_info": None, - "message": _AnyIdAIMessage(content="hello world!"), + "message": _any_id_ai_message( + content="hello world!" + ), "text": "hello world!", "type": "ChatGeneration", } @@ -643,7 +645,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessage(content="hello world!")}, + "data": {"chunk": _any_id_ai_message(content="hello world!")}, "event": "on_chain_stream", "metadata": {}, "name": "i_dont_stream", @@ -652,7 +654,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": [], }, { - "data": {"output": _AnyIdAIMessage(content="hello world!")}, + "data": {"output": _any_id_ai_message(content="hello world!")}, "event": "on_chain_end", "metadata": {}, "name": "i_dont_stream", @@ -697,7 +699,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -710,7 +712,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -723,7 +725,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -743,7 +745,9 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: [ { "generation_info": None, - "message": _AnyIdAIMessage(content="hello world!"), + "message": _any_id_ai_message( + content="hello world!" + ), "text": "hello world!", "type": "ChatGeneration", } @@ -766,7 +770,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessage(content="hello world!")}, + "data": {"chunk": _any_id_ai_message(content="hello world!")}, "event": "on_chain_stream", "metadata": {}, "name": "ai_dont_stream", @@ -775,7 +779,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": [], }, { - "data": {"output": _AnyIdAIMessage(content="hello world!")}, + "data": {"output": _any_id_ai_message(content="hello world!")}, "event": "on_chain_end", "metadata": {}, "name": "ai_dont_stream", diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index fba80d291a49a..84b06885db7f1 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -52,7 +52,7 @@ from tests.unit_tests.runnables.test_runnable_events_v1 import ( _assert_events_equal_allow_superset_metadata, ) -from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk +from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]: @@ -538,7 +538,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -551,7 +551,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -564,7 +564,7 @@ async def test_astream_events_from_model() -> None: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -578,7 +578,7 @@ async def test_astream_events_from_model() -> None: }, { "data": { - "output": _AnyIdAIMessageChunk(content="hello world!"), + "output": _any_id_ai_message_chunk(content="hello world!"), }, "event": "on_chat_model_end", "metadata": { @@ -645,7 +645,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -658,7 +658,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -671,7 +671,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -686,7 +686,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: { "data": { "input": {"messages": [[HumanMessage(content="hello")]]}, - "output": _AnyIdAIMessage(content="hello world!"), + "output": _any_id_ai_message(content="hello world!"), }, "event": "on_chat_model_end", "metadata": { @@ -700,7 +700,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessage(content="hello world!")}, + "data": {"chunk": _any_id_ai_message(content="hello world!")}, "event": "on_chain_stream", "metadata": {}, "name": "i_dont_stream", @@ -709,7 +709,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": [], }, { - "data": {"output": _AnyIdAIMessage(content="hello world!")}, + "data": {"output": _any_id_ai_message(content="hello world!")}, "event": "on_chain_end", "metadata": {}, "name": "i_dont_stream", @@ -754,7 +754,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="hello")}, + "data": {"chunk": _any_id_ai_message_chunk(content="hello")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -767,7 +767,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content=" ")}, + "data": {"chunk": _any_id_ai_message_chunk(content=" ")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -780,7 +780,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessageChunk(content="world!")}, + "data": {"chunk": _any_id_ai_message_chunk(content="world!")}, "event": "on_chat_model_stream", "metadata": { "a": "b", @@ -795,7 +795,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: { "data": { "input": {"messages": [[HumanMessage(content="hello")]]}, - "output": _AnyIdAIMessage(content="hello world!"), + "output": _any_id_ai_message(content="hello world!"), }, "event": "on_chat_model_end", "metadata": { @@ -809,7 +809,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": ["my_model"], }, { - "data": {"chunk": _AnyIdAIMessage(content="hello world!")}, + "data": {"chunk": _any_id_ai_message(content="hello world!")}, "event": "on_chain_stream", "metadata": {}, "name": "ai_dont_stream", @@ -818,7 +818,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: "tags": [], }, { - "data": {"output": _AnyIdAIMessage(content="hello world!")}, + "data": {"output": _any_id_ai_message(content="hello world!")}, "event": "on_chain_end", "metadata": {}, "name": "ai_dont_stream", diff --git a/libs/core/tests/unit_tests/stubs.py b/libs/core/tests/unit_tests/stubs.py index b752364e3af5d..95f36b72b0fbf 100644 --- a/libs/core/tests/unit_tests/stubs.py +++ b/libs/core/tests/unit_tests/stubs.py @@ -16,28 +16,28 @@ def __eq__(self, other: Any) -> bool: # subclassed strings. -def _AnyIdDocument(**kwargs: Any) -> Document: +def _any_id_document(**kwargs: Any) -> Document: """Create a document with an id field.""" message = Document(**kwargs) message.id = AnyStr() return message -def _AnyIdAIMessage(**kwargs: Any) -> AIMessage: +def _any_id_ai_message(**kwargs: Any) -> AIMessage: """Create ai message with an any id field.""" message = AIMessage(**kwargs) message.id = AnyStr() return message -def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk: +def _any_id_ai_message_chunk(**kwargs: Any) -> AIMessageChunk: """Create ai message with an any id field.""" message = AIMessageChunk(**kwargs) message.id = AnyStr() return message -def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage: +def _any_id_human_message(**kwargs: Any) -> HumanMessage: """Create a human with an any id field.""" message = HumanMessage(**kwargs) message.id = AnyStr() diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index c6c47396df08e..53d6942e6722f 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -781,7 +781,7 @@ def test_convert_to_messages() -> None: @pytest.mark.parametrize( - "MessageClass", + "message_class", [ AIMessage, AIMessageChunk, @@ -790,39 +790,39 @@ def test_convert_to_messages() -> None: SystemMessage, ], ) -def test_message_name(MessageClass: Type) -> None: - msg = MessageClass(content="foo", name="bar") +def test_message_name(message_class: Type) -> None: + msg = message_class(content="foo", name="bar") assert msg.name == "bar" - msg2 = MessageClass(content="foo", name=None) + msg2 = message_class(content="foo", name=None) assert msg2.name is None - msg3 = MessageClass(content="foo") + msg3 = message_class(content="foo") assert msg3.name is None @pytest.mark.parametrize( - "MessageClass", + "message_class", [FunctionMessage, FunctionMessageChunk], ) -def test_message_name_function(MessageClass: Type) -> None: +def test_message_name_function(message_class: Type) -> None: # functionmessage doesn't support name=None - msg = MessageClass(name="foo", content="bar") + msg = message_class(name="foo", content="bar") assert msg.name == "foo" @pytest.mark.parametrize( - "MessageClass", + "message_class", [ChatMessage, ChatMessageChunk], ) -def test_message_name_chat(MessageClass: Type) -> None: - msg = MessageClass(content="foo", role="user", name="bar") +def test_message_name_chat(message_class: Type) -> None: + msg = message_class(content="foo", role="user", name="bar") assert msg.name == "bar" - msg2 = MessageClass(content="foo", role="user", name=None) + msg2 = message_class(content="foo", role="user", name=None) assert msg2.name is None - msg3 = MessageClass(content="foo", role="user") + msg3 = message_class(content="foo", role="user") assert msg3.name is None diff --git a/libs/core/tests/unit_tests/test_pydantic_serde.py b/libs/core/tests/unit_tests/test_pydantic_serde.py index c19d14cb24946..87af2fa5a611b 100644 --- a/libs/core/tests/unit_tests/test_pydantic_serde.py +++ b/libs/core/tests/unit_tests/test_pydantic_serde.py @@ -51,14 +51,14 @@ def test_serde_any_message() -> None: ), ] - Model = RootModel[AnyMessage] + model = RootModel[AnyMessage] for lc_object in lc_objects: d = lc_object.model_dump() assert "type" in d, f"Missing key `type` for {type(lc_object)}" - obj1 = Model.model_validate(d) + obj1 = model.model_validate(d) assert type(obj1.root) is type(lc_object), f"failed for {type(lc_object)}" with pytest.raises((TypeError, ValidationError)): # Make sure that specifically validation error is raised - Model.model_validate({}) + model.model_validate({}) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 83a5b5956d9ee..e20c7a964c49f 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -1424,7 +1424,7 @@ def _run(self, x: int, y: Annotated[str, InjectedToolArg]) -> Any: return y -class fooSchema(BaseModel): +class fooSchema(BaseModel): # noqa: N801 """foo.""" x: int = Field(..., description="abc") @@ -1571,14 +1571,14 @@ def test_tool_injected_arg() -> None: def test_tool_inherited_injected_arg() -> None: - class barSchema(BaseModel): + class BarSchema(BaseModel): """bar.""" y: Annotated[str, "foobar comment", InjectedToolArg()] = Field( ..., description="123" ) - class fooSchema(barSchema): + class FooSchema(BarSchema): """foo.""" x: int = Field(..., description="abc") @@ -1586,14 +1586,14 @@ class fooSchema(barSchema): class InheritedInjectedArgTool(BaseTool): name: str = "foo" description: str = "foo." - args_schema: Type[BaseModel] = fooSchema + args_schema: Type[BaseModel] = FooSchema def _run(self, x: int, y: str) -> Any: return y tool_ = InheritedInjectedArgTool() assert tool_.get_input_schema().model_json_schema() == { - "title": "fooSchema", # Matches the title from the provided schema + "title": "FooSchema", # Matches the title from the provided schema "description": "foo.", "type": "object", "properties": { @@ -1880,15 +1880,15 @@ def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None: A = TypeVar("A") if use_v1_namespace: - from pydantic.v1 import BaseModel as BM1 + from pydantic.v1 import BaseModel as BaseModel1 - class ModelA(BM1, Generic[A], extra="allow"): + class ModelA(BaseModel1, Generic[A], extra="allow"): a: A else: - from pydantic import BaseModel as BM2 + from pydantic import BaseModel as BaseModel2 from pydantic import ConfigDict - class ModelA(BM2, Generic[A]): # type: ignore[no-redef] + class ModelA(BaseModel2, Generic[A]): # type: ignore[no-redef] a: A model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") @@ -2084,13 +2084,13 @@ def test_structured_tool_direct_init() -> None: def foo(bar: str) -> str: return bar - async def asyncFoo(bar: str) -> str: + async def async_foo(bar: str) -> str: return bar - class fooSchema(BaseModel): + class FooSchema(BaseModel): bar: str = Field(..., description="The bar") - tool = StructuredTool(name="foo", args_schema=fooSchema, coroutine=asyncFoo) + tool = StructuredTool(name="foo", args_schema=FooSchema, coroutine=async_foo) with pytest.raises(NotImplementedError): assert tool.invoke("hello") == "hello" diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index ec6ce542732bb..65293516d0edb 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -48,7 +48,7 @@ @pytest.fixture() def pydantic() -> Type[BaseModel]: - class dummy_function(BaseModel): + class dummy_function(BaseModel): # noqa: N801 """dummy function""" arg1: int = Field(..., description="foo") @@ -58,7 +58,7 @@ class dummy_function(BaseModel): @pytest.fixture() -def Annotated_function() -> Callable: +def annotated_function() -> Callable: def dummy_function( arg1: ExtensionsAnnotated[int, "foo"], arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"], @@ -128,7 +128,7 @@ class Schema(BaseModel): @pytest.fixture() def dummy_pydantic() -> Type[BaseModel]: - class dummy_function(BaseModel): + class dummy_function(BaseModel): # noqa: N801 """dummy function""" arg1: int = Field(..., description="foo") @@ -139,7 +139,7 @@ class dummy_function(BaseModel): @pytest.fixture() def dummy_pydantic_v2() -> Type[BaseModelV2Maybe]: - class dummy_function(BaseModelV2Maybe): + class dummy_function(BaseModelV2Maybe): # noqa: N801 """dummy function""" arg1: int = FieldV2Maybe(..., description="foo") @@ -152,7 +152,7 @@ class dummy_function(BaseModelV2Maybe): @pytest.fixture() def dummy_typing_typed_dict() -> Type: - class dummy_function(TypingTypedDict): + class dummy_function(TypingTypedDict): # noqa: N801 """dummy function""" arg1: TypingAnnotated[int, ..., "foo"] # noqa: F821 @@ -163,7 +163,7 @@ class dummy_function(TypingTypedDict): @pytest.fixture() def dummy_typing_typed_dict_docstring() -> Type: - class dummy_function(TypingTypedDict): + class dummy_function(TypingTypedDict): # noqa: N801 """dummy function Args: @@ -179,7 +179,7 @@ class dummy_function(TypingTypedDict): @pytest.fixture() def dummy_extensions_typed_dict() -> Type: - class dummy_function(ExtensionsTypedDict): + class dummy_function(ExtensionsTypedDict): # noqa: N801 """dummy function""" arg1: ExtensionsAnnotated[int, ..., "foo"] @@ -190,7 +190,7 @@ class dummy_function(ExtensionsTypedDict): @pytest.fixture() def dummy_extensions_typed_dict_docstring() -> Type: - class dummy_function(ExtensionsTypedDict): + class dummy_function(ExtensionsTypedDict): # noqa: N801 """dummy function Args: @@ -251,7 +251,7 @@ def test_convert_to_openai_function( dummy_structured_tool: StructuredTool, dummy_tool: BaseTool, json_schema: Dict, - Annotated_function: Callable, + annotated_function: Callable, dummy_pydantic: Type[BaseModel], runnable: Runnable, dummy_typing_typed_dict: Type, @@ -285,7 +285,7 @@ def test_convert_to_openai_function( expected, Dummy.dummy_function, DummyWithClassMethod.dummy_function, - Annotated_function, + annotated_function, dummy_pydantic, dummy_typing_typed_dict, dummy_typing_typed_dict_docstring, @@ -533,20 +533,20 @@ def test__convert_typed_dict_to_openai_function( use_extension_typed_dict: bool, use_extension_annotated: bool ) -> None: if use_extension_typed_dict: - TypedDict = ExtensionsTypedDict + typed_dict = ExtensionsTypedDict else: - TypedDict = TypingTypedDict + typed_dict = TypingTypedDict if use_extension_annotated: - Annotated = TypingAnnotated + annotated = TypingAnnotated else: - Annotated = TypingAnnotated + annotated = TypingAnnotated - class SubTool(TypedDict): + class SubTool(typed_dict): """Subtool docstring""" - args: Annotated[Dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore + args: annotated[Dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore - class Tool(TypedDict): + class Tool(typed_dict): """Docstring Args: @@ -556,20 +556,20 @@ class Tool(TypedDict): arg1: str arg2: Union[int, str, bool] arg3: Optional[List[SubTool]] - arg4: Annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722 - arg5: Annotated[Optional[float], None] - arg6: Annotated[ + arg4: annotated[Literal["bar", "baz"], ..., "this does foo"] # noqa: F722 + arg5: annotated[Optional[float], None] + arg6: annotated[ Optional[Sequence[Mapping[str, Tuple[Iterable[Any], SubTool]]]], [] ] - arg7: Annotated[List[SubTool], ...] - arg8: Annotated[Tuple[SubTool], ...] - arg9: Annotated[Sequence[SubTool], ...] - arg10: Annotated[Iterable[SubTool], ...] - arg11: Annotated[Set[SubTool], ...] - arg12: Annotated[Dict[str, SubTool], ...] - arg13: Annotated[Mapping[str, SubTool], ...] - arg14: Annotated[MutableMapping[str, SubTool], ...] - arg15: Annotated[bool, False, "flag"] # noqa: F821 # type: ignore + arg7: annotated[List[SubTool], ...] + arg8: annotated[Tuple[SubTool], ...] + arg9: annotated[Sequence[SubTool], ...] + arg10: annotated[Iterable[SubTool], ...] + arg11: annotated[Set[SubTool], ...] + arg12: annotated[Dict[str, SubTool], ...] + arg13: annotated[Mapping[str, SubTool], ...] + arg14: annotated[MutableMapping[str, SubTool], ...] + arg15: annotated[bool, False, "flag"] # noqa: F821 # type: ignore expected = { "name": "Tool", diff --git a/libs/core/tests/unit_tests/vectorstores/test_in_memory.py b/libs/core/tests/unit_tests/vectorstores/test_in_memory.py index c0f9944d07789..5373d022dbf6d 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_in_memory.py +++ b/libs/core/tests/unit_tests/vectorstores/test_in_memory.py @@ -10,7 +10,7 @@ from langchain_core.documents import Document from langchain_core.embeddings.fake import DeterministicFakeEmbedding from langchain_core.vectorstores import InMemoryVectorStore -from tests.unit_tests.stubs import _AnyIdDocument +from tests.unit_tests.stubs import _any_id_document class TestInMemoryReadWriteTestSuite(ReadWriteTestSuite): @@ -33,13 +33,13 @@ async def test_inmemory_similarity_search() -> None: # Check sync version output = store.similarity_search("foo", k=1) - assert output == [_AnyIdDocument(page_content="foo")] + assert output == [_any_id_document(page_content="foo")] # Check async version output = await store.asimilarity_search("bar", k=2) assert output == [ - _AnyIdDocument(page_content="bar"), - _AnyIdDocument(page_content="baz"), + _any_id_document(page_content="bar"), + _any_id_document(page_content="baz"), ] @@ -80,16 +80,16 @@ async def test_inmemory_mmr() -> None: # make sure we can k > docstore size output = docsearch.max_marginal_relevance_search("foo", k=10, lambda_mult=0.1) assert len(output) == len(texts) - assert output[0] == _AnyIdDocument(page_content="foo") - assert output[1] == _AnyIdDocument(page_content="foy") + assert output[0] == _any_id_document(page_content="foo") + assert output[1] == _any_id_document(page_content="foy") # Check async version output = await docsearch.amax_marginal_relevance_search( "foo", k=10, lambda_mult=0.1 ) assert len(output) == len(texts) - assert output[0] == _AnyIdDocument(page_content="foo") - assert output[1] == _AnyIdDocument(page_content="foy") + assert output[0] == _any_id_document(page_content="foo") + assert output[1] == _any_id_document(page_content="foy") async def test_inmemory_dump_load(tmp_path: Path) -> None: @@ -117,7 +117,7 @@ async def test_inmemory_filter() -> None: # Check sync version output = store.similarity_search("fee", filter=lambda doc: doc.metadata["id"] == 1) - assert output == [_AnyIdDocument(page_content="foo", metadata={"id": 1})] + assert output == [_any_id_document(page_content="foo", metadata={"id": 1})] # filter with not stored document id output = await store.asimilarity_search( diff --git a/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py b/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py index 971315752b8c2..52ff3685a97d7 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py +++ b/libs/core/tests/unit_tests/vectorstores/test_vectorstore.py @@ -49,6 +49,7 @@ def add_texts( def get_by_ids(self, ids: Sequence[str], /) -> List[Document]: return [self.store[id] for id in ids if id in self.store] + @classmethod def from_texts( # type: ignore cls, texts: List[str],