Skip to content

Commit 1e736d1

Browse files
committed
feat: extract import source from AST, add scope tree tests
1 parent 1b149a6 commit 1e736d1

File tree

7 files changed

+803
-56
lines changed

7 files changed

+803
-56
lines changed

src/context/index.ts

Lines changed: 5 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -99,61 +99,15 @@ export const getEntitiesInRange = (
9999
}
100100

101101
/**
102-
* Parse import source from an import entity
102+
* Get import source from an import entity
103103
*
104-
* Extracts the source module path from import signatures like:
105-
* - `import { foo } from 'module'`
106-
* - `import foo from 'module'`
107-
* - `import * as foo from 'module'`
104+
* Uses the pre-extracted source from AST parsing (works for all languages).
108105
*
109106
* @param entity - The import entity
110107
* @returns The import source or empty string if not found
111108
*/
112-
const parseImportSource = (entity: ExtractedEntity): string => {
113-
// Try to extract from signature using regex
114-
// Common patterns: from 'source' or from "source"
115-
const fromMatch = entity.signature.match(/from\s+['"]([^'"]+)['"]/)
116-
if (fromMatch?.[1]) {
117-
return fromMatch[1]
118-
}
119-
120-
// For CommonJS style: require('source')
121-
const requireMatch = entity.signature.match(/require\s*\(\s*['"]([^'"]+)['"]/)
122-
if (requireMatch?.[1]) {
123-
return requireMatch[1]
124-
}
125-
126-
return ''
127-
}
128-
129-
/**
130-
* Check if an import is a default import
131-
*
132-
* @param entity - The import entity
133-
* @returns Whether this is a default import
134-
*/
135-
const isDefaultImport = (entity: ExtractedEntity): boolean => {
136-
// Default import patterns:
137-
// import foo from 'module'
138-
// But NOT: import { foo } from 'module'
139-
// And NOT: import * as foo from 'module'
140-
const signature = entity.signature
141-
return (
142-
/^import\s+\w+\s+from/.test(signature) &&
143-
!/^import\s*\{/.test(signature) &&
144-
!/^import\s*\*/.test(signature)
145-
)
146-
}
147-
148-
/**
149-
* Check if an import is a namespace import
150-
*
151-
* @param entity - The import entity
152-
* @returns Whether this is a namespace import
153-
*/
154-
const isNamespaceImport = (entity: ExtractedEntity): boolean => {
155-
// Namespace import pattern: import * as foo from 'module'
156-
return /^import\s*\*\s*as\s+\w+/.test(entity.signature)
109+
const getImportSource = (entity: ExtractedEntity): string => {
110+
return entity.source ?? ''
157111
}
158112

159113
/**
@@ -178,9 +132,7 @@ export const getRelevantImports = (
178132
// Map import entity to ImportInfo
179133
const mapToImportInfo = (entity: ExtractedEntity): ImportInfo => ({
180134
name: entity.name,
181-
source: parseImportSource(entity),
182-
isDefault: isDefaultImport(entity) || undefined,
183-
isNamespace: isNamespaceImport(entity) || undefined,
135+
source: getImportSource(entity),
184136
})
185137

186138
// If not filtering, return all imports

src/extract/fallback.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import type {
66
SyntaxNode,
77
} from '../types'
88
import { extractDocstring } from './docstring'
9-
import { extractName, extractSignature } from './signature'
9+
import { extractImportSource, extractName, extractSignature } from './signature'
1010

1111
/**
1212
* Node types that represent extractable entities by language
@@ -176,6 +176,12 @@ function walkAndExtract(
176176
// Extract docstring
177177
const docstring = yield* extractDocstring(node, language, code)
178178

179+
// Extract import source for import entities
180+
const source =
181+
entityType === 'import'
182+
? (extractImportSource(node, language) ?? undefined)
183+
: undefined
184+
179185
// Create entity
180186
const entity: ExtractedEntity = {
181187
type: entityType,
@@ -192,6 +198,7 @@ function walkAndExtract(
192198
},
193199
parent: parentName,
194200
node,
201+
source,
195202
}
196203

197204
entities.push(entity)

src/extract/index.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import {
1212
getEntityType,
1313
} from './fallback'
1414
import { type CompiledQuery, loadQuery, loadQuerySync } from './queries'
15-
import { extractName, extractSignature } from './signature'
15+
import { extractImportSource, extractName, extractSignature } from './signature'
1616

1717
/**
1818
* Error when entity extraction fails
@@ -168,6 +168,12 @@ function matchesToEntities(
168168
// Find parent entity
169169
const parent = findParentEntityName(itemNode, rootNode, language)
170170

171+
// Extract import source for import entities
172+
const source =
173+
entityType === 'import'
174+
? (extractImportSource(itemNode, language) ?? undefined)
175+
: undefined
176+
171177
const entity: ExtractedEntity = {
172178
type: entityType,
173179
name,
@@ -183,6 +189,7 @@ function matchesToEntities(
183189
},
184190
parent,
185191
node: itemNode,
192+
source,
186193
}
187194

188195
entities.push(entity)
@@ -359,4 +366,4 @@ export {
359366
} from './fallback'
360367
export type { CompiledQuery, QueryLoadError } from './queries'
361368
export { clearQueryCache, loadQuery, loadQuerySync } from './queries'
362-
export { extractName, extractSignature } from './signature'
369+
export { extractImportSource, extractName, extractSignature } from './signature'

src/extract/signature.ts

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,3 +356,187 @@ export const extractSignature = (
356356
export const getBodyDelimiter = (language: Language): string => {
357357
return BODY_DELIMITERS[language]
358358
}
359+
360+
/**
361+
* Node types that represent import source/path by language
362+
*/
363+
const IMPORT_SOURCE_NODE_TYPES: readonly string[] = [
364+
'string',
365+
'string_literal',
366+
'interpreted_string_literal', // Go
367+
'source', // Some grammars use this field name
368+
]
369+
370+
/**
371+
* Extract the import source path from an import AST node
372+
*
373+
* Works for all supported languages by looking at the AST structure:
374+
* - JS/TS: import { foo } from 'source' -> string child
375+
* - Python: from source import foo -> 'module_name' field or dotted_name
376+
* - Rust: use crate::module::item -> scoped_identifier or path
377+
* - Go: import "source" -> interpreted_string_literal
378+
* - Java: import package.Class -> scoped_identifier
379+
*
380+
* @param node - The import AST node
381+
* @param language - The programming language
382+
* @returns The import source path, or null if not found
383+
*/
384+
export const extractImportSource = (
385+
node: SyntaxNode,
386+
language: Language,
387+
): string | null => {
388+
// Try the 'source' field first (common in many grammars)
389+
const sourceField = node.childForFieldName('source')
390+
if (sourceField) {
391+
return stripQuotes(sourceField.text)
392+
}
393+
394+
// Language-specific extraction
395+
switch (language) {
396+
case 'typescript':
397+
case 'javascript': {
398+
// Look for string literal child (the 'from "..."' part)
399+
for (const child of node.children) {
400+
if (child.type === 'string') {
401+
return stripQuotes(child.text)
402+
}
403+
}
404+
break
405+
}
406+
407+
case 'python': {
408+
// For 'from X import Y', look for module_name field or dotted_name
409+
const moduleNameField = node.childForFieldName('module_name')
410+
if (moduleNameField) {
411+
return moduleNameField.text
412+
}
413+
// For 'import X' style
414+
const nameField = node.childForFieldName('name')
415+
if (nameField) {
416+
return nameField.text
417+
}
418+
// Fallback: look for dotted_name
419+
for (const child of node.children) {
420+
if (child.type === 'dotted_name') {
421+
return child.text
422+
}
423+
}
424+
break
425+
}
426+
427+
case 'rust': {
428+
// For 'use path::to::item', extract the path
429+
// Look for scoped_identifier, use_wildcard, use_list, or identifier
430+
const argumentField = node.childForFieldName('argument')
431+
if (argumentField) {
432+
// Get the path part (everything except the last segment if it's a use_list)
433+
return extractRustUsePath(argumentField)
434+
}
435+
// Fallback: look for children that could be paths
436+
for (const child of node.children) {
437+
if (
438+
child.type === 'scoped_identifier' ||
439+
child.type === 'identifier' ||
440+
child.type === 'use_wildcard'
441+
) {
442+
return extractRustUsePath(child)
443+
}
444+
}
445+
break
446+
}
447+
448+
case 'go': {
449+
// For 'import "path"', look for import_spec or interpreted_string_literal
450+
for (const child of node.children) {
451+
// Single import: import "fmt" -> has import_spec child
452+
if (child.type === 'import_spec') {
453+
const pathNode = child.childForFieldName('path')
454+
if (pathNode) {
455+
return stripQuotes(pathNode.text)
456+
}
457+
// Fallback: look for string literal in import_spec
458+
for (const specChild of child.children) {
459+
if (specChild.type === 'interpreted_string_literal') {
460+
return stripQuotes(specChild.text)
461+
}
462+
}
463+
}
464+
// Direct string literal (some Go grammars)
465+
if (child.type === 'interpreted_string_literal') {
466+
return stripQuotes(child.text)
467+
}
468+
// For import blocks: import ( "fmt" "os" )
469+
if (child.type === 'import_spec_list') {
470+
for (const spec of child.children) {
471+
if (spec.type === 'import_spec') {
472+
const pathNode = spec.childForFieldName('path')
473+
if (pathNode) {
474+
return stripQuotes(pathNode.text)
475+
}
476+
}
477+
}
478+
}
479+
}
480+
break
481+
}
482+
483+
case 'java': {
484+
// For 'import package.Class', look for scoped_identifier
485+
for (const child of node.children) {
486+
if (child.type === 'scoped_identifier') {
487+
return child.text
488+
}
489+
}
490+
break
491+
}
492+
}
493+
494+
// Fallback: look for any string-like child
495+
for (const child of node.children) {
496+
if (IMPORT_SOURCE_NODE_TYPES.includes(child.type)) {
497+
return stripQuotes(child.text)
498+
}
499+
}
500+
501+
return null
502+
}
503+
504+
/**
505+
* Extract the path from a Rust use declaration
506+
* For 'std::collections::HashMap', returns 'std::collections::HashMap'
507+
* For 'std::collections::{HashMap, HashSet}', returns 'std::collections'
508+
*/
509+
const extractRustUsePath = (node: SyntaxNode): string => {
510+
// If it's a use_list (e.g., {HashMap, HashSet}), get the parent path
511+
if (node.type === 'use_list') {
512+
return ''
513+
}
514+
515+
// For scoped_identifier, check if the last part is a use_list
516+
if (node.type === 'scoped_identifier') {
517+
const lastChild = node.children[node.children.length - 1]
518+
if (lastChild?.type === 'use_list') {
519+
// Return everything except the use_list
520+
const pathChild = node.childForFieldName('path')
521+
if (pathChild) {
522+
return pathChild.text
523+
}
524+
}
525+
}
526+
527+
return node.text
528+
}
529+
530+
/**
531+
* Strip surrounding quotes from a string
532+
*/
533+
const stripQuotes = (str: string): string => {
534+
if (
535+
(str.startsWith('"') && str.endsWith('"')) ||
536+
(str.startsWith("'") && str.endsWith("'")) ||
537+
(str.startsWith('`') && str.endsWith('`'))
538+
) {
539+
return str.slice(1, -1)
540+
}
541+
return str
542+
}

src/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ export interface ExtractedEntity {
100100
parent: string | null
101101
/** The underlying AST node */
102102
node: SyntaxNode
103+
/** Import source path (only for import entities) */
104+
source?: string
103105
}
104106

105107
/**

0 commit comments

Comments
 (0)