Skip to content

Commit

Permalink
Miguel/observe a11y (#412)
Browse files Browse the repository at this point in the history
* first steps towards accessibility backbone

* cleanup working for accessibility tree

* accessibility backbone eval task updates

* added accessibility tree to evals typing

* migrated extract construct to new stagehand page location

* fixing linting

* first try

* new observe logic for indexing elements (using nodeId-DOM nodeid)

* work in progress

* selector for a11y tree now in xpath format

* testing not returning

* generating xpath for elements not in selectormap

* passing evals locally

* adjusting evals

* merged main

* PR cleanup

* deleted unnecessary evals for now

* fixing liniting

* removing useAccessibilityTree from extract evals

* changes for review

* fixing lint errors

* final review fixes

* prettify

* resolved comments

* changeset

* added final comment
  • Loading branch information
miguelg719 authored Jan 20, 2025
1 parent fe3b044 commit 4aa4813
Show file tree
Hide file tree
Showing 18 changed files with 460 additions and 44 deletions.
5 changes: 5 additions & 0 deletions .changeset/empty-peas-smell.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@browserbasehq/stagehand": minor
---

Includes a new format to get website context using accessibility (a11y) trees. The new context is provided optionally with the flag useAccessibilityTree for observe tasks.
2 changes: 2 additions & 0 deletions evals/args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ if (extractMethodArg) {
// Set the extraction method in the process environment so tasks can reference it.
process.env.EXTRACT_METHOD = extractMethod;
const useTextExtract = process.env.EXTRACT_METHOD === "textExtract";
const useAccessibilityTree = process.env.EXTRACT_METHOD === "accessibilityTree";

/**
* Variables for filtering which tasks to run:
Expand Down Expand Up @@ -75,5 +76,6 @@ export {
filterByCategory,
filterByEvalName,
useTextExtract,
useAccessibilityTree,
DEFAULT_EVAL_CATEGORIES,
};
1 change: 0 additions & 1 deletion evals/evals.config.json
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@
"name": "vanta_h",
"categories": ["observe"]
},

{
"name": "extract_area_codes",
"categories": ["text_extract"]
Expand Down
8 changes: 7 additions & 1 deletion evals/index.eval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ import { env } from "./env";
import { generateExperimentName } from "./utils";
import { exactMatch, errorMatch } from "./scoring";
import { tasksByName, MODELS } from "./taskConfig";
import { filterByCategory, filterByEvalName, useTextExtract } from "./args";
import {
filterByCategory,
filterByEvalName,
useTextExtract,
useAccessibilityTree,
} from "./args";
import { Eval } from "braintrust";
import { EvalFunction, SummaryResult, Testcase } from "../types/evals";
import { EvalLogger } from "./logger";
Expand Down Expand Up @@ -221,6 +226,7 @@ const generateFilteredTestcases = (): Testcase[] => {
modelName: input.modelName,
logger,
useTextExtract,
useAccessibilityTree,
});

// Log result to console
Expand Down
8 changes: 6 additions & 2 deletions evals/tasks/ionwave_observe.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import { initStagehand } from "../initStagehand";
import { EvalFunction } from "../../types/evals";

export const ionwave_observe: EvalFunction = async ({ modelName, logger }) => {
export const ionwave_observe: EvalFunction = async ({
modelName,
logger,
useAccessibilityTree,
}) => {
const { stagehand, initResponse } = await initStagehand({
modelName,
logger,
Expand All @@ -11,7 +15,7 @@ export const ionwave_observe: EvalFunction = async ({ modelName, logger }) => {

await stagehand.page.goto("https://elpasotexas.ionwave.net/#.aspx");

const observations = await stagehand.page.observe();
const observations = await stagehand.page.observe({ useAccessibilityTree });

if (observations.length === 0) {
await stagehand.close();
Expand Down
8 changes: 6 additions & 2 deletions evals/tasks/panamcs.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import { initStagehand } from "../initStagehand";
import { EvalFunction } from "../../types/evals";

export const panamcs: EvalFunction = async ({ modelName, logger }) => {
export const panamcs: EvalFunction = async ({
modelName,
logger,
useAccessibilityTree,
}) => {
const { stagehand, initResponse } = await initStagehand({
modelName,
logger,
Expand All @@ -11,7 +15,7 @@ export const panamcs: EvalFunction = async ({ modelName, logger }) => {

await stagehand.page.goto("https://panamcs.org/about/staff/");

const observations = await stagehand.page.observe();
const observations = await stagehand.page.observe({ useAccessibilityTree });

if (observations.length === 0) {
await stagehand.close();
Expand Down
8 changes: 6 additions & 2 deletions evals/tasks/shopify_homepage.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import { initStagehand } from "../initStagehand";
import { EvalFunction } from "../../types/evals";

export const shopify_homepage: EvalFunction = async ({ modelName, logger }) => {
export const shopify_homepage: EvalFunction = async ({
modelName,
logger,
useAccessibilityTree,
}) => {
const { stagehand, initResponse } = await initStagehand({
modelName,
logger,
Expand All @@ -11,7 +15,7 @@ export const shopify_homepage: EvalFunction = async ({ modelName, logger }) => {

await stagehand.page.goto("https://www.shopify.com/");

const observations = await stagehand.page.observe();
const observations = await stagehand.page.observe({ useAccessibilityTree });

if (observations.length === 0) {
await stagehand.close();
Expand Down
8 changes: 6 additions & 2 deletions evals/tasks/vanta.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import { initStagehand } from "../initStagehand";
import { EvalFunction } from "../../types/evals";

export const vanta: EvalFunction = async ({ modelName, logger }) => {
export const vanta: EvalFunction = async ({
modelName,
logger,
useAccessibilityTree,
}) => {
const { stagehand, initResponse } = await initStagehand({
modelName,
logger,
Expand All @@ -12,7 +16,7 @@ export const vanta: EvalFunction = async ({ modelName, logger }) => {
await stagehand.page.goto("https://www.vanta.com/");
await stagehand.page.act({ action: "close the cookies popup" });

const observations = await stagehand.page.observe();
const observations = await stagehand.page.observe({ useAccessibilityTree });

if (observations.length === 0) {
await stagehand.close();
Expand Down
7 changes: 6 additions & 1 deletion evals/tasks/vanta_h.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import { initStagehand } from "../initStagehand";
import { EvalFunction } from "../../types/evals";

export const vanta_h: EvalFunction = async ({ modelName, logger }) => {
export const vanta_h: EvalFunction = async ({
modelName,
logger,
useAccessibilityTree,
}) => {
const { stagehand, initResponse } = await initStagehand({
modelName,
logger,
Expand All @@ -13,6 +17,7 @@ export const vanta_h: EvalFunction = async ({ modelName, logger }) => {

const observations = await stagehand.page.observe({
instruction: "find the buy now button if it is available",
useAccessibilityTree,
});

await stagehand.close();
Expand Down
34 changes: 34 additions & 0 deletions lib/StagehandPage.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type {
Page as PlaywrightPage,
BrowserContext as PlaywrightContext,
CDPSession,
} from "@playwright/test";
import { LLMClient } from "./llm/LLMClient";
import { ActOptions, ActResult, GotoOptions, Stagehand } from "./index";
Expand All @@ -25,6 +26,7 @@ export class StagehandPage {
private extractHandler: StagehandExtractHandler;
private observeHandler: StagehandObserveHandler;
private llmClient: LLMClient;
private cdpClient: CDPSession | null = null;

constructor(
page: PlaywrightPage,
Expand Down Expand Up @@ -460,6 +462,10 @@ export class StagehandPage {
value: llmClient.modelName,
type: "string",
},
useAccessibilityTree: {
value: options?.useAccessibilityTree ? "true" : "false",
type: "boolean",
},
},
});

Expand All @@ -473,6 +479,7 @@ export class StagehandPage {
fullPage: false,
requestId,
domSettleTimeoutMs: options?.domSettleTimeoutMs,
useAccessibilityTree: options?.useAccessibilityTree ?? false,
})
.catch((e) => {
this.stagehand.log({
Expand Down Expand Up @@ -506,4 +513,31 @@ export class StagehandPage {
throw e;
});
}

async getCDPClient(): Promise<CDPSession> {
if (!this.cdpClient) {
this.cdpClient = await this.context.newCDPSession(this.page);
}
return this.cdpClient;
}

async sendCDP<T>(
command: string,
args?: Record<string, unknown>,
): Promise<T> {
const client = await this.getCDPClient();
// Type assertion needed because CDP command strings are not fully typed
return client.send(
command as Parameters<CDPSession["send"]>[0],
args || {},
) as Promise<T>;
}

async enableCDP(domain: string): Promise<void> {
await this.sendCDP(`${domain}.enable`, {});
}

async disableCDP(domain: string): Promise<void> {
await this.sendCDP(`${domain}.disable`, {});
}
}
Loading

0 comments on commit 4aa4813

Please # to comment.