Get tests working with streaming responses

This commit is contained in:
Zach Gollwitzer
2025-03-27 14:33:15 -04:00
parent 05add6f4b3
commit b31e8990a1
22 changed files with 247 additions and 196 deletions

View File

@@ -136,21 +136,7 @@ A family sync happens once daily via [auto_sync.rb](mdc:app/controllers/concerns
The Maybe app utilizes several 3rd party data services to calculate historical account balances, enrich data, and more. Since the app can be run in both "hosted" and "self hosted" mode, this means that data providers are _optional_ for self hosted users and must be configured.
Because of this optionality, data providers must be configured at _runtime_ through the [providers.rb](mdc:app/models/providers.rb) module, utilizing [setting.rb](mdc:app/models/setting.rb) for runtime parameters like API keys:
```rb
module Providers
module_function
def synth
api_key = ENV.fetch("SYNTH_API_KEY", Setting.synth_api_key)
return nil unless api_key.present?
Provider::Synth.new(api_key)
end
end
```
Because of this optionality, data providers must be configured at _runtime_ through [registry.rb](mdc:app/models/provider/registry.rb) utilizing [setting.rb](mdc:app/models/setting.rb) for runtime parameters like API keys:
There are two types of 3rd party data in the Maybe app:
@@ -161,59 +147,20 @@ There are two types of 3rd party data in the Maybe app:
Since the app is self hostable, users may prefer using different providers for generic data like exchange rates and security prices. When data is generic enough where we can easily swap out different providers, we call it a data "concept".
Each "concept" _must_ have a `Provideable` concern that defines the methods that must be implemented along with the data shapes that are returned. For example, an "exchange rates concept" might look like this:
Each "concept" has an interface defined in the `app/models/provider/concepts` directory.
```
app/models/
exchange_rate.rb # <- ActiveRecord model and "concept"
exchange_rate/
provided.rb # <- Chooses the provider for this concept based on user settings / config
provideable.rb # <- Defines interface for providing exchange rates
provided.rb # <- Responsible for selecting the concept provider from the registry
provider.rb # <- Base provider class
provider/
registry.rb <- Defines available providers by concept
concepts/
exchange_rate.rb <- defines the interface required for the exchange rate concept
synth.rb # <- Concrete provider implementation
```
Where the `Provideable` and concrete provider implementations would be something like:
```rb
# Defines the interface an exchange rate provider must implement
module ExchangeRate::Provideable
extend ActiveSupport::Concern
FetchRateData = Data.define(:rate)
FetchRatesData = Data.define(:rates)
def fetch_exchange_rate(from:, to:, date:)
raise NotImplementedError, "Subclasses must implement #fetch_exchange_rate"
end
def fetch_exchange_rates(from:, to:, start_date:, end_date:)
raise NotImplementedError, "Subclasses must implement #fetch_exchange_rates"
end
end
```
Any provider that is a valid exchange rate provider must implement this interface:
```rb
class ConcreteProvider < Provider
include ExchangeRate::Provideable
def fetch_exchange_rate(from:, to:, date:)
provider_response do
ExchangeRate::Provideable::FetchRateData.new(
rate: ExchangeRate.new # build response
)
end
end
def fetch_exchange_rates(from:, to:, start_date:, end_date:)
# Implementation
end
end
```
### One-off data
For data that does not fit neatly into a "concept", a `Provideable` is not required and the concrete provider may implement ad-hoc methods called directly in code. For example, the [synth.rb](mdc:app/models/provider/synth.rb) provider has a `usage` method that is only applicable to this specific provider. This should be called directly without any abstractions:
@@ -221,14 +168,14 @@ For data that does not fit neatly into a "concept", a `Provideable` is not requi
```rb
class SomeModel < Application
def synth_usage
Providers.synth.usage
Provider::Registry.get_provider(:synth)&.usage
end
end
```
## "Provided" Concerns
In general, domain models should not be calling [providers.rb](mdc:app/models/providers.rb) (`Providers.some_provider`) directly. When 3rd party data is required for a domain model, we use the `Provided` concern within that model's namespace. This concern is primarily responsible for:
In general, domain models should not be calling [registry.rb](mdc:app/models/provider/registry.rb) directly. When 3rd party data is required for a domain model, we use the `Provided` concern within that model's namespace. This concern is primarily responsible for:
- Choosing the provider to use for this "concept"
- Providing convenience methods on the model for accessing data
@@ -241,7 +188,8 @@ module ExchangeRate::Provided
class_methods do
def provider
Providers.synth
registry = Provider::Registry.for_concept(:exchange_rates)
registry.get_provider(:synth)
end
def find_or_fetch_rate(from:, to:, date: Date.current, cache: true)

View File

@@ -113,7 +113,7 @@
}
.prose--ai-chat {
@apply break-words max-w-[300px];
@apply break-words;
p, li {
@apply text-sm text-primary;

View File

@@ -10,7 +10,7 @@ class PagesController < ApplicationController
end
def changelog
@release_notes = Providers.github.fetch_latest_release_notes
@release_notes = github_provider.fetch_latest_release_notes
render layout: "settings"
end
@@ -26,4 +26,9 @@ class PagesController < ApplicationController
@invite_code = InviteCode.order("RANDOM()").limit(1).first
render layout: false
end
private
def github_provider
Provider::Registry.get_provider(:github)
end
end

View File

@@ -6,7 +6,8 @@ class Settings::HostingsController < ApplicationController
before_action :ensure_admin, only: :clear_cache
def show
@synth_usage = Providers.synth&.usage
synth_provider = Provider::Registry.get_provider(:synth)
@synth_usage = synth_provider&.usage
end
def update

View File

@@ -2,9 +2,9 @@ module Account::Transaction::Provided
extend ActiveSupport::Concern
def fetch_enrichment_info
return nil unless Providers.synth # Only Synth can provide this data
return nil unless provider
response = Providers.synth.enrich_transaction(
response = provider.enrich_transaction(
entry.name,
amount: entry.amount,
date: entry.date
@@ -12,4 +12,9 @@ module Account::Transaction::Provided
response.data
end
private
def provider
Provider::Registry.get_provider(:synth)
end
end

View File

@@ -18,20 +18,14 @@ class Assistant
@chat = chat
end
def respond_to(message)
chat.clear_error
sleep artificial_thinking_delay
provider = get_model_provider(message.ai_model)
def streamer(model)
assistant_message = AssistantMessage.new(
chat: chat,
content: "",
ai_model: message.ai_model
ai_model: model
)
streamer = proc do |chunk|
proc do |chunk|
case chunk.type
when "output_text"
stop_thinking
@@ -57,12 +51,19 @@ class Assistant
chat.update!(latest_assistant_response_id: chunk.data.id)
end
end
end
def respond_to(message)
chat.clear_error
sleep artificial_thinking_delay
provider = get_model_provider(message.ai_model)
provider.chat_response(
message,
instructions: instructions,
available_functions: functions,
streamer: streamer
streamer: streamer(message.ai_model)
)
rescue => e
chat.add_error(e)

View File

@@ -2,11 +2,11 @@ module Assistant::Provided
extend ActiveSupport::Concern
def get_model_provider(ai_model)
available_providers.find { |provider| provider.supports_model?(ai_model) }
registry.providers.find { |provider| provider.supports_model?(ai_model) }
end
private
def available_providers
[ Providers.openai ].compact
def registry
@registry ||= Provider::Registry.for_concept(:llm)
end
end

View File

@@ -3,7 +3,8 @@ module ExchangeRate::Provided
class_methods do
def provider
Providers.synth
registry = Provider::Registry.for_concept(:exchange_rates)
registry.get_provider(:synth)
end
def find_or_fetch_rate(from:, to:, date: Date.current, cache: true)

View File

@@ -74,9 +74,9 @@ class Family < ApplicationRecord
def get_link_token(webhooks_url:, redirect_url:, accountable_type: nil, region: :us, access_token: nil)
provider = if region.to_sym == :eu
Providers.plaid_eu
Provider::Registry.get_provider(:plaid_eu)
else
Providers.plaid_us
Provider::Registry.get_provider(:plaid_us)
end
# early return when no provider

View File

@@ -3,11 +3,11 @@ module PlaidItem::Provided
class_methods do
def plaid_us_provider
Providers.plaid_us
Provider::Registry.get_provider(:plaid_us)
end
def plaid_eu_provider
Providers.plaid_eu
Provider::Registry.get_provider(:plaid_eu)
end
def plaid_provider_for_region(region)

View File

@@ -0,0 +1,16 @@
module Provider::Concept::ExchangeRates
extend ActiveSupport::Concern
def fetch_exchange_rate(from:, to:, date:)
raise NotImplementedError, "Subclasses must implement #fetch_exchange_rate"
end
def fetch_exchange_rates(from:, to:, start_date:, end_date:)
raise NotImplementedError, "Subclasses must implement #fetch_exchange_rates"
end
private
ProviderRate = Data.define(:from, :to, :date, :rate)
FetchExchangeRate = Data.define(:rate)
FetchExchangeRates = Data.define(:rates)
end

View File

@@ -0,0 +1,7 @@
module Provider::Concept::LLM
extend ActiveSupport::Concern
def chat_response(message, instructions: nil, available_functions: [], streamer: nil)
raise NotImplementedError, "Subclasses must implement #chat_response"
end
end

View File

@@ -0,0 +1,7 @@
module Provider::Concept::Securities
extend ActiveSupport::Concern
def fetch_security_price(symbol:, date:)
raise NotImplementedError, "Subclasses must implement #fetch_security_price"
end
end

View File

@@ -0,0 +1,91 @@
class Provider::Registry
include ActiveModel::Validations
Error = Class.new(StandardError)
CONCEPTS = %i[exchange_rates securities llm]
validates :concept, inclusion: { in: CONCEPTS }
class << self
def for_concept(concept)
new(concept.to_sym)
end
def get_provider(name)
send(name)
rescue NoMethodError
raise Error.new("Provider '#{name}' not found in registry")
end
private
def synth
api_key = ENV.fetch("SYNTH_API_KEY", Setting.synth_api_key)
return nil unless api_key.present?
Provider::Synth.new(api_key)
end
def plaid_us
config = Rails.application.config.plaid
return nil unless config.present?
Provider::Plaid.new(config, region: :us)
end
def plaid_eu
config = Rails.application.config.plaid_eu
return nil unless config.present?
Provider::Plaid.new(config, region: :eu)
end
def github
Provider::Github.new
end
def openai
access_token = ENV.fetch("OPENAI_ACCESS_TOKEN", Setting.openai_access_token)
return nil unless access_token.present?
Provider::Openai.new(access_token)
end
end
def initialize(concept)
@concept = concept
validate!
end
def providers
available_providers.map { |p| self.class.send(p) }
end
def get_provider(name)
provider_method = available_providers.find { |p| p == name.to_sym }
raise Error.new("Provider '#{name}' not found for concept: #{concept}") unless provider_method.present?
self.class.send(provider_method)
end
private
attr_reader :concept
def available_providers
case concept
when :exchange_rates
%i[synth]
when :securities
%i[synth]
when :llm
%i[openai]
else
%i[synth plaid_us plaid_eu github openai]
end
end
end

View File

@@ -1,39 +0,0 @@
module Providers
module_function
def synth
api_key = ENV.fetch("SYNTH_API_KEY", Setting.synth_api_key)
return nil unless api_key.present?
Provider::Synth.new(api_key)
end
def plaid_us
config = Rails.application.config.plaid
return nil unless config.present?
Provider::Plaid.new(config, region: :us)
end
def plaid_eu
config = Rails.application.config.plaid_eu
return nil unless config.present?
Provider::Plaid.new(config, region: :eu)
end
def github
Provider::Github.new
end
def openai
access_token = ENV.fetch("OPENAI_ACCESS_TOKEN", Setting.openai_access_token)
return nil unless access_token.present?
Provider::Openai.new(access_token)
end
end

View File

@@ -3,7 +3,8 @@ module Security::Provided
class_methods do
def provider
Providers.synth
registry = Provider::Registry.for_concept(:securities)
registry.get_provider(:synth)
end
def search_provider(symbol, country_code: nil, exchange_operating_mic: nil)

View File

@@ -1,6 +1,6 @@
<%# locals: (family:) %>
<% if family.requires_data_provider? && Providers.synth.nil? %>
<% if family.requires_data_provider? && Provider::Registry.get_provider(:synth).nil? %>
<details class="group bg-yellow-tint-10 rounded-lg p-2 text-yellow-600 mb-3 text-xs">
<summary class="flex items-center justify-between gap-2">
<div class="flex items-center gap-2">

View File

@@ -1,5 +1,10 @@
<%= render "layouts/shared/htmldoc" do %>
<div class="flex h-full bg-gray-50" data-controller="sidebar" data-sidebar-user-id-value="<%= Current.user.id %>">
<% sidebar_config = app_sidebar_config(Current.user) %>
<div class="flex h-full bg-gray-50"
data-controller="sidebar"
data-sidebar-user-id-value="<%= Current.user.id %>"
data-sidebar-config-value="<%= sidebar_config.to_json %>">
<nav class="flex flex-col shrink-0 w-[84px] py-4 mr-3">
<div class="pl-2 mb-3">
<%= link_to root_path, class: "block" do %>
@@ -26,7 +31,9 @@
</div>
</nav>
<%= tag.div class: class_names("py-4 shrink-0 h-full overflow-y-auto transition-all duration-300", Current.user.show_sidebar? ? "w-80" : "w-0"), data: { sidebar_target: "leftPanel" } do %>
<%= tag.div class: class_names("py-4 shrink-0 h-full overflow-y-auto transition-all duration-300"),
style: "width: #{sidebar_config.dig(:left_panel, :initial_width)}px",
data: { sidebar_target: "leftPanel" } do %>
<% if content_for?(:sidebar) %>
<%= yield :sidebar %>
<% else %>
@@ -36,18 +43,6 @@
<% end %>
<% end %>
<%
left_sidebar_open = Current.user.show_sidebar?
right_sidebar_open = Current.user.show_ai_sidebar?
content_width_class = if left_sidebar_open && right_sidebar_open
"max-w-3xl"
elsif left_sidebar_open || right_sidebar_open
"max-w-4xl"
else
"max-w-5xl"
end
%>
<%= tag.main class: class_names("px-10 py-4 grow h-full", require_upgrade? ? "relative overflow-hidden" : "overflow-y-auto") do %>
<% if require_upgrade? %>
<div class="absolute inset-0 px-10 h-full w-full z-50">
@@ -55,7 +50,7 @@
</div>
<% end %>
<%= tag.div class: class_names("mx-auto w-full h-full", content_width_class), data: { sidebar_target: "content" } do %>
<%= tag.div style: "max-width: #{sidebar_config.dig(:content_max_width)}px", class: class_names("mx-auto w-full h-full"), data: { sidebar_target: "content" } do %>
<% if content_for?(:breadcrumbs) %>
<%= yield :breadcrumbs %>
<% else %>
@@ -72,7 +67,8 @@
<%# AI chat sidebar %>
<%= tag.div id: "chat-container",
class: class_names("flex flex-col justify-between shrink-0 transition-all duration-300", right_sidebar_open ? "w-[400px]" : "w-0"),
style: "width: #{sidebar_config.dig(:right_panel, :initial_width)}px",
class: class_names("flex flex-col justify-between shrink-0 transition-all duration-300"),
data: { controller: "chat hotkey", sidebar_target: "rightPanel", turbo_permanent: true } do %>
<% if Current.user.ai_enabled? %>

View File

@@ -8,7 +8,7 @@ class Settings::HostingsControllerTest < ActionDispatch::IntegrationTest
sign_in users(:family_admin)
@provider = mock
Providers.stubs(:synth).returns(@provider)
Provider::Registry.stubs(:get_provider).with(:synth).returns(@provider)
@usage_response = provider_success_response(
OpenStruct.new(
used: 10,

View File

@@ -5,70 +5,81 @@ class AssistantTest < ActiveSupport::TestCase
setup do
@chat = chats(:two)
@message = @chat.messages.create!(
type: "UserMessage",
content: "Help me with my finances",
ai_model: "gpt-4o"
)
@assistant = Assistant.for_chat(@chat)
@provider = mock
@assistant.expects(:get_model_provider).with("gpt-4o").returns(@provider)
end
test "responds to basic prompt without tools" do
collected_chunks = []
streamer = proc do |chunk|
collected_chunks << chunk
end
@provider.expects(:chat_response).returns(
provider_success_response(
Assistant::Provideable::ChatResponse.new(
id: "1",
model: "gpt-4o",
messages: [
Assistant::Provideable::ChatResponseMessage.new(
id: "1",
content: "Hello from assistant",
)
],
functions: []
)
test "responds to basic prompt" do
text_chunk = Provider::Openai::ChatResponseProcessor::StreamChunk.new(type: "output_text", data: "Hello from assistant")
response_chunk = Provider::Openai::ChatResponseProcessor::StreamChunk.new(
type: "response",
data: Assistant::Provideable::ChatResponse.new(
id: "1",
model: "gpt-4o",
messages: [
Assistant::Provideable::ChatResponseMessage.new(
id: "1",
content: "Hello from assistant",
)
],
functions: []
)
)
assert_difference "Message.count", 1 do
@assistant.respond_to(messages(:chat2_user))
@provider.expects(:chat_response).with do |message, **options|
options[:streamer].call(text_chunk)
options[:streamer].call(response_chunk)
true
end
assert_difference "AssistantMessage.count", 1 do
@assistant.respond_to(@message)
end
end
test "responds with tool function calls" do
@provider.expects(:chat_response).returns(
provider_success_response(
Assistant::Provideable::ChatResponse.new(
id: "1",
model: "gpt-4o",
messages: [
Assistant::Provideable::ChatResponseMessage.new(
id: "1",
content: "Your net worth is $124,200",
)
],
functions: [
Assistant::Provideable::ChatResponseFunctionExecution.new(
id: "1",
call_id: "1",
name: "get_net_worth",
arguments: "{}",
result: "$124,200"
)
]
)
function_request_chunk = Provider::Openai::ChatResponseProcessor::StreamChunk.new(type: "function_request", data: "get_net_worth")
text_chunk = Provider::Openai::ChatResponseProcessor::StreamChunk.new(type: "output_text", data: "Your net worth is $124,200")
response_chunk = Provider::Openai::ChatResponseProcessor::StreamChunk.new(
type: "response",
data: Assistant::Provideable::ChatResponse.new(
id: "1",
model: "gpt-4o",
messages: [
Assistant::Provideable::ChatResponseMessage.new(
id: "1",
content: "Your net worth is $124,200",
)
],
functions: [
Assistant::Provideable::ChatResponseFunctionExecution.new(
id: "1",
call_id: "1",
name: "get_net_worth",
arguments: "{}",
result: "$124,200"
)
]
)
)
assert_difference "Message.count", 1 do
@assistant.respond_to(messages(:chat2_user))
@provider.expects(:chat_response).with do |message, **options|
options[:streamer].call(function_request_chunk)
options[:streamer].call(text_chunk)
options[:streamer].call(response_chunk)
true
end
message = @chat.messages.ordered.last
assert_equal 1, message.tool_calls.size
assert_difference "AssistantMessage.count", 1 do
@assistant.respond_to(@message)
message = @chat.messages.ordered.where(type: "AssistantMessage").last
assert_equal 1, message.tool_calls.size
end
end
end

View File

@@ -1,11 +1,11 @@
require "test_helper"
class ProvidersTest < ActiveSupport::TestCase
class Provider::RegistryTest < ActiveSupport::TestCase
test "synth configured with ENV" do
Setting.stubs(:synth_api_key).returns(nil)
with_env_overrides SYNTH_API_KEY: "123" do
assert_instance_of Provider::Synth, Providers.synth
assert_instance_of Provider::Synth, Provider::Registry.get_provider(:synth)
end
end
@@ -13,7 +13,7 @@ class ProvidersTest < ActiveSupport::TestCase
Setting.stubs(:synth_api_key).returns("123")
with_env_overrides SYNTH_API_KEY: nil do
assert_instance_of Provider::Synth, Providers.synth
assert_instance_of Provider::Synth, Provider::Registry.get_provider(:synth)
end
end
@@ -21,7 +21,7 @@ class ProvidersTest < ActiveSupport::TestCase
Setting.stubs(:synth_api_key).returns(nil)
with_env_overrides SYNTH_API_KEY: nil do
assert_nil Providers.synth
assert_nil Provider::Registry.get_provider(:synth)
end
end
end

View File

@@ -33,7 +33,7 @@ class SettingsTest < ApplicationSystemTestCase
test "can update self hosting settings" do
Rails.application.config.app_mode.stubs(:self_hosted?).returns(true)
Providers.stubs(:synth).returns(nil)
Provider::Registry.stubs(:get_provider).with(:synth).returns(nil)
open_settings_from_sidebar
assert_selector "li", text: "Self hosting"
click_link "Self hosting"