Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update base.py to remove ValueError #8050

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions langchain/chains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ class Chain(Serializable, ABC):
starting with on_chain_start, ending with on_chain_end or on_chain_error.
Each custom chain can optionally call additional callback methods, see Callback docs
for full details."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
callback_manager: Optional[BaseCallbackManager] = Field(
default=None, exclude=True
)
"""Deprecated, use `callbacks` instead."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
Expand All @@ -91,7 +93,8 @@ class Config:

@property
def _chain_type(self) -> str:
raise NotImplementedError("Saving not supported for this chain type.")
warnings.warn("Saving not supported for this chain type.", UserWarning)
return "NotImplemented"

@root_validator()
def raise_callback_manager_deprecation(cls, values: Dict) -> Dict:
Expand Down Expand Up @@ -180,7 +183,9 @@ async def _acall(
A dict of named outputs. Should contain all outputs specified in
`Chain.output_keys`.
"""
raise NotImplementedError("Async call not supported for this chain type.")
raise NotImplementedError(
"Async call not supported for this chain type."
)

def __call__(
self,
Expand Down Expand Up @@ -227,7 +232,9 @@ def __call__(
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
new_arg_supported = inspect.signature(self._call).parameters.get(
"run_manager"
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
Expand Down Expand Up @@ -294,7 +301,9 @@ async def acall(
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
new_arg_supported = inspect.signature(self._acall).parameters.get(
"run_manager"
)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
inputs,
Expand Down Expand Up @@ -342,7 +351,9 @@ def prep_outputs(
else:
return {**inputs, **outputs}

def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
def prep_inputs(
self, inputs: Union[Dict[str, Any], Any]
) -> Dict[str, str]:
"""Validate and prepare chain inputs, including adding inputs from memory.

Args:
Expand All @@ -359,7 +370,9 @@ def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
_input_keys = _input_keys.difference(
self.memory.memory_variables
)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
Expand Down Expand Up @@ -436,15 +449,17 @@ def run(

if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return self(args[0], callbacks=callbacks, tags=tags, metadata=metadata)[
_output_key
]
raise ValueError(
"`run` supports only one positional argument."
)
return self(
args[0], callbacks=callbacks, tags=tags, metadata=metadata
)[_output_key]

if kwargs and not args:
return self(kwargs, callbacks=callbacks, tags=tags, metadata=metadata)[
_output_key
]
return self(
kwargs, callbacks=callbacks, tags=tags, metadata=metadata
)[_output_key]

if not kwargs and not args:
raise ValueError(
Expand Down Expand Up @@ -512,7 +527,9 @@ async def arun(
)
elif args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
raise ValueError(
"`run` supports only one positional argument."
)
return (
await self.acall(
args[0], callbacks=callbacks, tags=tags, metadata=metadata
Expand Down Expand Up @@ -551,7 +568,11 @@ def dict(self, **kwargs: Any) -> Dict:
# -> {"_type": "foo", "verbose": False, ...}
"""
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
warnings.warn(
"Saving not supported for this chain type.", UserWarning
)
return {}

_dict = super().dict(**kwargs)
_dict["_type"] = self._chain_type
return _dict
Expand Down
Loading