Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ollama provider #95

Merged
merged 1 commit into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions LangChain.sln
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Google.
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Azure", "src\libs\Providers\LangChain.Providers.Azure\LangChain.Providers.Azure.csproj", "{18F5AAB1-1750-41BD-B623-6339CA5754D9}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LangChain.Providers.Ollama.IntegrationTests", "src\tests\LangChain.Providers.Ollama.IntegrationTests\LangChain.Providers.Ollama.IntegrationTests.csproj", "{72B1E2CC-1A34-470E-A579-034CB0972BB7}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Ollama", "src\libs\Providers\LangChain.Providers.Ollama\LangChain.Providers.Ollama.csproj", "{4913844F-74EC-4E74-AE8A-EA825569E6BA}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -378,6 +382,14 @@ Global
{18F5AAB1-1750-41BD-B623-6339CA5754D9}.Debug|Any CPU.Build.0 = Debug|Any CPU
{18F5AAB1-1750-41BD-B623-6339CA5754D9}.Release|Any CPU.ActiveCfg = Release|Any CPU
{18F5AAB1-1750-41BD-B623-6339CA5754D9}.Release|Any CPU.Build.0 = Release|Any CPU
{72B1E2CC-1A34-470E-A579-034CB0972BB7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{72B1E2CC-1A34-470E-A579-034CB0972BB7}.Debug|Any CPU.Build.0 = Debug|Any CPU
{72B1E2CC-1A34-470E-A579-034CB0972BB7}.Release|Any CPU.ActiveCfg = Release|Any CPU
{72B1E2CC-1A34-470E-A579-034CB0972BB7}.Release|Any CPU.Build.0 = Release|Any CPU
{4913844F-74EC-4E74-AE8A-EA825569E6BA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{4913844F-74EC-4E74-AE8A-EA825569E6BA}.Debug|Any CPU.Build.0 = Debug|Any CPU
{4913844F-74EC-4E74-AE8A-EA825569E6BA}.Release|Any CPU.ActiveCfg = Release|Any CPU
{4913844F-74EC-4E74-AE8A-EA825569E6BA}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -441,6 +453,8 @@ Global
{B953ABEC-50DD-4A63-A12A-E82F124C7D5B} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
{32FC123E-F269-4352-848C-0161B53093CC} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
{DEAFA0CB-462D-4D74-B16F-68FD83FE3858} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
{72B1E2CC-1A34-470E-A579-034CB0972BB7} = {FDEE2E22-C239-4921-83B2-9797F765FD6A}
{4913844F-74EC-4E74-AE8A-EA825569E6BA} = {E55391DE-F8F3-4CC2-A0E3-2406C76E9C68}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {5C00D0F1-6138-4ED9-846B-97E43D6DFF1C}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
using System.Text.Json.Serialization;

namespace OllamaTest;

public class GenerateCompletionRequest
{
/// <summary>
/// The model name (required)
/// </summary>
[JsonPropertyName("model")]
public string Model { get; set; }

/// <summary>
/// The prompt to generate a response for
/// </summary>
[JsonPropertyName("prompt")]
public string Prompt { get; set; }

/// <summary>
/// Additional model parameters listed in the documentation for the Modelfile such as temperature
/// </summary>
[JsonPropertyName("options")]
public OllamaLanguageModelOptions Options { get; set; }

/// <summary>
/// Base64-encoded images (for multimodal models such as llava)
/// </summary>
[JsonPropertyName("images")]
public string[] Images { get; set; }

/// <summary>
/// System prompt to (overrides what is defined in the Modelfile)
/// </summary>
[JsonPropertyName("system")]
public string System { get; set; }

/// <summary>
/// The full prompt or prompt template (overrides what is defined in the Modelfile)
/// </summary>
[JsonPropertyName("template")]
public string Template { get; set; }

/// <summary>
/// The context parameter returned from a previous request to /generate, this can be used to keep a short conversational memory
/// </summary>
[JsonPropertyName("context")]
public long[] Context { get; set; }

/// <summary>
/// If false the response will be returned as a single response object, rather than a stream of objects
/// </summary>
[JsonPropertyName("stream")]
public bool Stream { get; set; } = true;

/// <summary>
/// In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the raw parameter to disable formatting.
/// </summary>
[JsonPropertyName("raw")]
public bool Raw { get; set; }
}

public class GenerateCompletionResponseStream
{
[JsonPropertyName("model")]
public string Model { get; set; }

[JsonPropertyName("created_at")]
public string CreatedAt { get; set; }

[JsonPropertyName("response")]
public string Response { get; set; }

[JsonPropertyName("done")]
public bool Done { get; set; }
}

public class GenerateCompletionDoneResponseStream : GenerateCompletionResponseStream
{
[JsonPropertyName("context")]
public long[] Context { get; set; }

[JsonPropertyName("total_duration")]
public long TotalDuration { get; set; }

[JsonPropertyName("load_duration")]
public long LoadDuration { get; set; }

[JsonPropertyName("prompt_eval_count")]
public int PromptEvalCount { get; set; }

[JsonPropertyName("prompt_eval_duration")]
public long PromptEvalDuration { get; set; }

[JsonPropertyName("eval_count")]
public int EvalCount { get; set; }

[JsonPropertyName("eval_duration")]
public long EvalDuration { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using System.Text.Json.Serialization;

namespace OllamaTest;

public class GenerateEmbeddingRequest
{
[JsonPropertyName("model")]
public string Model { get; set; }

[JsonPropertyName("prompt")]
public string Prompt { get; set; }

[JsonPropertyName("options")]
public OllamaLanguageModelOptions Options { get; set; }
}

public class GenerateEmbeddingResponse
{
[JsonPropertyName("embedding")]
public double[] Embedding { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFrameworks>net4.6.2;netstandard2.0;net6.0;net7.0;net8.0</TargetFrameworks>
<SignAssembly>false</SignAssembly>
<NoWarn>$(NoWarn);CA1003;CA1307</NoWarn>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="System.Net.Http" />
</ItemGroup>

<ItemGroup Label="Usings">
<Using Include="System.Net.Http" />
</ItemGroup>

<PropertyGroup Label="NuGet">
<Description>Ollama Chat model provider.</Description>
<PackageTags>$(PackageTags);Ollama;api</PackageTags>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\..\LangChain.Core\LangChain.Core.csproj" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using System.Diagnostics;
using System.Text.Json.Serialization;

namespace OllamaTest;

public class ListModelsResponse
{
[JsonPropertyName("models")]
public Model[] Models { get; set; }
}

[DebuggerDisplay("{Name}")]
public class Model
{
[JsonPropertyName("name")]
public string Name { get; set; }

[JsonPropertyName("modified_at")]
public DateTime ModifiedAt { get; set; }

[JsonPropertyName("size")]
public long Size { get; set; }

[JsonPropertyName("digest")]
public string Digest { get; set; }
}
106 changes: 106 additions & 0 deletions src/libs/Providers/LangChain.Providers.Ollama/OllamaApiClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
using System.IO;
using System.Net;
using System.Numerics;
using System.Reflection;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;


namespace OllamaTest;

public class OllamaApiClient
{

private readonly HttpClient _client;

public OllamaApiClient(string url)
: this(new HttpClient() { BaseAddress = new Uri(url) })
{
}

public OllamaApiClient(HttpClient client)
{
_client = client ?? throw new ArgumentNullException(nameof(client));
}

public async Task<IEnumerable<Model>> ListLocalModels()
{
var data = await GetAsync<ListModelsResponse>("/api/tags");
return data.Models;
}

public async Task PullModel(string name)
{
var request = new HttpRequestMessage(HttpMethod.Post, "/api/pull")
{
Content = new StringContent(JsonSerializer.Serialize(new { name, stream=false }), Encoding.UTF8, "application/json")
};
using var response = await _client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
response.EnsureSuccessStatusCode();

using var stream = await response.Content.ReadAsStreamAsync();
using var reader = new StreamReader(stream);
await reader.ReadToEndAsync();
}



private async Task<TResponse> GetAsync<TResponse>(string endpoint)
{
var response = await _client.GetAsync(endpoint);
response.EnsureSuccessStatusCode();

var responseBody = await response.Content.ReadAsStringAsync();
return JsonSerializer.Deserialize<TResponse>(responseBody);
}


public async IAsyncEnumerable<GenerateCompletionResponseStream> GenerateCompletion(GenerateCompletionRequest generateRequest)
{
var content = JsonSerializer.Serialize(generateRequest, new JsonSerializerOptions(){DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull});
var request = new HttpRequestMessage(HttpMethod.Post, "/api/generate")
{
Content = new StringContent(content, Encoding.UTF8, "application/json")
};

var completion = generateRequest.Stream ? HttpCompletionOption.ResponseHeadersRead : HttpCompletionOption.ResponseContentRead;

using var response = await _client.SendAsync(request, completion);
response.EnsureSuccessStatusCode();

using var stream = await response.Content.ReadAsStreamAsync();
using var reader = new StreamReader(stream);

while (!reader.EndOfStream)
{
string line = await reader.ReadLineAsync();
var streamedResponse = JsonSerializer.Deserialize<GenerateCompletionResponseStream>(line);

yield return streamedResponse;
}
}

public async Task<GenerateEmbeddingResponse> GenerateEmbeddings(GenerateEmbeddingRequest generateRequest)
{
var content = JsonSerializer.Serialize(generateRequest, new JsonSerializerOptions() { DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull });
var request = new HttpRequestMessage(HttpMethod.Post, "/api/generate")
{
Content = new StringContent(content, Encoding.UTF8, "application/json")
};

var completion = HttpCompletionOption.ResponseContentRead;

using var response = await _client.SendAsync(request, completion);
response.EnsureSuccessStatusCode();

using var stream = await response.Content.ReadAsStreamAsync();
using var reader = new StreamReader(stream);
string line = await reader.ReadToEndAsync();
var streamedResponse = JsonSerializer.Deserialize<GenerateEmbeddingResponse>(line);
return streamedResponse;

}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using LangChain.Providers;
using System.Diagnostics;
using LangChain.Abstractions.Embeddings.Base;

namespace OllamaTest;

public class OllamaLanguageModelEmbeddings : IEmbeddings
{
private readonly string _modelName;
public OllamaLanguageModelOptions Options { get; }
private readonly OllamaApiClient _api;


public OllamaLanguageModelEmbeddings(string modelName, string? url=null, OllamaLanguageModelOptions options=null)

Check warning on line 14 in src/libs/Providers/LangChain.Providers.Ollama/OllamaLanguageModelEmbeddings.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Cannot convert null literal to non-nullable reference type.
{
_modelName = modelName;
Options = options;

url ??= "http://localhost:11434";
options ??= new OllamaLanguageModelOptions();
Options = options;
_api = new OllamaApiClient(url);

}

public Task<float[][]> EmbedDocumentsAsync(string[] texts, CancellationToken cancellationToken = default)
{
texts = texts ?? throw new ArgumentNullException(nameof(texts));

double[][] result = new double[texts.Length][];
for (int i = 0; i < texts.Length; i++)
{
result[i] = _api.GenerateEmbeddings(new GenerateEmbeddingRequest(){Prompt = texts[i],Model = _modelName,Options = Options}).Result.Embedding;
}
var result2 = result.Select(x => x.Select(y => (float)y).ToArray()).ToArray();
return Task.FromResult(result2);
}

public Task<float[]> EmbedQueryAsync(string text, CancellationToken cancellationToken = default)
{
text = text ?? throw new ArgumentNullException(nameof(text));

double[] result = _api.GenerateEmbeddings(new GenerateEmbeddingRequest() { Prompt = text, Model = _modelName, Options = Options }).Result.Embedding;
var result2 = result.Select(x => (float)x).ToArray();
return Task.FromResult(result2);

}
}
Loading
Loading