feat: refactor add command

This commit is contained in:
shadcn
2026-02-16 14:15:24 +04:00
parent d8e5d0d4f1
commit 028b1b2d93
8 changed files with 32 additions and 43 deletions

View File

@@ -1,14 +1,12 @@
import path from "path"
import {
getTemplateFromFrameworkName,
runInit,
} from "@/src/commands/init"
import { runInit } from "@/src/commands/init"
import { preFlightAdd } from "@/src/preflights/preflight-add"
import { getRegistryItems, getShadcnRegistryIndex } from "@/src/registry/api"
import { DEPRECATED_COMPONENTS } from "@/src/registry/constants"
import { clearRegistryContext } from "@/src/registry/context"
import { registryItemTypeSchema } from "@/src/registry/schema"
import { isUniversalRegistryItem } from "@/src/registry/utils"
import { getTemplateForFramework } from "@/src/templates/index"
import { addComponents } from "@/src/utils/add-components"
import { createProject } from "@/src/utils/create-project"
import { loadEnvFiles } from "@/src/utils/env-loader"
@@ -16,12 +14,9 @@ import * as ERRORS from "@/src/utils/errors"
import { createConfig, getConfig } from "@/src/utils/get-config"
import { getProjectInfo } from "@/src/utils/get-project-info"
import { handleError } from "@/src/utils/handle-error"
import {
promptForPreset,
resolveRegistryBaseConfig,
} from "@/src/utils/presets"
import { highlighter } from "@/src/utils/highlighter"
import { logger } from "@/src/utils/logger"
import { promptForPreset, resolveRegistryBaseConfig } from "@/src/utils/presets"
import { ensureRegistriesInConfig } from "@/src/utils/registries"
import { updateAppIndex } from "@/src/utils/update-app-index"
import { Command } from "commander"
@@ -167,7 +162,7 @@ export const add = new Command()
}
// Infer template from project framework.
const inferredTemplate = getTemplateFromFrameworkName(
const inferredTemplate = getTemplateForFramework(
projectInfo?.framework.name
)

View File

@@ -1,15 +1,12 @@
import { promises as fs } from "fs"
import path from "path"
import { preFlightInit } from "@/src/preflights/preflight-init"
import {
getRegistryBaseColors,
getRegistryStyles,
} from "@/src/registry/api"
import { getRegistryBaseColors, getRegistryStyles } from "@/src/registry/api"
import { BUILTIN_REGISTRIES } from "@/src/registry/constants"
import { clearRegistryContext } from "@/src/registry/context"
import { isUrl } from "@/src/registry/utils"
import { rawConfigSchema } from "@/src/schema"
import { templates } from "@/src/templates/index"
import { getTemplateForFramework, templates } from "@/src/templates/index"
import { addComponents } from "@/src/utils/add-components"
import { createProject } from "@/src/utils/create-project"
import { loadEnvFiles } from "@/src/utils/env-loader"
@@ -25,7 +22,6 @@ import {
DEFAULT_TAILWIND_CONFIG,
DEFAULT_TAILWIND_CSS,
DEFAULT_UTILS,
createConfig,
getConfig,
resolveConfigPaths,
type Config,
@@ -193,7 +189,7 @@ export const init = new Command()
// Try to infer template for existing projects.
if (!options.template && hasPackageJson) {
const projectInfo = await getProjectInfo(cwd)
const detectedTemplate = getTemplateFromFrameworkName(
const detectedTemplate = getTemplateForFramework(
projectInfo?.framework.name
)
if (detectedTemplate) {
@@ -619,19 +615,3 @@ async function promptForMinimalConfig(
aliases: defaultConfig?.aliases,
})
}
export function getTemplateFromFrameworkName(frameworkName?: string) {
if (frameworkName === "next-app" || frameworkName === "next-pages") {
return "next"
}
if (frameworkName === "vite") {
return "vite"
}
if (frameworkName === "tanstack-start" || frameworkName === "react-router") {
return "start"
}
return undefined
}

View File

@@ -20,6 +20,8 @@ export function createTemplate(config: {
name: string
title: string
defaultProjectName: string
// Framework names that map to this template.
frameworks?: string[]
scaffold: (options: TemplateOptions) => Promise<void>
create: (options: TemplateOptions) => Promise<void>
init?: (options: TemplateInitOptions) => Promise<Config>
@@ -28,6 +30,7 @@ export function createTemplate(config: {
}) {
return {
...config,
frameworks: config.frameworks ?? [],
postInit: config.postInit ?? defaultPostInit,
}
}
@@ -38,7 +41,7 @@ async function defaultPostInit({ projectPath }: { projectPath: string }) {
try {
await execa("git", ["init"], { cwd: projectPath })
await execa("git", ["add", "-A"], { cwd: projectPath })
await execa("git", ["commit", "-m", "Initial commit"], {
await execa("git", ["commit", "-m", "feat: initial commit"], {
cwd: projectPath,
})
} catch {}

View File

@@ -12,3 +12,18 @@ export const templates = {
start,
"next-monorepo": nextMonorepo,
}
// Resolve a template key from a detected framework name.
export function getTemplateForFramework(frameworkName?: string) {
if (!frameworkName) {
return undefined
}
for (const [key, template] of Object.entries(templates)) {
if (template.frameworks.includes(frameworkName)) {
return key
}
}
return undefined
}

View File

@@ -12,6 +12,7 @@ export const next = createTemplate({
name: "next",
title: "Next.js",
defaultProjectName: "next-app",
frameworks: ["next-app", "next-pages"],
scaffold: async ({ projectPath, packageManager }) => {
const createSpinner = spinner(
`Creating a new Next.js project. This may take a few minutes.`

View File

@@ -12,6 +12,7 @@ export const start = createTemplate({
name: "start",
title: "TanStack Start",
defaultProjectName: "start-app",
frameworks: ["tanstack-start"],
scaffold: async ({ projectPath, packageManager }) => {
const createSpinner = spinner(
`Creating a new TanStack Start project. This may take a few minutes.`

View File

@@ -12,6 +12,7 @@ export const vite = createTemplate({
name: "vite",
title: "Vite",
defaultProjectName: "vite-app",
frameworks: ["vite"],
scaffold: async ({ projectPath, packageManager }) => {
const createSpinner = spinner(
`Creating a new Vite project. This may take a few minutes.`

View File

@@ -124,9 +124,7 @@ export async function promptForPreset(options: {
...(options.template && { template: options.template }),
})
logger.break()
logger.log(
` Build your custom preset on ${highlighter.info(createUrl)}`
)
logger.log(` Build your custom preset on ${highlighter.info(createUrl)}`)
logger.log(
` Then ${highlighter.info(
"copy and run the command"
@@ -156,10 +154,7 @@ export async function promptForPreset(options: {
return resolveInitUrl({ ...preset, rtl: options.rtl })
}
export async function resolveRegistryBaseConfig(
initUrl: string,
cwd: string
) {
export async function resolveRegistryBaseConfig(initUrl: string, cwd: string) {
// Use a shadow config to fetch the registry:base item.
let shadowConfig = configWithDefaults(
createConfig({
@@ -189,9 +184,7 @@ export async function resolveRegistryBaseConfig(
return {
registryBaseConfig:
item?.type === "registry:base" && item.config
? item.config
: undefined,
item?.type === "registry:base" && item.config ? item.config : undefined,
installStyleIndex: item?.extends !== "none",
}
}