diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 6088e382bd..7df5f1bf35 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -23,6 +23,7 @@ "@patternfly/react-log-viewer": "^5.2.0", "@patternfly/react-styles": "^5.3.1", "@patternfly/react-table": "^5.3.3", + "@patternfly/react-templates": "^1.0.4", "@patternfly/react-tokens": "^5.3.1", "@patternfly/react-topology": "^5.4.0-prerelease.10", "@patternfly/react-virtualized-extension": "^5.1.0", @@ -4026,9 +4027,9 @@ } }, "node_modules/@patternfly/react-core": { - "version": "5.3.3", - "resolved": "https://registry.npmjs.org/@patternfly/react-core/-/react-core-5.3.3.tgz", - "integrity": "sha512-qq3j0M+Vi+Xmd+a/MhRhGgjdRh9Hnm79iA+L935HwMIVDcIWRYp6Isib/Ha4+Jk+f3Qdl0RT3dBDvr/4m6OpVQ==", + "version": "5.3.4", + "resolved": "https://registry.npmjs.org/@patternfly/react-core/-/react-core-5.3.4.tgz", + "integrity": "sha512-zr2yeilIoFp8MFOo0vNgI8XuM+P2466zHvy4smyRNRH2/but2WObqx7Wu4ftd/eBMYdNqmTeuXe6JeqqRqnPMQ==", "dependencies": { "@patternfly/react-icons": "^5.3.2", "@patternfly/react-styles": "^5.3.1", @@ -4107,6 +4108,22 @@ "react-dom": "^17 || ^18" } }, + "node_modules/@patternfly/react-templates": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/@patternfly/react-templates/-/react-templates-1.0.4.tgz", + "integrity": "sha512-SHhlgaPoY1eoCqr2xtift3pRldhVfs8x8zelTnhAnfrja5yVFJqTq/yediB0qm7OZ84wF2HIgSfuS0iM0/iG5A==", + "dependencies": { + "@patternfly/react-core": "^5.3.4", + "@patternfly/react-icons": "^5.3.2", + "@patternfly/react-styles": "^5.3.1", + "@patternfly/react-tokens": "^5.3.1", + "tslib": "^2.5.0" + }, + "peerDependencies": { + "react": "^17 || ^18", + "react-dom": "^17 || ^18" + } + }, "node_modules/@patternfly/react-tokens": { "version": "5.3.1", "resolved": "https://registry.npmjs.org/@patternfly/react-tokens/-/react-tokens-5.3.1.tgz", diff --git a/frontend/package.json b/frontend/package.json index 2adb953921..c72ad99211 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -66,6 +66,7 @@ "@patternfly/react-log-viewer": "^5.2.0", "@patternfly/react-styles": "^5.3.1", "@patternfly/react-table": "^5.3.3", + "@patternfly/react-templates": "^1.0.4", "@patternfly/react-tokens": "^5.3.1", "@patternfly/react-topology": "^5.4.0-prerelease.10", "@patternfly/react-virtualized-extension": "^5.1.0", @@ -149,13 +150,22 @@ "@babel/preset-react": "^7.18.6", "@babel/preset-typescript": "^7.21.5", "@cypress/code-coverage": "^3.12.34", + "@jsdevtools/coverage-istanbul-loader": "^3.0.5", + "@testing-library/cypress": "^10.0.1", + "@testing-library/dom": "^9.3.4", + "@testing-library/jest-dom": "^6.3.0", + "@testing-library/react": "^14.0.0", + "@testing-library/user-event": "^14.5.2", + "@types/chai-subset": "^1.3.5", + "@types/jest": "^28.1.8", + "@typescript-eslint/eslint-plugin": "^7.1.1", + "@typescript-eslint/parser": "^7.1.1", "chai-subset": "^1.6.0", "cypress": "^13.10.0", "cypress-axe": "^1.5.0", "cypress-high-resolution": "^1.0.0", "cypress-mochawesome-reporter": "^3.8.2", "cypress-multi-reporters": "^1.6.4", - "@types/jest": "^28.1.8", "eslint": "^8.57.0", "eslint-config-prettier": "^8.6.0", "eslint-import-resolver-node": "^0.3.7", @@ -174,19 +184,10 @@ "jest": "^28.1.3", "jest-environment-jsdom": "^29.4.3", "junit-report-merger": "^7.0.0", - "@jsdevtools/coverage-istanbul-loader": "^3.0.5", "mocha-junit-reporter": "^2.2.1", "npm-run-all": "^4.1.5", "nyc": "^15.1.0", "serve": "^14.2.1", - "@testing-library/cypress": "^10.0.1", - "@testing-library/dom": "^9.3.4", - "@testing-library/jest-dom": "^6.3.0", - "@testing-library/react": "^14.0.0", - "@testing-library/user-event": "^14.5.2", - "@types/chai-subset": "^1.3.5", - "@typescript-eslint/eslint-plugin": "^7.1.1", - "@typescript-eslint/parser": "^7.1.1", "ts-jest": "^28.0.8", "wait-on": "^7.2.0" }, diff --git a/frontend/src/__mocks__/mockModelArtifact.ts b/frontend/src/__mocks__/mockModelArtifact.ts index 78c75ae1ef..162a3af85e 100644 --- a/frontend/src/__mocks__/mockModelArtifact.ts +++ b/frontend/src/__mocks__/mockModelArtifact.ts @@ -1,6 +1,6 @@ import { ModelArtifact } from '~/concepts/modelRegistry/types'; -export const mockModelArtifact = (): ModelArtifact => ({ +export const mockModelArtifact = (partial?: Partial): ModelArtifact => ({ createTimeSinceEpoch: '1712234877179', id: '1', lastUpdateTimeSinceEpoch: '1712234877179', @@ -13,4 +13,5 @@ export const mockModelArtifact = (): ModelArtifact => ({ uri: 's3://test-bucket/demo-models/test-path?endpoint=test-endpoint&defaultRegion=test-region', modelFormatName: 'test model format', modelFormatVersion: 'test version 1', + ...partial, }); diff --git a/frontend/src/__mocks__/mockModelArtifactList.ts b/frontend/src/__mocks__/mockModelArtifactList.ts index 6e4e3edcef..da368c1dad 100644 --- a/frontend/src/__mocks__/mockModelArtifactList.ts +++ b/frontend/src/__mocks__/mockModelArtifactList.ts @@ -2,8 +2,10 @@ import { ModelArtifactList } from '~/concepts/modelRegistry/types'; import { mockModelArtifact } from './mockModelArtifact'; -export const mockModelArtifactList = (): ModelArtifactList => ({ - items: [mockModelArtifact()], +export const mockModelArtifactList = ({ + items = [mockModelArtifact()], +}: Partial): ModelArtifactList => ({ + items, nextPageToken: '', pageSize: 0, size: 1, diff --git a/frontend/src/__mocks__/mockModelVersion.ts b/frontend/src/__mocks__/mockModelVersion.ts index 29396dab1f..3593a127e6 100644 --- a/frontend/src/__mocks__/mockModelVersion.ts +++ b/frontend/src/__mocks__/mockModelVersion.ts @@ -9,6 +9,7 @@ type MockModelVersionType = { labels?: string[]; state?: ModelState; description?: string; + createTimeSinceEpoch?: string; }; export const mockModelVersion = ({ @@ -19,9 +20,10 @@ export const mockModelVersion = ({ id = '1', state = ModelState.LIVE, description = 'Description of model version', + createTimeSinceEpoch = '1712234877179', }: MockModelVersionType): ModelVersion => ({ author, - createTimeSinceEpoch: '1712234877179', + createTimeSinceEpoch, customProperties: createModelRegistryLabelsObject(labels), id, lastUpdateTimeSinceEpoch: '1712234877179', diff --git a/frontend/src/__tests__/cypress/cypress/pages/modelRegistry/registerVersionPage.ts b/frontend/src/__tests__/cypress/cypress/pages/modelRegistry/registerVersionPage.ts new file mode 100644 index 0000000000..d08eb3c7ea --- /dev/null +++ b/frontend/src/__tests__/cypress/cypress/pages/modelRegistry/registerVersionPage.ts @@ -0,0 +1,51 @@ +export enum FormFieldSelector { + REGISTERED_MODEL = '#registered-model-container .pf-m-typeahead', + VERSION_NAME = '#version-name', + VERSION_DESCRIPTION = '#version-description', + SOURCE_MODEL_FORMAT = '#source-model-format', + SOURCE_MODEL_FORMAT_VERSION = '#source-model-format-version', + LOCATION_TYPE_OBJECT_STORAGE = '#location-type-object-storage', + LOCATION_ENDPOINT = '#location-endpoint', + LOCATION_BUCKET = '#location-bucket', + LOCATION_REGION = '#location-region', + LOCATION_PATH = '#location-path', + LOCATION_TYPE_URI = '#location-type-uri', + LOCATION_URI = '#location-uri', +} + +class RegisterVersionPage { + visit(registeredModelId?: string) { + const preferredModelRegistry = 'modelregistry-sample'; + cy.visitWithLogin( + registeredModelId + ? `/modelRegistry/${preferredModelRegistry}/registeredModels/${registeredModelId}/registerVersion` + : `/modelRegistry/${preferredModelRegistry}/registerVersion`, + ); + this.wait(); + } + + private wait() { + const preferredModelRegistry = 'modelregistry-sample'; + cy.findByTestId('app-page-title').should('exist'); + cy.findByTestId('app-page-title').contains('Register new version'); + cy.findByText(`Model registry - ${preferredModelRegistry}`).should('exist'); + cy.testA11y(); + } + + findFormField(selector: FormFieldSelector) { + return cy.get(selector); + } + + selectRegisteredModel(name: string) { + this.findFormField(FormFieldSelector.REGISTERED_MODEL) + .findByRole('button', { name: 'Typeahead menu toggle' }) + .findSelectOption(name) + .click(); + } + + findSubmitButton() { + return cy.findByTestId('create-button'); + } +} + +export const registerVersionPage = new RegisterVersionPage(); diff --git a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersionDeploy.cy.ts b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersionDeploy.cy.ts index bb7f28d213..aaca0b7582 100644 --- a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersionDeploy.cy.ts +++ b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersionDeploy.cy.ts @@ -146,7 +146,7 @@ const initIntercepts = ({ modelVersionId: 1, }, }, - mockModelArtifactList(), + mockModelArtifactList({}), ); cy.interceptK8sList( diff --git a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersionDetails.cy.ts b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersionDetails.cy.ts index ff3057cb45..0ecc098a86 100644 --- a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersionDetails.cy.ts +++ b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/modelVersionDetails.cy.ts @@ -111,7 +111,7 @@ const initIntercepts = () => { modelVersionId: 1, }, }, - mockModelArtifactList(), + mockModelArtifactList({}), ); }; diff --git a/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/registerVersion.cy.ts b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/registerVersion.cy.ts new file mode 100644 index 0000000000..494abadbd5 --- /dev/null +++ b/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/registerVersion.cy.ts @@ -0,0 +1,585 @@ +import { mockDashboardConfig, mockDscStatus, mockK8sResourceList } from '~/__mocks__'; +import { mockDsciStatus } from '~/__mocks__/mockDsciStatus'; +import { StackCapability, StackComponent } from '~/concepts/areas/types'; +import { ServiceModel } from '~/__tests__/cypress/cypress/utils/models'; +import { + FormFieldSelector, + registerVersionPage, +} from '~/__tests__/cypress/cypress/pages/modelRegistry/registerVersionPage'; +import { mockRegisteredModel } from '~/__mocks__/mockRegisteredModel'; +import { mockModelVersion } from '~/__mocks__/mockModelVersion'; +import { mockModelArtifact } from '~/__mocks__/mockModelArtifact'; +import { mockModelRegistryService } from '~/__mocks__/mockModelRegistryService'; +import { mockRegisteredModelList } from '~/__mocks__/mockRegisteredModelsList'; +import { mockModelVersionList } from '~/__mocks__/mockModelVersionList'; +import { mockModelArtifactList } from '~/__mocks__/mockModelArtifactList'; +import { + ModelArtifactState, + ModelState, + type ModelVersion, + type ModelArtifact, +} from '~/concepts/modelRegistry/types'; + +const MODEL_REGISTRY_API_VERSION = 'v1alpha3'; + +const initIntercepts = () => { + cy.interceptOdh( + 'GET /api/config', + mockDashboardConfig({ + disableModelRegistry: false, + }), + ); + cy.interceptOdh( + 'GET /api/dsc/status', + mockDscStatus({ + installedComponents: { + [StackComponent.MODEL_REGISTRY]: true, + [StackComponent.MODEL_MESH]: true, + }, + }), + ); + cy.interceptOdh( + 'GET /api/dsci/status', + mockDsciStatus({ + requiredCapabilities: [StackCapability.SERVICE_MESH, StackCapability.SERVICE_MESH_AUTHZ], + }), + ); + + cy.interceptK8sList( + ServiceModel, + mockK8sResourceList([ + mockModelRegistryService({ name: 'modelregistry-sample' }), + mockModelRegistryService({ name: 'modelregistry-sample-2' }), + ]), + ); + + cy.interceptOdh( + `GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models`, + { + path: { serviceName: 'modelregistry-sample', apiVersion: MODEL_REGISTRY_API_VERSION }, + }, + mockRegisteredModelList({ + items: [ + mockRegisteredModel({ id: '1', name: 'Test model 1' }), + mockRegisteredModel({ id: '2', name: 'Test model 2' }), + mockRegisteredModel({ id: '3', name: 'Test model 3 has version but is missing artifact' }), + mockRegisteredModel({ id: '4', name: 'Test model 4 is missing version and artifact' }), + ], + }), + ); + + cy.interceptOdh( + 'GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models/:registeredModelId/versions', + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 1, + }, + }, + mockModelVersionList({ + items: [ + mockModelVersion({ + id: '1', + registeredModelId: '1', + name: 'Test older version for model 1', + createTimeSinceEpoch: '1712234877179', // Apr 04 2024 + }), + mockModelVersion({ + id: '2', + registeredModelId: '1', + name: 'Test latest version for model 1', + createTimeSinceEpoch: '1723659611927', // Aug 14 2024 + }), + ], + }), + ); + + cy.interceptOdh( + 'GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models/:registeredModelId/versions', + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 2, + }, + }, + mockModelVersionList({ + items: [ + mockModelVersion({ + id: '3', + registeredModelId: '2', + name: 'Test older version for model 2', + createTimeSinceEpoch: '1712234877179', // Apr 04 2024 + }), + mockModelVersion({ + id: '4', + registeredModelId: '2', + name: 'Test latest version for model 2', + createTimeSinceEpoch: '1723659611927', // Aug 14 2024 + }), + ], + }), + ); + + cy.interceptOdh( + 'GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models/:registeredModelId/versions', + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 3, + }, + }, + mockModelVersionList({ + items: [ + mockModelVersion({ + id: '5', + registeredModelId: '3', + name: 'Test version for model 3', + }), + ], + }), + ); + + cy.interceptOdh( + 'GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models/:registeredModelId/versions', + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 4, + }, + }, + mockModelVersionList({ + items: [], // Model 4 has no versions + }), + ); + + // Model id 1's latest version is id 2 + cy.interceptOdh( + 'GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/model_versions/:modelVersionId/artifacts', + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + modelVersionId: 2, + }, + }, + mockModelArtifactList({ + items: [ + mockModelArtifact({ + modelFormatName: 'test-version-id-2-format-name', + modelFormatVersion: 'test-version-id-2-format-version', + uri: 's3://test-bucket-version-id-2/demo-models/test-path?endpoint=test-endpoint-version-id-2&defaultRegion=test-region-version-id-2', + }), + ], + }), + ); + + // Model id 2's latest version is id 4 + cy.interceptOdh( + 'GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/model_versions/:modelVersionId/artifacts', + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + modelVersionId: 4, + }, + }, + mockModelArtifactList({ + items: [ + mockModelArtifact({ + modelFormatName: 'test-version-id-4-format-name', + modelFormatVersion: 'test-version-id-4-format-version', + uri: 'oops-malformed-uri', + }), + ], + }), + ); + + // Model id 3's latest version is id 5 + cy.interceptOdh( + 'GET /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/model_versions/:modelVersionId/artifacts', + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + modelVersionId: 5, + }, + }, + mockModelArtifactList({ + items: [], // Model 3 has no artifacts + }), + ); + + cy.interceptOdh( + 'POST /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/registered_models/:registeredModelId/versions', + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + registeredModelId: 1, + }, + }, + mockModelVersion({ id: '6', name: 'Test version name' }), + ).as('createModelVersion'); + + cy.interceptOdh( + 'POST /api/service/modelregistry/:serviceName/api/model_registry/:apiVersion/model_versions/:modelVersionId/artifacts', + { + path: { + serviceName: 'modelregistry-sample', + apiVersion: MODEL_REGISTRY_API_VERSION, + modelVersionId: 6, + }, + }, + mockModelArtifact(), + ).as('createModelArtifact'); +}; + +describe('Register model page with no preselected model', () => { + beforeEach(() => { + initIntercepts(); + registerVersionPage.visit(); + }); + + it('Prefills version/artifact details when a model is selected', () => { + registerVersionPage.selectRegisteredModel('Test model 1'); + cy.findByText('Current version is Test latest version for model 1').should('exist'); + registerVersionPage + .findFormField(FormFieldSelector.SOURCE_MODEL_FORMAT) + .should('have.value', 'test-version-id-2-format-name'); + registerVersionPage + .findFormField(FormFieldSelector.SOURCE_MODEL_FORMAT_VERSION) + .should('have.value', 'test-version-id-2-format-version'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_TYPE_OBJECT_STORAGE) + .should('be.checked'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_ENDPOINT) + .should('have.value', 'test-endpoint-version-id-2'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_BUCKET) + .should('have.value', 'test-bucket-version-id-2'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_REGION) + .should('have.value', 'test-region-version-id-2'); + + // Test model 2 has an invalid artifact URI so its object fields are reset + registerVersionPage.selectRegisteredModel('Test model 2'); + cy.findByText('Current version is Test latest version for model 2').should('exist'); + registerVersionPage + .findFormField(FormFieldSelector.SOURCE_MODEL_FORMAT) + .should('have.value', 'test-version-id-4-format-name'); + registerVersionPage + .findFormField(FormFieldSelector.SOURCE_MODEL_FORMAT_VERSION) + .should('have.value', 'test-version-id-4-format-version'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_TYPE_OBJECT_STORAGE) + .should('be.checked'); + registerVersionPage.findFormField(FormFieldSelector.LOCATION_ENDPOINT).should('have.value', ''); + registerVersionPage.findFormField(FormFieldSelector.LOCATION_BUCKET).should('have.value', ''); + registerVersionPage.findFormField(FormFieldSelector.LOCATION_REGION).should('have.value', ''); + + // Switching back should prefill them again + registerVersionPage.selectRegisteredModel('Test model 1'); + cy.findByText('Current version is Test latest version for model 1').should('exist'); + registerVersionPage + .findFormField(FormFieldSelector.SOURCE_MODEL_FORMAT) + .should('have.value', 'test-version-id-2-format-name'); + registerVersionPage + .findFormField(FormFieldSelector.SOURCE_MODEL_FORMAT_VERSION) + .should('have.value', 'test-version-id-2-format-version'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_TYPE_OBJECT_STORAGE) + .should('be.checked'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_ENDPOINT) + .should('have.value', 'test-endpoint-version-id-2'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_BUCKET) + .should('have.value', 'test-bucket-version-id-2'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_REGION) + .should('have.value', 'test-region-version-id-2'); + }); + + it('Clears prefilled details if switching to a model with missing artifact', () => { + registerVersionPage.selectRegisteredModel('Test model 1'); + registerVersionPage.selectRegisteredModel('Test model 3 has version but is missing artifact'); + registerVersionPage + .findFormField(FormFieldSelector.SOURCE_MODEL_FORMAT) + .should('have.value', ''); + registerVersionPage + .findFormField(FormFieldSelector.SOURCE_MODEL_FORMAT_VERSION) + .should('have.value', ''); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_TYPE_OBJECT_STORAGE) + .should('be.checked'); + registerVersionPage.findFormField(FormFieldSelector.LOCATION_ENDPOINT).should('have.value', ''); + registerVersionPage.findFormField(FormFieldSelector.LOCATION_BUCKET).should('have.value', ''); + registerVersionPage.findFormField(FormFieldSelector.LOCATION_REGION).should('have.value', ''); + }); + + it('Clears prefilled details if switching to a model with missing version', () => { + registerVersionPage.selectRegisteredModel('Test model 1'); + registerVersionPage.selectRegisteredModel('Test model 4 is missing version and artifact'); + registerVersionPage + .findFormField(FormFieldSelector.SOURCE_MODEL_FORMAT) + .should('have.value', ''); + registerVersionPage + .findFormField(FormFieldSelector.SOURCE_MODEL_FORMAT_VERSION) + .should('have.value', ''); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_TYPE_OBJECT_STORAGE) + .should('be.checked'); + registerVersionPage.findFormField(FormFieldSelector.LOCATION_ENDPOINT).should('have.value', ''); + registerVersionPage.findFormField(FormFieldSelector.LOCATION_BUCKET).should('have.value', ''); + registerVersionPage.findFormField(FormFieldSelector.LOCATION_REGION).should('have.value', ''); + }); + + it('Disables submit until required fields are filled in object storage mode', () => { + registerVersionPage.findSubmitButton().should('be.disabled'); + registerVersionPage.selectRegisteredModel('Test model 1'); + registerVersionPage.findFormField(FormFieldSelector.VERSION_NAME).type('Test version name'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_PATH) + .type('demo-models/flan-t5-small-caikit'); + registerVersionPage.findSubmitButton().should('be.enabled'); + }); + + it('Creates expected resources on submit in object storage mode', () => { + registerVersionPage.selectRegisteredModel('Test model 1'); + registerVersionPage.findFormField(FormFieldSelector.VERSION_NAME).type('Test version name'); + registerVersionPage + .findFormField(FormFieldSelector.VERSION_DESCRIPTION) + .type('Test version description'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_PATH) + .type('demo-models/flan-t5-small-caikit'); + + registerVersionPage.findSubmitButton().click(); + + cy.wait('@createModelVersion').then((interception) => { + expect(interception.request.body).to.containSubset({ + name: 'Test version name', + description: 'Test version description', + customProperties: {}, + state: ModelState.LIVE, + author: 'test-user', + registeredModelId: '1', + } satisfies Partial); + }); + cy.wait('@createModelArtifact').then((interception) => { + expect(interception.request.body).to.containSubset({ + name: 'Test model 1-Test version name-artifact', + description: 'Test version description', + customProperties: {}, + state: ModelArtifactState.LIVE, + author: 'test-user', + modelFormatName: 'test-version-id-2-format-name', + modelFormatVersion: 'test-version-id-2-format-version', + uri: 's3://test-bucket-version-id-2/demo-models/flan-t5-small-caikit?endpoint=test-endpoint-version-id-2&defaultRegion=test-region-version-id-2', + artifactType: 'model-artifact', + } satisfies Partial); + }); + + cy.url().should('include', '/modelRegistry/modelregistry-sample/registeredModels/1/versions'); + }); + + it('Disables submit until required fields are filled in URI mode', () => { + registerVersionPage.findSubmitButton().should('be.disabled'); + registerVersionPage.selectRegisteredModel('Test model 1'); + registerVersionPage.findFormField(FormFieldSelector.VERSION_NAME).type('Test version name'); + registerVersionPage.findFormField(FormFieldSelector.LOCATION_TYPE_URI).click(); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_URI) + .type( + 's3://test-bucket/demo-models/flan-t5-small-caikit?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F&defaultRegion=us-east-1', + ); + registerVersionPage.findSubmitButton().should('be.enabled'); + }); + + it('Creates expected resources on submit in URI mode', () => { + registerVersionPage.selectRegisteredModel('Test model 1'); + registerVersionPage.findFormField(FormFieldSelector.VERSION_NAME).type('Test version name'); + registerVersionPage + .findFormField(FormFieldSelector.VERSION_DESCRIPTION) + .type('Test version description'); + registerVersionPage.findFormField(FormFieldSelector.LOCATION_TYPE_URI).click(); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_URI) + .type( + 's3://test-bucket/demo-models/flan-t5-small-caikit?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F&defaultRegion=us-east-1', + ); + + registerVersionPage.findSubmitButton().click(); + + cy.wait('@createModelVersion').then((interception) => { + expect(interception.request.body).to.containSubset({ + name: 'Test version name', + description: 'Test version description', + customProperties: {}, + state: ModelState.LIVE, + author: 'test-user', + registeredModelId: '1', + } satisfies Partial); + }); + cy.wait('@createModelArtifact').then((interception) => { + expect(interception.request.body).to.containSubset({ + name: 'Test model 1-Test version name-artifact', + description: 'Test version description', + customProperties: {}, + state: ModelArtifactState.LIVE, + author: 'test-user', + modelFormatName: 'test-version-id-2-format-name', + modelFormatVersion: 'test-version-id-2-format-version', + uri: 's3://test-bucket/demo-models/flan-t5-small-caikit?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F&defaultRegion=us-east-1', + artifactType: 'model-artifact', + } satisfies Partial); + }); + + cy.url().should('include', '/modelRegistry/modelregistry-sample/registeredModels/1/versions'); + }); +}); + +describe('Register model page with preselected model', () => { + beforeEach(() => { + initIntercepts(); + }); + + it('Prefills version/artifact details for the preselected model', () => { + registerVersionPage.visit('1'); + cy.findByText('Current version is Test latest version for model 1').should('exist'); + registerVersionPage + .findFormField(FormFieldSelector.SOURCE_MODEL_FORMAT) + .should('have.value', 'test-version-id-2-format-name'); + registerVersionPage + .findFormField(FormFieldSelector.SOURCE_MODEL_FORMAT_VERSION) + .should('have.value', 'test-version-id-2-format-version'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_TYPE_OBJECT_STORAGE) + .should('be.checked'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_ENDPOINT) + .should('have.value', 'test-endpoint-version-id-2'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_BUCKET) + .should('have.value', 'test-bucket-version-id-2'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_REGION) + .should('have.value', 'test-region-version-id-2'); + }); + + it('Does not prefill location fields if the URI on the artifact is malformed', () => { + registerVersionPage.visit('2'); + registerVersionPage.findFormField(FormFieldSelector.LOCATION_ENDPOINT).should('have.value', ''); + registerVersionPage.findFormField(FormFieldSelector.LOCATION_BUCKET).should('have.value', ''); + registerVersionPage.findFormField(FormFieldSelector.LOCATION_REGION).should('have.value', ''); + }); + + it('Disables submit until required fields are filled in object storage mode', () => { + registerVersionPage.visit('1'); + registerVersionPage.findSubmitButton().should('be.disabled'); + registerVersionPage.findFormField(FormFieldSelector.VERSION_NAME).type('Test version name'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_PATH) + .type('demo-models/flan-t5-small-caikit'); + registerVersionPage.findSubmitButton().should('be.enabled'); + }); + + it('Creates expected resources in object storage mode', () => { + registerVersionPage.visit('1'); + registerVersionPage.findFormField(FormFieldSelector.VERSION_NAME).type('Test version name'); + registerVersionPage + .findFormField(FormFieldSelector.VERSION_DESCRIPTION) + .type('Test version description'); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_PATH) + .type('demo-models/flan-t5-small-caikit'); + + registerVersionPage.findSubmitButton().click(); + + cy.wait('@createModelVersion').then((interception) => { + expect(interception.request.body).to.containSubset({ + name: 'Test version name', + description: 'Test version description', + customProperties: {}, + state: ModelState.LIVE, + author: 'test-user', + registeredModelId: '1', + } satisfies Partial); + }); + cy.wait('@createModelArtifact').then((interception) => { + expect(interception.request.body).to.containSubset({ + name: 'Test model 1-Test version name-artifact', + description: 'Test version description', + customProperties: {}, + state: ModelArtifactState.LIVE, + author: 'test-user', + modelFormatName: 'test-version-id-2-format-name', + modelFormatVersion: 'test-version-id-2-format-version', + uri: 's3://test-bucket-version-id-2/demo-models/flan-t5-small-caikit?endpoint=test-endpoint-version-id-2&defaultRegion=test-region-version-id-2', + artifactType: 'model-artifact', + } satisfies Partial); + }); + + cy.url().should('include', '/modelRegistry/modelregistry-sample/registeredModels/1/versions'); + }); + + it('Disables submit until required fields are filled in URI mode', () => { + registerVersionPage.visit('1'); + registerVersionPage.findSubmitButton().should('be.disabled'); + registerVersionPage.findFormField(FormFieldSelector.VERSION_NAME).type('Test version name'); + registerVersionPage.findFormField(FormFieldSelector.LOCATION_TYPE_URI).click(); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_URI) + .type( + 's3://test-bucket/demo-models/flan-t5-small-caikit?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F&defaultRegion=us-east-1', + ); + registerVersionPage.findSubmitButton().should('be.enabled'); + }); + + it('Creates expected resources in URI mode', () => { + registerVersionPage.visit('1'); + registerVersionPage.findFormField(FormFieldSelector.VERSION_NAME).type('Test version name'); + registerVersionPage + .findFormField(FormFieldSelector.VERSION_DESCRIPTION) + .type('Test version description'); + registerVersionPage.findFormField(FormFieldSelector.LOCATION_TYPE_URI).click(); + registerVersionPage + .findFormField(FormFieldSelector.LOCATION_URI) + .type( + 's3://test-bucket/demo-models/flan-t5-small-caikit?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F&defaultRegion=us-east-1', + ); + + registerVersionPage.findSubmitButton().click(); + + cy.wait('@createModelVersion').then((interception) => { + expect(interception.request.body).to.containSubset({ + name: 'Test version name', + description: 'Test version description', + customProperties: {}, + state: ModelState.LIVE, + author: 'test-user', + registeredModelId: '1', + } satisfies Partial); + }); + cy.wait('@createModelArtifact').then((interception) => { + expect(interception.request.body).to.containSubset({ + name: 'Test model 1-Test version name-artifact', + description: 'Test version description', + customProperties: {}, + state: ModelArtifactState.LIVE, + author: 'test-user', + modelFormatName: 'test-version-id-2-format-name', + modelFormatVersion: 'test-version-id-2-format-version', + uri: 's3://test-bucket/demo-models/flan-t5-small-caikit?endpoint=http%3A%2F%2Fs3.amazonaws.com%2F&defaultRegion=us-east-1', + artifactType: 'model-artifact', + } satisfies Partial); + }); + + cy.url().should('include', '/modelRegistry/modelregistry-sample/registeredModels/1/versions'); + }); +}); diff --git a/frontend/src/concepts/modelRegistry/__tests__/utils.spec.ts b/frontend/src/concepts/modelRegistry/__tests__/utils.spec.ts index 314277b795..95a80e406c 100644 --- a/frontend/src/concepts/modelRegistry/__tests__/utils.spec.ts +++ b/frontend/src/concepts/modelRegistry/__tests__/utils.spec.ts @@ -1,8 +1,16 @@ +import { mockModelVersion } from '~/__mocks__/mockModelVersion'; +import { mockRegisteredModel } from '~/__mocks__/mockRegisteredModel'; import { + filterArchiveModels, + filterArchiveVersions, + filterLiveModels, + filterLiveVersions, + getLastCreatedItem, ObjectStorageFields, objectStorageFieldsToUri, uriToObjectStorageFields, } from '~/concepts/modelRegistry/utils'; +import { RegisteredModel, ModelState, ModelVersion } from '~/concepts/modelRegistry/types'; describe('objectStorageFieldsToUri', () => { it('converts fields to URI with all fields present', () => { @@ -120,3 +128,103 @@ describe('uriToObjectStorageFields', () => { expect(fields).toBeNull(); }); }); + +describe('getLastCreatedItem', () => { + it('returns the latest item correctly', () => { + const items = [ + { + foo: 'a', + createTimeSinceEpoch: '1712234877179', // Apr 04 2024 + }, + { + foo: 'b', + createTimeSinceEpoch: '1723659611927', // Aug 14 2024 + }, + ]; + expect(getLastCreatedItem(items)).toBe(items[1]); + }); + + it('returns first item if items have no createTimeSinceEpoch', () => { + const items = [ + { foo: 'a', createTimeSinceEpoch: undefined }, + { foo: 'b', createTimeSinceEpoch: undefined }, + ]; + expect(getLastCreatedItem(items)).toBe(items[0]); + }); +}); + +describe('Filter model state', () => { + const models: RegisteredModel[] = [ + mockRegisteredModel({ name: 'Test 1', state: ModelState.ARCHIVED }), + mockRegisteredModel({ + name: 'Test 2', + state: ModelState.LIVE, + description: 'Description2', + }), + mockRegisteredModel({ name: 'Test 3', state: ModelState.ARCHIVED }), + mockRegisteredModel({ name: 'Test 4', state: ModelState.ARCHIVED }), + mockRegisteredModel({ name: 'Test 5', state: ModelState.LIVE }), + ]; + + describe('filterArchiveModels', () => { + it('should filter out only the archived versions', () => { + const archivedModels = filterArchiveModels(models); + expect(archivedModels).toEqual([models[0], models[2], models[3]]); + }); + + it('should return an empty array if the input array is empty', () => { + const result = filterArchiveModels([]); + expect(result).toEqual([]); + }); + }); + + describe('filterLiveModels', () => { + it('should filter out only the live models', () => { + const liveModels = filterLiveModels(models); + expect(liveModels).toEqual([models[1], models[4]]); + }); + + it('should return an empty array if the input array is empty', () => { + const result = filterLiveModels([]); + expect(result).toEqual([]); + }); + }); +}); + +describe('Filter model version state', () => { + const modelVersions: ModelVersion[] = [ + mockModelVersion({ name: 'Test 1', state: ModelState.ARCHIVED }), + mockModelVersion({ + name: 'Test 2', + state: ModelState.LIVE, + description: 'Description2', + }), + mockModelVersion({ name: 'Test 3', author: 'Author3', state: ModelState.ARCHIVED }), + mockModelVersion({ name: 'Test 4', state: ModelState.ARCHIVED }), + mockModelVersion({ name: 'Test 5', state: ModelState.LIVE }), + ]; + + describe('filterArchiveVersions', () => { + it('should filter out only the archived versions', () => { + const archivedVersions = filterArchiveVersions(modelVersions); + expect(archivedVersions).toEqual([modelVersions[0], modelVersions[2], modelVersions[3]]); + }); + + it('should return an empty array if the input array is empty', () => { + const result = filterArchiveVersions([]); + expect(result).toEqual([]); + }); + }); + + describe('filterLiveVersions', () => { + it('should filter out only the live versions', () => { + const liveVersions = filterLiveVersions(modelVersions); + expect(liveVersions).toEqual([modelVersions[1], modelVersions[4]]); + }); + + it('should return an empty array if the input array is empty', () => { + const result = filterLiveVersions([]); + expect(result).toEqual([]); + }); + }); +}); diff --git a/frontend/src/concepts/modelRegistry/utils.ts b/frontend/src/concepts/modelRegistry/utils.ts index b9d39754d4..4d203e93f6 100644 --- a/frontend/src/concepts/modelRegistry/utils.ts +++ b/frontend/src/concepts/modelRegistry/utils.ts @@ -1,3 +1,5 @@ +import { ModelVersion, ModelState, RegisteredModel } from './types'; + export type ObjectStorageFields = { endpoint: string; bucket: string; @@ -37,3 +39,27 @@ export const uriToObjectStorageFields = (uri: string): ObjectStorageFields | nul return null; } }; + +export const getLastCreatedItem = ( + items?: T[], +): T | undefined => + items?.toSorted( + ({ createTimeSinceEpoch: createTimeA }, { createTimeSinceEpoch: createTimeB }) => { + if (!createTimeA || !createTimeB) { + return 0; + } + return Number(createTimeB) - Number(createTimeA); + }, + )[0]; + +export const filterArchiveVersions = (modelVersions: ModelVersion[]): ModelVersion[] => + modelVersions.filter((mv) => mv.state === ModelState.ARCHIVED); + +export const filterLiveVersions = (modelVersions: ModelVersion[]): ModelVersion[] => + modelVersions.filter((mv) => mv.state === ModelState.LIVE); + +export const filterArchiveModels = (registeredModels: RegisteredModel[]): RegisteredModel[] => + registeredModels.filter((rm) => rm.state === ModelState.ARCHIVED); + +export const filterLiveModels = (registeredModels: RegisteredModel[]): RegisteredModel[] => + registeredModels.filter((rm) => rm.state === ModelState.LIVE); diff --git a/frontend/src/pages/modelRegistry/ModelRegistryCoreLoader.tsx b/frontend/src/pages/modelRegistry/ModelRegistryCoreLoader.tsx index dcdff3cdb6..4c0a97a6f4 100644 --- a/frontend/src/pages/modelRegistry/ModelRegistryCoreLoader.tsx +++ b/frontend/src/pages/modelRegistry/ModelRegistryCoreLoader.tsx @@ -13,6 +13,7 @@ import { ModelRegistrySelectorContext } from '~/concepts/modelRegistry/context/M import InvalidModelRegistry from './screens/InvalidModelRegistry'; import EmptyModelRegistryState from './screens/components/EmptyModelRegistryState'; import ModelRegistrySelectorNavigator from './screens/ModelRegistrySelectorNavigator'; +import { modelRegistryUrl } from './screens/routeUtils'; type ApplicationPageProps = React.ComponentProps; @@ -138,7 +139,7 @@ const ModelRegistryCoreLoader: React.FC = description="View and manage all of your registered models. Registering models to model registry allows you to manage their content, metadata, versions, and user access settings." headerContent={ `/modelRegistry/${modelRegistryName}`} + getRedirectPath={(modelRegistryName) => modelRegistryUrl(modelRegistryName)} /> } {...renderStateProps} diff --git a/frontend/src/pages/modelRegistry/ModelRegistryRoutes.tsx b/frontend/src/pages/modelRegistry/ModelRegistryRoutes.tsx index 090664cb39..c0665cc311 100644 --- a/frontend/src/pages/modelRegistry/ModelRegistryRoutes.tsx +++ b/frontend/src/pages/modelRegistry/ModelRegistryRoutes.tsx @@ -11,6 +11,8 @@ import ModelVersionsArchiveDetails from './screens/ModelVersionsArchive/ModelVer import RegisteredModelsArchive from './screens/RegisteredModelsArchive/RegisteredModelsArchive'; import RegisteredModelsArchiveDetails from './screens/RegisteredModelsArchive/RegisteredModelArchiveDetails'; import RegisterModel from './screens/RegisterModel/RegisterModel'; +import RegisterVersion from './screens/RegisterModel/RegisterVersion'; +import { modelRegistryUrl } from './screens/routeUtils'; const ModelRegistryRoutes: React.FC = () => ( @@ -18,7 +20,7 @@ const ModelRegistryRoutes: React.FC = () => ( path={'/:modelRegistry?/*'} element={ `/modelRegistry/${modelRegistry}`} + getInvalidRedirectPath={(modelRegistry) => modelRegistryUrl(modelRegistry)} /> } > @@ -33,6 +35,7 @@ const ModelRegistryRoutes: React.FC = () => ( path={ModelVersionsTab.DETAILS} element={} /> + } /> } /> ( } /> } /> + } /> } /> diff --git a/frontend/src/pages/modelRegistry/screens/InvalidModelRegistry.tsx b/frontend/src/pages/modelRegistry/screens/InvalidModelRegistry.tsx index 6042a50a37..8705b330db 100644 --- a/frontend/src/pages/modelRegistry/screens/InvalidModelRegistry.tsx +++ b/frontend/src/pages/modelRegistry/screens/InvalidModelRegistry.tsx @@ -1,6 +1,7 @@ import * as React from 'react'; import EmptyStateErrorMessage from '~/components/EmptyStateErrorMessage'; import ModelRegistrySelectorNavigator from './ModelRegistrySelectorNavigator'; +import { modelRegistryUrl } from './routeUtils'; type InvalidModelRegistryProps = { title?: string; @@ -15,7 +16,7 @@ const InvalidModelRegistry: React.FC = ({ title, mode } was not found.`} > `/modelRegistry/${modelRegistryName}`} + getRedirectPath={(modelRegistryName) => modelRegistryUrl(modelRegistryName)} primary /> diff --git a/frontend/src/pages/modelRegistry/screens/ModelRegistry.tsx b/frontend/src/pages/modelRegistry/screens/ModelRegistry.tsx index 016f834de5..7490594c8f 100644 --- a/frontend/src/pages/modelRegistry/screens/ModelRegistry.tsx +++ b/frontend/src/pages/modelRegistry/screens/ModelRegistry.tsx @@ -3,9 +3,10 @@ import ApplicationsPage from '~/pages/ApplicationsPage'; import useRegisteredModels from '~/concepts/modelRegistry/apiHooks/useRegisteredModels'; import TitleWithIcon from '~/concepts/design/TitleWithIcon'; import { ProjectObjectType } from '~/concepts/design/utils'; +import { filterLiveModels } from '~/concepts/modelRegistry/utils'; import RegisteredModelListView from './RegisteredModels/RegisteredModelListView'; import ModelRegistrySelectorNavigator from './ModelRegistrySelectorNavigator'; -import { filterLiveModels } from './utils'; +import { modelRegistryUrl } from './routeUtils'; type ModelRegistryProps = Omit< React.ComponentProps, @@ -28,7 +29,7 @@ const ModelRegistry: React.FC = ({ ...pageProps }) => { description="View and manage all of your registered models. Registering models to model registry allows you to manage their content, metadata, versions, and user access settings." headerContent={ `/modelRegistry/${modelRegistryName}`} + getRedirectPath={(modelRegistryName) => modelRegistryUrl(modelRegistryName)} /> } loadError={loadError} diff --git a/frontend/src/pages/modelRegistry/screens/ModelVersions/ModelVersionListView.tsx b/frontend/src/pages/modelRegistry/screens/ModelVersions/ModelVersionListView.tsx index 76bface23e..ce9d99c172 100644 --- a/frontend/src/pages/modelRegistry/screens/ModelVersions/ModelVersionListView.tsx +++ b/frontend/src/pages/modelRegistry/screens/ModelVersions/ModelVersionListView.tsx @@ -21,7 +21,10 @@ import SimpleSelect from '~/components/SimpleSelect'; import EmptyModelRegistryState from '~/pages/modelRegistry/screens/components/EmptyModelRegistryState'; import { filterModelVersions } from '~/pages/modelRegistry/screens/utils'; import { ModelRegistrySelectorContext } from '~/concepts/modelRegistry/context/ModelRegistrySelectorContext'; -import { modelVersionArchiveUrl } from '~/pages/modelRegistry/screens/routeUtils'; +import { + modelVersionArchiveUrl, + registerVersionForModelUrl, +} from '~/pages/modelRegistry/screens/routeUtils'; import { asEnumMember } from '~/utilities/utils'; import ModelVersionsTable from './ModelVersionsTable'; @@ -58,7 +61,7 @@ const ModelVersionListView: React.FC = ({ primaryActionText="Register new version" secondaryActionText="View archived versions" primaryActionOnClick={() => { - // TODO: Add primary action + navigate(registerVersionForModelUrl(rm?.id, preferredModelRegistry?.metadata.name)); }} secondaryActionOnClick={() => { navigate(modelVersionArchiveUrl(rm?.id, preferredModelRegistry?.metadata.name)); @@ -113,7 +116,14 @@ const ModelVersionListView: React.FC = ({ - + = ({ mrName }) => ( + + + +); + +export default PrefilledModelRegistryField; diff --git a/frontend/src/pages/modelRegistry/screens/RegisterModel/RegisterModel.tsx b/frontend/src/pages/modelRegistry/screens/RegisterModel/RegisterModel.tsx index 6d385c37aa..ea605d9517 100644 --- a/frontend/src/pages/modelRegistry/screens/RegisterModel/RegisterModel.tsx +++ b/frontend/src/pages/modelRegistry/screens/RegisterModel/RegisterModel.tsx @@ -1,21 +1,10 @@ import React from 'react'; import { - ActionGroup, - Alert, - AlertActionCloseButton, Breadcrumb, BreadcrumbItem, - Button, Form, FormGroup, - HelperText, - HelperTextItem, - InputGroupItem, - InputGroupText, PageSection, - Radio, - Split, - SplitItem, Stack, StackItem, TextArea, @@ -24,78 +13,34 @@ import { import spacing from '@patternfly/react-styles/css/utilities/Spacing/spacing'; import { useParams, useNavigate } from 'react-router'; import { Link } from 'react-router-dom'; -import { OptimizeIcon } from '@patternfly/react-icons'; import FormSection from '~/components/pf-overrides/FormSection'; import ApplicationsPage from '~/pages/ApplicationsPage'; -import { ModelRegistryContext } from '~/concepts/modelRegistry/context/ModelRegistryContext'; -import { useAppSelector } from '~/redux/hooks'; -import { DataConnection } from '~/pages/projects/types'; -import { convertAWSSecretData } from '~/pages/projects/screens/detail/data-connections/utils'; -import { - useRegisterModelData, - ModelLocationType, - RegisterVersionFormData, -} from './useRegisterModelData'; -import { registerModel } from './utils'; -import { ConnectionModal } from './ConnectionModal'; +import { modelRegistryUrl, registeredModelUrl } from '~/pages/modelRegistry/screens/routeUtils'; +import { ValueOf } from '~/typeHelpers'; +import { useRegisterModelData, RegistrationCommonFormData } from './useRegisterModelData'; +import { isRegisterModelSubmitDisabled, registerModel } from './utils'; +import RegistrationCommonFormSections from './RegistrationCommonFormSections'; +import { useRegistrationCommonState } from './useRegistrationCommonState'; +import PrefilledModelRegistryField from './PrefilledModelRegistryField'; +import RegistrationFormFooter from './RegistrationFormFooter'; const RegisterModel: React.FC = () => { const { modelRegistry: mrName } = useParams(); const navigate = useNavigate(); - const [formData, setData] = useRegisterModelData(); - const { - modelName, - modelDescription, - versionName, - versionDescription, - sourceModelFormat, - sourceModelFormatVersion, - modelLocationType, - modelLocationEndpoint, - modelLocationBucket, - modelLocationRegion, - modelLocationPath, - modelLocationURI, - } = formData; - const [loading, setIsLoading] = React.useState(false); - const [formError, setFormError] = React.useState(undefined); - const [isAutofillModalOpen, setAutofillModalOpen] = React.useState(false); - - const { apiState } = React.useContext(ModelRegistryContext); - const author = useAppSelector((state) => state.user || ''); - const isSubmitDisabled = - !modelName || - !versionName || - loading || - (modelLocationType === ModelLocationType.URI && !modelLocationURI) || - (modelLocationType === ModelLocationType.ObjectStorage && - (!modelLocationBucket || !modelLocationEndpoint || !modelLocationPath)); - const handleSubmit = () => { - setIsLoading(true); - setFormError(undefined); + const { isSubmitting, submitError, setSubmitError, handleSubmit, apiState, author } = + useRegistrationCommonState(); - registerModel(apiState, formData, author) - .then(({ registeredModel }) => { - navigate(`/modelRegistry/${mrName}/registeredModels/${registeredModel.id}`); - }) - .catch((e: Error) => { - setIsLoading(false); - setFormError(e); - }); - }; - - const connectionDataMap: Record = { - AWS_S3_ENDPOINT: 'modelLocationEndpoint', - AWS_S3_BUCKET: 'modelLocationBucket', - AWS_DEFAULT_REGION: 'modelLocationRegion', - }; + const [formData, setData] = useRegisterModelData(); + const isSubmitDisabled = isSubmitting || isRegisterModelSubmitDisabled(formData); + const { modelName, modelDescription } = formData; - const fillObjectStorageByConnection = (connection: DataConnection) => { - convertAWSSecretData(connection).forEach((dataItem) => { - setData(connectionDataMap[dataItem.key], dataItem.value); + const onSubmit = () => + handleSubmit(async () => { + const { registeredModel } = await registerModel(apiState, formData, author); + navigate(registeredModelUrl(registeredModel.id, mrName)); }); - }; + const onCancel = () => navigate(modelRegistryUrl(mrName)); return ( { breadcrumb={ Model registry - {mrName}} + render={() => Model registry - {mrName}} /> Register model @@ -115,22 +60,8 @@ const RegisterModel: React.FC = () => {
- - - - + + { /> - - - setData('versionName', value)} - /> - - -