Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Users can add local GenAI models. also remembers execution provider in MRU #263

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions AIDevGallery.Utils/ModelUrl.cs
Original file line number Diff line number Diff line change
@@ -104,9 +104,9 @@ public HuggingFaceUrl(string modelNameOrUrl)

modelNameOrUrl = modelNameOrUrl.Trim();

if (modelNameOrUrl.StartsWith("https://", StringComparison.InvariantCulture))
if (modelNameOrUrl.StartsWith("https://", StringComparison.OrdinalIgnoreCase))
{
if (!modelNameOrUrl.StartsWith("https://huggingface.co", StringComparison.InvariantCulture))
if (!modelNameOrUrl.StartsWith("https://huggingface.co", StringComparison.OrdinalIgnoreCase))
{
throw new ArgumentException("Invalid URL", nameof(modelNameOrUrl));
}
@@ -174,7 +174,7 @@ public GitHubUrl(string url)
url = url.Trim();
FullUrl = url;

if (!url.StartsWith("https://github.com/", StringComparison.InvariantCulture))
if (!url.StartsWith("https://github.com/", StringComparison.OrdinalIgnoreCase))
{
throw new ArgumentException("Invalid URL", nameof(url));
}
@@ -248,7 +248,11 @@ public static class UrlHelpers
/// <returns>The full URL as a string.</returns>
public static string GetFullUrl(string url)
{
if (url.StartsWith("https://github.com", StringComparison.InvariantCulture))
if (url.StartsWith("https://github.com", StringComparison.OrdinalIgnoreCase))
{
return url;
}
else if (url.StartsWith("local", StringComparison.OrdinalIgnoreCase))
{
return url;
}
10 changes: 3 additions & 7 deletions AIDevGallery/AIDevGallery.csproj
Original file line number Diff line number Diff line change
@@ -89,13 +89,8 @@
</ItemGroup>

<ItemGroup Condition="$(Platform) == 'ARM64'">
<PackageReference Include="Microsoft.ML.OnnxRuntime.Qnn" />
<PackageReference Include="Microsoft.ML.OnnxRuntimeGenAI" GeneratePathProperty="true" ExcludeAssets="all" />
<None Include="$(PKGMicrosoft_ML_OnnxRuntimeGenAI)\runtimes\win-arm64\native\onnxruntime-genai.dll">
<Link>onnxruntime-genai.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<PackageReference Include="Microsoft.ML.OnnxRuntime.QNN" />
<PackageReference Include="Microsoft.ML.OnnxRuntimeGenAI.QNN" />
<PackageReference Include="Microsoft.ML.OnnxRuntimeGenAI.Managed" />
</ItemGroup>

@@ -250,6 +245,7 @@
<None Remove="Assets\ModelIcons\HuggingFace.svg" />
<None Remove="Assets\ModelIcons\Microsoft.svg" />
<None Remove="Assets\ModelIcons\Mistral.svg" />
<None Remove="Assets\ModelIcons\onnx.svg" />
<None Remove="Assets\ModelIcons\OpenAI.png" />
<None Remove="Assets\TileImages\Chat.png" />
<None Remove="Assets\TileImages\ClassifyImage.png" />
12 changes: 12 additions & 0 deletions AIDevGallery/Assets/ModelIcons/onnx.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 11 additions & 10 deletions AIDevGallery/Controls/ModelSelectionControl.xaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<?xml version="1.0" encoding="utf-8" ?>
<UserControl
x:Class="AIDevGallery.Controls.ModelSelectionControl"
xmlns:local="using:AIDevGallery.Controls"
xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation"
xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
xmlns:d="http://schemas.microsoft.com/expression/blend/2008"
@@ -107,7 +108,7 @@
HorizontalSpacing="6"
Orientation="Horizontal"
VerticalSpacing="2"
Visibility="{x:Bind helpers:ModelDetailsHelper.ShowWhenDownloadedModel(ModelDetails)}">
Visibility="{x:Bind helpers:ModelDetailsHelper.ShowWhenOnnxModel(ModelDetails)}">
<TextBlock VerticalAlignment="Center" Visibility="{x:Bind ModelDetails.ParameterSize, Converter={StaticResource StringVisibilityConverter}}">
<Run Text="{x:Bind ModelDetails.ParameterSize}" />
<Run Text="params" />
@@ -123,13 +124,15 @@
Source="{x:Bind utils:AppUtils.GetModelSourceImageFromUrl(ModelDetails.Url)}">
<ToolTipService.ToolTip>
<TextBlock TextWrapping="Wrap">
<Run Text="This model is from" />
<Run Text="{x:Bind utils:AppUtils.GetModelSourceNameFromUrl(ModelDetails.Url)}" />
</TextBlock>
</ToolTipService.ToolTip>
</Image>
<TextBlock Text="• " />
<TextBlock VerticalAlignment="Center" Text="{x:Bind utils:AppUtils.GetLicenseShortNameFromString(ModelDetails.License)}">
<TextBlock Text="• "
Visibility="{x:Bind helpers:ModelDetailsHelper.ShowWhenDownloadedModel(ModelDetails)}" />
<TextBlock VerticalAlignment="Center"
Text="{x:Bind utils:AppUtils.GetLicenseShortNameFromString(ModelDetails.License)}"
Visibility="{x:Bind helpers:ModelDetailsHelper.ShowWhenDownloadedModel(ModelDetails)}">
<ToolTipService.ToolTip>
<TextBlock TextWrapping="Wrap">
<Run Text="This model is under the" />
@@ -216,20 +219,20 @@
Icon="{ui:FontIcon Glyph=&#xE8C8;}"
Tag="{x:Bind ModelDetails}"
Text="Copy as path"
Visibility="{x:Bind helpers:ModelDetailsHelper.ShowWhenDownloadedModel(ModelDetails)}" />
Visibility="{x:Bind helpers:ModelDetailsHelper.ShowWhenOnnxModel(ModelDetails)}" />
<MenuFlyoutItem
Click="OpenModelFolder_Click"
Icon="{ui:FontIcon Glyph=&#xE838;}"
Tag="{x:Bind ModelDetails}"
Text="Open containing folder"
Visibility="{x:Bind helpers:ModelDetailsHelper.ShowWhenDownloadedModel(ModelDetails)}" />
<MenuFlyoutSeparator Visibility="{x:Bind helpers:ModelDetailsHelper.ShowWhenDownloadedModel(ModelDetails)}" />
Visibility="{x:Bind helpers:ModelDetailsHelper.ShowWhenOnnxModel(ModelDetails)}" />
<MenuFlyoutSeparator Visibility="{x:Bind helpers:ModelDetailsHelper.ShowWhenOnnxModel(ModelDetails)}" />
<MenuFlyoutItem
Click="DeleteModel_Click"
Icon="{ui:FontIcon Glyph=&#xE74D;}"
Tag="{x:Bind ModelDetails}"
Text="Delete"
Visibility="{x:Bind helpers:ModelDetailsHelper.ShowWhenDownloadedModel(ModelDetails)}" />
Visibility="{x:Bind helpers:ModelDetailsHelper.ShowWhenOnnxModel(ModelDetails)}" />
</MenuFlyout>
</Button.Flyout>
</Button>
@@ -333,7 +336,6 @@
Source="{x:Bind utils:AppUtils.GetModelSourceImageFromUrl(ModelDetails.Url)}">
<ToolTipService.ToolTip>
<TextBlock TextWrapping="Wrap">
<Run Text="This model is from" />
<Run Text="{x:Bind utils:AppUtils.GetModelSourceNameFromUrl(ModelDetails.Url)}" />
</TextBlock>
</ToolTipService.ToolTip>
@@ -519,7 +521,6 @@
Source="{x:Bind utils:AppUtils.GetModelSourceImageFromUrl(ModelDetails.Url)}">
<ToolTipService.ToolTip>
<TextBlock TextWrapping="Wrap">
<Run Text="This model is from" />
<Run Text="{x:Bind utils:AppUtils.GetModelSourceNameFromUrl(ModelDetails.Url)}" />
</TextBlock>
</ToolTipService.ToolTip>
45 changes: 32 additions & 13 deletions AIDevGallery/Controls/ModelSelectionControl.xaml.cs
Original file line number Diff line number Diff line change
@@ -103,7 +103,7 @@ private void ResetAndLoadModelList(ModelDetails? selectedModel = null)
if (AvailableModels.Count > 0)
{
var modelIds = AvailableModels.Select(s => s.ModelDetails.Id);
var modelOrApiUsageHistory = App.AppData.UsageHistory.Where(id => modelIds.Contains(id));
var modelOrApiUsageHistory = App.AppData.UsageHistoryV2?.FirstOrDefault(u => modelIds.Contains(u.Id));

ModelDetails? modelToPreselect = null;

@@ -112,20 +112,33 @@ private void ResetAndLoadModelList(ModelDetails? selectedModel = null)
modelToPreselect = AvailableModels.Where(m => m.ModelDetails.Id == selectedModel.Id).FirstOrDefault()?.ModelDetails;
}

if (modelToPreselect != null)
if (modelToPreselect == null && modelOrApiUsageHistory != default)
{
SetSelectedModel(selectedModel);
}
else if (modelOrApiUsageHistory.Any())
{
// select most recently used if there is one
var modelId = modelOrApiUsageHistory.First();
SetSelectedModel(AvailableModels.Where(s => s.ModelDetails.Id == modelId).First().ModelDetails);
var models = AvailableModels.Where(am => am.ModelDetails.Id == modelOrApiUsageHistory.Id).ToList();
if (models.Count > 0)
{
if (modelOrApiUsageHistory.HardwareAccelerator != null)
{
var model = models.FirstOrDefault(m => m.ModelDetails.HardwareAccelerators.Contains(modelOrApiUsageHistory.HardwareAccelerator.Value));
if (model != null)
{
modelToPreselect = model.ModelDetails;
}
}

if (modelToPreselect == null)
{
modelToPreselect = models.FirstOrDefault()?.ModelDetails;
}
}
}
else

if (modelToPreselect == null)
{
SetSelectedModel(AvailableModels[0].ModelDetails);
modelToPreselect = AvailableModels[0].ModelDetails;
}

SetSelectedModel(modelToPreselect);
}
else
{
@@ -134,7 +147,7 @@ private void ResetAndLoadModelList(ModelDetails? selectedModel = null)
}
}

private void SetSelectedModel(ModelDetails? modelDetails)
private void SetSelectedModel(ModelDetails? modelDetails, HardwareAccelerator? accelerator = null)
{
if (modelDetails != null)
{
@@ -167,7 +180,13 @@ private void SetViewSelection(ModelDetails modelDetails)
if (IsSelectionEnabled)
{
ModelSelectionItemsView.DeselectAll();
ModelSelectionItemsView.Select(AvailableModels.IndexOf(AvailableModels.First(a => a.ModelDetails.Id == modelDetails.Id)));

var models = AvailableModels.Where(a => a.ModelDetails == modelDetails).ToList();

if (models.Count != 0)
{
ModelSelectionItemsView.Select(AvailableModels.IndexOf(models.First()));
}
}
}

11 changes: 10 additions & 1 deletion AIDevGallery/Helpers/ModelDetailsHelper.cs
Original file line number Diff line number Diff line change
@@ -155,11 +155,20 @@ public static Visibility ShowWhenOllama(ModelDetails modelDetails)
return modelDetails.HardwareAccelerators.Contains(HardwareAccelerator.OLLAMA) ? Visibility.Visible : Visibility.Collapsed;
}

public static Visibility ShowWhenDownloadedModel(ModelDetails modelDetails)
public static Visibility ShowWhenOnnxModel(ModelDetails modelDetails)
{
return modelDetails.HardwareAccelerators.Contains(HardwareAccelerator.CPU)
|| modelDetails.HardwareAccelerators.Contains(HardwareAccelerator.DML)
|| modelDetails.HardwareAccelerators.Contains(HardwareAccelerator.QNN)
? Visibility.Visible : Visibility.Collapsed;
}

public static Visibility ShowWhenDownloadedModel(ModelDetails modelDetails)
{
return (modelDetails.HardwareAccelerators.Contains(HardwareAccelerator.CPU)
|| modelDetails.HardwareAccelerators.Contains(HardwareAccelerator.DML)
|| modelDetails.HardwareAccelerators.Contains(HardwareAccelerator.QNN))
&& !modelDetails.IsUserAdded
? Visibility.Visible : Visibility.Collapsed;
}
}
1 change: 1 addition & 0 deletions AIDevGallery/Helpers/SamplesHelper.cs
Original file line number Diff line number Diff line change
@@ -73,6 +73,7 @@ public static List<string> GetAllNugetPackageReferences(this Sample sample, Dict
else
{
AddUnique("Microsoft.ML.OnnxRuntimeGenAI.DirectML");
AddUnique("Microsoft.AI.DirectML");
}
}

4 changes: 4 additions & 0 deletions AIDevGallery/MainWindow.xaml.cs
Original file line number Diff line number Diff line change
@@ -57,6 +57,10 @@ public void NavigateToPage(object? obj)
{
NavigateToApiOrModelPage(modelTypes[0]);
}
else if (obj is ModelDetails)
{
Navigate("Models", obj);
}
else
{
Navigate("Home");
7 changes: 6 additions & 1 deletion AIDevGallery/Models/BaseSampleNavigationParameters.cs
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@ internal abstract class BaseSampleNavigationParameters(TaskCompletionSource samp
public TaskCompletionSource SampleLoadedCompletionSource { get; set; } = sampleLoadedCompletionSource;

protected abstract string ChatClientModelPath { get; }
protected abstract HardwareAccelerator ChatClientHardwareAccelerator { get; }
protected abstract LlmPromptTemplate? ChatClientPromptTemplate { get; }

public void NotifyCompletion()
@@ -35,7 +36,11 @@ public void NotifyCompletion()
return new OllamaChatClient(OllamaHelper.GetOllamaUrl(), modelId);
}

return await GenAIModel.CreateAsync(ChatClientModelPath, ChatClientPromptTemplate, CancellationToken).ConfigureAwait(false);
return await GenAIModel.CreateAsync(
ChatClientModelPath,
ChatClientPromptTemplate,
ChatClientHardwareAccelerator == HardwareAccelerator.QNN ? "qnn" : null,
CancellationToken).ConfigureAwait(false);
}

internal abstract void SendSampleInteractionEvent(string? customInfo = null);
10 changes: 8 additions & 2 deletions AIDevGallery/Models/CachedModel.cs
Original file line number Diff line number Diff line change
@@ -20,11 +20,16 @@ internal class CachedModel
public CachedModel(ModelDetails details, string path, bool isFile, long modelSize)
{
Details = details;
if (details.Url.StartsWith("https://github.com", StringComparison.InvariantCulture))
if (details.Url.StartsWith("https://github.com", StringComparison.OrdinalIgnoreCase))
{
Url = details.Url;
Source = CachedModelSource.GitHub;
}
else if (details.Url.StartsWith("local", StringComparison.OrdinalIgnoreCase))
{
Url = details.Url;
Source = CachedModelSource.Local;
}
else
{
Url = new HuggingFaceUrl(details.Url).FullUrl;
@@ -42,5 +47,6 @@ public CachedModel(ModelDetails details, string path, bool isFile, long modelSiz
internal enum CachedModelSource
{
GitHub,
HuggingFace
HuggingFace,
Local
}
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@ internal class MultiModelSampleNavigationParameters(
public HardwareAccelerator[] HardwareAccelerators { get; } = hardwareAccelerators;

protected override string ChatClientModelPath => ModelPaths[0];
protected override HardwareAccelerator ChatClientHardwareAccelerator => HardwareAccelerators[0];
protected override LlmPromptTemplate? ChatClientPromptTemplate => promptTemplates[0];

internal override void SendSampleInteractionEvent(string? customInfo = null)
1 change: 1 addition & 0 deletions AIDevGallery/Models/SampleNavigationParameters.cs
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ internal class SampleNavigationParameters(
public string SampleId => sampleId;

protected override string ChatClientModelPath => ModelPath;
protected override HardwareAccelerator ChatClientHardwareAccelerator => HardwareAccelerator;
protected override LlmPromptTemplate? ChatClientPromptTemplate => promptTemplate;

internal override void SendSampleInteractionEvent(string? customInfo = null)
Loading