Skip to content

Commit 1e5aac4

Browse files
mcastorina0x1
andauthored
Add Scan method to SourceManager to scan a single SourceUnit (#3650)
* renaming to enumeration * update enumeration * comments * remove commented out func * Add Scan method to SourceManager to scan a single SourceUnit * Add tests for each Enumerate and Scan * add source name to log * rename scanWithUnits * updating comments to be more clear --------- Co-authored-by: ahmed <[email protected]> Co-authored-by: 0x1 <[email protected]>
1 parent 1276d26 commit 1e5aac4

File tree

2 files changed

+170
-2
lines changed

2 files changed

+170
-2
lines changed

pkg/sources/source_manager.go

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,50 @@ func (s *SourceManager) Enumerate(ctx context.Context, sourceName string, source
212212
case s.firstErr <- err:
213213
default:
214214
}
215+
progress.ReportError(Fatal{err})
216+
}
217+
}()
218+
return progress.Ref(), nil
219+
}
220+
221+
// Scan blocks until a resource is available to run the source against a single
222+
// SourceUnit, then asynchronously runs it. Error information is stored and
223+
// accessible via the JobProgressRef as it becomes available.
224+
func (s *SourceManager) Scan(ctx context.Context, sourceName string, source Source, unit SourceUnit) (JobProgressRef, error) {
225+
sourceID, jobID := source.SourceID(), source.JobID()
226+
// Do preflight checks before waiting on the pool.
227+
if err := s.preflightChecks(ctx); err != nil {
228+
return JobProgressRef{
229+
SourceName: sourceName,
230+
SourceID: sourceID,
231+
JobID: jobID,
232+
}, err
233+
}
234+
// Create a JobProgress object for tracking progress.
235+
ctx, cancel := context.WithCancelCause(ctx)
236+
progress := NewJobProgress(jobID, sourceID, sourceName, WithHooks(s.hooks...), WithCancel(cancel))
237+
if err := s.sem.Acquire(ctx, 1); err != nil {
238+
// Context cancelled.
239+
progress.ReportError(Fatal{err})
240+
return progress.Ref(), Fatal{err}
241+
}
242+
s.wg.Add(1)
243+
go func() {
244+
// Call Finish after the semaphore has been released.
245+
defer progress.Finish()
246+
defer s.sem.Release(1)
247+
defer s.wg.Done()
248+
ctx := context.WithValues(ctx,
249+
"source_manager_worker_id", common.RandomID(5),
250+
)
251+
defer common.Recover(ctx)
252+
defer cancel(nil)
253+
if err := s.scan(ctx, source, progress, unit); err != nil {
254+
select {
255+
case s.firstErr <- err:
256+
default:
257+
}
258+
progress.ReportError(Fatal{err})
215259
}
216260
}()
217261
return progress.Ref(), nil
@@ -320,7 +364,7 @@ func (s *SourceManager) run(ctx context.Context, source Source, report *JobProgr
320364
ctx = context.WithValue(ctx, "source_type", source.Type().String())
321365
}
322366

323-
// Check for the preferred method of tracking source units.
367+
// Check if source units are supported and configured.
324368
canUseSourceUnits := len(targets) == 0 && s.useSourceUnitsFunc != nil
325369
if enumChunker, ok := source.(SourceUnitEnumChunker); ok && canUseSourceUnits && s.useSourceUnitsFunc() {
326370
ctx.Logger().Info("running source",
@@ -359,7 +403,7 @@ func (s *SourceManager) enumerate(ctx context.Context, source Source, report *Jo
359403
ctx = context.WithValue(ctx, "source_type", source.Type().String())
360404
}
361405

362-
// Check for the preferred method of tracking source units.
406+
// Check if source units are supported and configured.
363407
canUseSourceUnits := s.useSourceUnitsFunc != nil
364408
if enumChunker, ok := source.(SourceUnitEnumerator); ok && canUseSourceUnits && s.useSourceUnitsFunc() {
365409
ctx.Logger().Info("running source",
@@ -369,6 +413,42 @@ func (s *SourceManager) enumerate(ctx context.Context, source Source, report *Jo
369413
return fmt.Errorf("Enumeration not supported or configured for source: %s", source.Type().String())
370414
}
371415

416+
// scan runs a scan against a single SourceUnit as its only job. This method
417+
// manages the lifecycle of the provided report.
418+
func (s *SourceManager) scan(ctx context.Context, source Source, report *JobProgress, unit SourceUnit) error {
419+
report.Start(time.Now())
420+
defer func() { report.End(time.Now()) }()
421+
422+
defer func() {
423+
if err := context.Cause(ctx); err != nil {
424+
report.ReportError(Fatal{err})
425+
}
426+
}()
427+
428+
report.TrackProgress(source.GetProgress())
429+
if ctx.Value("job_id") == "" {
430+
ctx = context.WithValue(ctx, "job_id", report.JobID)
431+
}
432+
if ctx.Value("source_id") == "" {
433+
ctx = context.WithValue(ctx, "source_id", report.SourceID)
434+
}
435+
if ctx.Value("source_name") == "" {
436+
ctx = context.WithValue(ctx, "source_name", report.SourceName)
437+
}
438+
if ctx.Value("source_type") == "" {
439+
ctx = context.WithValue(ctx, "source_type", source.Type().String())
440+
}
441+
442+
// Check if source units are supported and configured.
443+
canUseSourceUnits := s.useSourceUnitsFunc != nil
444+
if unitChunker, ok := source.(SourceUnitChunker); ok && canUseSourceUnits && s.useSourceUnitsFunc() {
445+
ctx.Logger().Info("running source",
446+
"with_units", true)
447+
return s.scanWithUnit(ctx, unitChunker, report, unit)
448+
}
449+
return fmt.Errorf("source units not supported or configured for source: %s (%s)", report.SourceName, source.Type().String())
450+
}
451+
372452
// enumerateWithUnits is a helper method to enumerate a Source that is also a
373453
// SourceUnitEnumerator. This allows better introspection of what is getting
374454
// enumerated and any errors encountered.
@@ -511,6 +591,44 @@ func (s *SourceManager) runWithUnits(ctx context.Context, source SourceUnitEnumC
511591
}
512592
}
513593

594+
// scanWithUnit produces chunks from a single SourceUnit.
595+
func (s *SourceManager) scanWithUnit(ctx context.Context, source SourceUnitChunker, report *JobProgress, unit SourceUnit) error {
596+
// Create a function that will save the first error encountered (if
597+
// any) and discard the rest.
598+
chunkReporter := &mgrChunkReporter{
599+
unit: unit,
600+
chunkCh: make(chan *Chunk, defaultChannelSize),
601+
report: report,
602+
}
603+
// Produce chunks from the given unit.
604+
var chunkErr error
605+
go func() {
606+
report.StartUnitChunking(unit, time.Now())
607+
// TODO: Catch panics and add to report.
608+
defer close(chunkReporter.chunkCh)
609+
id, kind := unit.SourceUnitID()
610+
ctx := context.WithValues(ctx, "unit_kind", kind, "unit", id)
611+
ctx.Logger().V(3).Info("chunking unit")
612+
if err := source.ChunkUnit(ctx, unit, chunkReporter); err != nil {
613+
report.ReportError(Fatal{ChunkError{Unit: unit, Err: err}})
614+
chunkErr = Fatal{err}
615+
}
616+
}()
617+
// Consume chunks and export chunks.
618+
// This anonymous function blocks until the chunkReporter.chunkCh is
619+
// closed in the above goroutine.
620+
func() {
621+
defer func() { report.EndUnitChunking(unit, time.Now()) }()
622+
for chunk := range chunkReporter.chunkCh {
623+
if src, ok := source.(Source); ok {
624+
chunk.JobID = src.JobID()
625+
}
626+
s.outputChunks <- chunk
627+
}
628+
}()
629+
return chunkErr
630+
}
631+
514632
// headlessAPI implements the apiClient interface locally.
515633
type headlessAPI struct {
516634
// Counters for assigning source and job IDs.

pkg/sources/source_manager_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,56 @@ func TestSourceManagerReport(t *testing.T) {
173173
}
174174
}
175175

176+
func TestSourceManagerEnumerate(t *testing.T) {
177+
mgr := NewManager(WithBufferedOutput(8), WithSourceUnits())
178+
source, err := buildDummy(&counterChunker{count: 1})
179+
assert.NoError(t, err)
180+
var enumeratedUnits []SourceUnit
181+
reporter := visitorUnitReporter{
182+
ok: func(_ context.Context, unit SourceUnit) error {
183+
enumeratedUnits = append(enumeratedUnits, unit)
184+
return nil
185+
},
186+
}
187+
for i := 0; i < 3; i++ {
188+
ref, err := mgr.Enumerate(context.Background(), "dummy", source, reporter)
189+
<-ref.Done()
190+
assert.NoError(t, err)
191+
assert.NoError(t, ref.Snapshot().FatalError())
192+
// The Chunks channel should be empty because we only enumerated.
193+
_, err = tryRead(mgr.Chunks())
194+
assert.Error(t, err)
195+
// Each time the loop iterates, we add 1 unit to the slice.
196+
assert.Equal(t, i+1, len(enumeratedUnits), ref.Snapshot())
197+
}
198+
}
199+
200+
func TestSourceManagerScan(t *testing.T) {
201+
mgr := NewManager(WithBufferedOutput(8), WithSourceUnits())
202+
source, err := buildDummy(&counterChunker{count: 1})
203+
assert.NoError(t, err)
204+
for i := 0; i < 3; i++ {
205+
ref, err := mgr.Scan(context.Background(), "dummy", source, countChunk(123))
206+
<-ref.Done()
207+
assert.NoError(t, err)
208+
assert.NoError(t, ref.Snapshot().FatalError())
209+
chunk, err := tryRead(mgr.Chunks())
210+
assert.NoError(t, err)
211+
assert.Equal(t, []byte{123}, chunk.Data)
212+
// The Chunks channel should be empty now.
213+
_, err = tryRead(mgr.Chunks())
214+
assert.Error(t, err)
215+
}
216+
}
217+
218+
type visitorUnitReporter struct {
219+
ok func(context.Context, SourceUnit) error
220+
err func(context.Context, error) error
221+
}
222+
223+
func (v visitorUnitReporter) UnitOk(ctx context.Context, u SourceUnit) error { return v.ok(ctx, u) }
224+
func (v visitorUnitReporter) UnitErr(ctx context.Context, err error) error { return v.err(ctx, err) }
225+
176226
type unitChunk struct {
177227
unit string
178228
output string

0 commit comments

Comments
 (0)