Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public String handle(Resource resource, DispatchInfo dispatchInfo) {
DAGSettings dagSettings = DAGSettings.builder()
.ignoreExist(false)
.dagMaxDepth(bizDConfs.getFlowDAGMaxDepth()).build();

olympicene.submit(executionId, dag, data, dagSettings, notifyInfo);
dagResourceStatistic.updateFlowTypeResourceStatus(parentDAGExecutionId, parentTaskName, resource.getResourceName(), dag);
ProfileActions.recordTinyDAGSubmit(executionId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public String handle(Resource resource, DispatchInfo dispatchInfo) {
int maxInvokeTime = switcherManagerImpl.getSwitcherState("ENABLE_FUNCTION_DISPATCH_RET_CHECK") ? 2 : 1;
HttpMethod method = Optional.ofNullable(requestType).map(String::toUpperCase).map(HttpMethod::resolve).orElse(HttpMethod.POST);
HttpEntity<?> requestEntity = buildHttpEntity(method, header, requestParams);

String ret = httpInvokeHelper.invokeRequest(executionId, taskInfoName, url, requestEntity, method, maxInvokeTime);
dagResourceStatistic.updateUrlTypeResourceStatus(executionId, taskInfoName, resource.getResourceName(), ret);
return ret;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
/*
* Copyright 2021-2023 Weibo, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.weibo.rill.flow.service.dispatcher

import com.alibaba.fastjson.JSON
import com.weibo.rill.flow.interfaces.model.resource.Resource
import com.weibo.rill.flow.interfaces.model.strategy.DispatchInfo
import com.weibo.rill.flow.interfaces.model.task.FunctionPattern
import com.weibo.rill.flow.interfaces.model.task.FunctionTask
import com.weibo.rill.flow.interfaces.model.task.TaskInfo
import com.weibo.rill.flow.olympicene.core.model.dag.DAG
import com.weibo.rill.flow.olympicene.traversal.Olympicene
import com.weibo.rill.flow.service.dconfs.BizDConfs
import com.weibo.rill.flow.service.service.DAGDescriptorService
import com.weibo.rill.flow.service.statistic.DAGResourceStatistic
import spock.lang.Specification

class FlowProtocolDispatcherTest extends Specification {
FlowProtocolDispatcher dispatcher
DAGDescriptorService dagDescriptorService
BizDConfs bizDConfs
Olympicene olympicene
DAGResourceStatistic dagResourceStatistic

def setup() {
dispatcher = new FlowProtocolDispatcher()
dagDescriptorService = Mock(DAGDescriptorService)
bizDConfs = Mock(BizDConfs)
olympicene = Mock(Olympicene)
dagResourceStatistic = Mock(DAGResourceStatistic)

dispatcher.dagDescriptorService = dagDescriptorService
dispatcher.bizDConfs = bizDConfs
dispatcher.olympicene = olympicene
dispatcher.dagResourceStatistic = dagResourceStatistic
}

def "test handle method with valid input"() {
given:
def resource = Mock(Resource) {
getSchemeValue() >> "test-scheme"
getResourceName() >> "test-resource"
}
def taskInfo = Mock(TaskInfo) {
getName() >> "test-task"
getTask() >> Mock(FunctionTask) {
getPattern() >> FunctionPattern.FLOW_ASYNC
}
}
def dispatchInfo = Mock(DispatchInfo) {
getExecutionId() >> "parent-execution-id"
getTaskInfo() >> taskInfo
getInput() >> ["uid": "123", "key": "value"]
}
def dag = Mock(DAG)

when:
def result = dispatcher.handle(resource, dispatchInfo)

then:
1 * bizDConfs.getFlowDAGMaxDepth() >> 10
1 * dagDescriptorService.getDAG(123L, _, "test-scheme") >> dag
1 * olympicene.submit(_, dag, _, _, _)
1 * dagResourceStatistic.updateFlowTypeResourceStatus("parent-execution-id", "test-task", "test-resource", dag)

and:
def jsonResult = JSON.parseObject(result)
jsonResult.containsKey("execution_id")
jsonResult.get("execution_id") != null
}

def "test handle method with null input map"() {
given:
def resource = Mock(Resource) {
getSchemeValue() >> "test-scheme"
getResourceName() >> "test-resource"
}
def taskInfo = Mock(TaskInfo) {
getName() >> "test-task"
getTask() >> Mock(FunctionTask) {
getPattern() >> FunctionPattern.FLOW_ASYNC
}
}
def dispatchInfo = Mock(DispatchInfo) {
getExecutionId() >> "parent-execution-id"
getTaskInfo() >> taskInfo
getInput() >> null
}
def dag = Mock(DAG)

when:
def result = dispatcher.handle(resource, dispatchInfo)

then:
1 * bizDConfs.getFlowDAGMaxDepth() >> 10
1 * dagDescriptorService.getDAG(0L, _, "test-scheme") >> dag
1 * olympicene.submit(_, dag, _, _, _)
1 * dagResourceStatistic.updateFlowTypeResourceStatus("parent-execution-id", "test-task", "test-resource", dag)

and:
def jsonResult = JSON.parseObject(result)
jsonResult.containsKey("execution_id")
jsonResult.get("execution_id") != null
}

def "test handle method with invalid uid"() {
given:
def resource = Mock(Resource) {
getSchemeValue() >> "test-scheme"
getResourceName() >> "test-resource"
}
def taskInfo = Mock(TaskInfo) {
getName() >> "test-task"
getTask() >> Mock(FunctionTask) {
getPattern() >> FunctionPattern.FLOW_ASYNC
}
}
def dispatchInfo = Mock(DispatchInfo) {
getExecutionId() >> "parent-execution-id"
getTaskInfo() >> taskInfo
getInput() >> ["uid": null, "key": "value"]
}
def dag = Mock(DAG)

when:
def result = dispatcher.handle(resource, dispatchInfo)

then:
1 * bizDConfs.getFlowDAGMaxDepth() >> 10
1 * dagDescriptorService.getDAG(0L, _, "test-scheme") >> dag
1 * olympicene.submit(_, dag, _, _, _)
1 * dagResourceStatistic.updateFlowTypeResourceStatus("parent-execution-id", "test-task", "test-resource", dag)

and:
def jsonResult = JSON.parseObject(result)
jsonResult.containsKey("execution_id")
jsonResult.get("execution_id") != null
}

def "test handle method with non-numeric uid"() {
given:
def resource = Mock(Resource) {
getSchemeValue() >> "test-scheme"
getResourceName() >> "test-resource"
}
def taskInfo = Mock(TaskInfo) {
getName() >> "test-task"
getTask() >> Mock(FunctionTask) {
getPattern() >> FunctionPattern.FLOW_ASYNC
}
}
def dispatchInfo = Mock(DispatchInfo) {
getExecutionId() >> "parent-execution-id"
getTaskInfo() >> taskInfo
getInput() >> ["uid": "not-a-number", "key": "value"]
}
def dag = Mock(DAG)

when:
def result = dispatcher.handle(resource, dispatchInfo)

then:
thrown(NumberFormatException)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,169 @@

package com.weibo.rill.flow.service.dispatcher

import com.weibo.rill.flow.common.exception.TaskException
import com.weibo.rill.flow.interfaces.model.http.HttpParameter
import com.weibo.rill.flow.interfaces.model.resource.Resource
import com.weibo.rill.flow.interfaces.model.strategy.DispatchInfo
import com.weibo.rill.flow.interfaces.model.task.FunctionTask
import com.weibo.rill.flow.interfaces.model.task.TaskInfo
import com.weibo.rill.flow.olympicene.core.switcher.SwitcherManager
import com.weibo.rill.flow.service.invoke.HttpInvokeHelper
import com.weibo.rill.flow.service.statistic.DAGResourceStatistic
import org.springframework.http.HttpEntity
import org.springframework.http.HttpHeaders
import org.springframework.http.HttpMethod
import org.springframework.http.MediaType
import org.springframework.util.LinkedMultiValueMap
import org.springframework.util.MultiValueMap
import org.springframework.web.client.RestClientResponseException
import spock.lang.Specification
import spock.lang.Subject

class FunctionProtocolDispatcherTest extends Specification {
FunctionProtocolDispatcher dispatcher = new FunctionProtocolDispatcher();
@Subject
FunctionProtocolDispatcher dispatcher

def "buildHttpEntity test"() {
HttpInvokeHelper httpInvokeHelper
DAGResourceStatistic dagResourceStatistic
SwitcherManager switcherManager

def setup() {
httpInvokeHelper = Mock(HttpInvokeHelper)
dagResourceStatistic = Mock(DAGResourceStatistic)
switcherManager = Mock(SwitcherManager)
dispatcher = new FunctionProtocolDispatcher(
httpInvokeHelper: httpInvokeHelper,
dagResourceStatistic: dagResourceStatistic,
switcherManagerImpl: switcherManager
)
}

def "should handle POST request successfully"() {
given:
def executionId = "exec-123"
def taskName = "testTask"
def resource = Mock(Resource)
def input = [key: "value"]
def taskInfo = new TaskInfo(name: taskName, task: new FunctionTask(taskName, null, null, "function", null, false, null, null, null, null, null, null, null, null, null, null, null, null, "POST", false, null, null, null, null, null, null))
def headers = new LinkedMultiValueMap<String, String>()
def dispatchInfo = Mock(DispatchInfo) {
getExecutionId() >> executionId
getInput() >> input
getTaskInfo() >> taskInfo
getHeaders() >> headers
}
def requestParams = Mock(HttpParameter) {
getHeader() >> [contentType: MediaType.APPLICATION_JSON_VALUE]
}
def url = "http://test.com/api"
def expectedResponse = '{"status": "success"}'

when:
def result = dispatcher.handle(resource, dispatchInfo)

then:
1 * switcherManager.getSwitcherState("ENABLE_FUNCTION_DISPATCH_RET_CHECK") >> false
1 * httpInvokeHelper.functionRequestParams(executionId, taskName, resource, input) >> requestParams
1 * httpInvokeHelper.buildUrl(resource, requestParams.queryParams) >> url
1 * httpInvokeHelper.invokeRequest(executionId, taskName, url, _ as HttpEntity, HttpMethod.POST, 1) >> expectedResponse
1 * dagResourceStatistic.updateUrlTypeResourceStatus(executionId, taskName, _, expectedResponse)
result == expectedResponse
}

def "should handle GET request successfully"() {
given:
def httpParameter = HttpParameter.builder()
.header(inputHeader)
.body(inputBody)
.build()
MultiValueMap<String, String> header = new LinkedMultiValueMap<>()
Optional.ofNullable(httpParameter.getHeader())
.ifPresent { it -> it.forEach { key, value -> header.add(key, value) } }
def executionId = "exec-123"
def taskName = "testTask"
def resource = Mock(Resource)
def input = [key: "value"]
def taskInfo = new TaskInfo(name: taskName, task: new FunctionTask(taskName, null, null, "function", null, false, null, null, null, null, null, null, null, null, null, null, null, null, "GET", false, null, null, null, null, null, null))
def headers = new LinkedMultiValueMap<String, String>()
def dispatchInfo = Mock(DispatchInfo) {
getExecutionId() >> executionId
getInput() >> input
getTaskInfo() >> taskInfo
getHeaders() >> headers
}
def requestParams = Mock(HttpParameter)
def url = "http://test.com/api"
def expectedResponse = '{"status": "success"}'

when:
def httpEntity = dispatcher.buildHttpEntity(method, header, httpParameter)
def result = dispatcher.handle(resource, dispatchInfo)

then:
httpEntity.body == body

where:
method | inputHeader | inputBody | body
null | [:] | [:] | null
HttpMethod.GET | [:] | [:] | null
HttpMethod.POST | [:] | [:] | [:]
HttpMethod.POST | [:] | [k: "v", user: [name: "Bob"]] | [k: "v", user: [name: "Bob"]]
HttpMethod.POST | ["Content-Type": MediaType.APPLICATION_JSON_VALUE] | [k: "v", user: [name: "Bob"]] | [k: "v", user: [name: "Bob"]]
HttpMethod.POST | ["Content-Type": MediaType.APPLICATION_FORM_URLENCODED_VALUE] | [k: "v", name: "Bob"] | [k: ["v"], name: ["Bob"]]
1 * switcherManager.getSwitcherState("ENABLE_FUNCTION_DISPATCH_RET_CHECK") >> false
1 * httpInvokeHelper.functionRequestParams(executionId, taskName, resource, input) >> requestParams
1 * httpInvokeHelper.buildUrl(resource, requestParams.queryParams) >> url
1 * httpInvokeHelper.invokeRequest(executionId, taskName, url, _ as HttpEntity, HttpMethod.GET, 1) >> expectedResponse
1 * dagResourceStatistic.updateUrlTypeResourceStatus(executionId, taskName, _, expectedResponse)
result == expectedResponse
}

def "should handle error response correctly"() {
given:
def executionId = "exec-123"
def taskName = "testTask"
def resource = Mock(Resource)
def input = [key: "value"]
def taskInfo = new TaskInfo(name: taskName, task: new FunctionTask(taskName, null, null, "function", null, false, null, null, null, null, null, null, null, null, null, null, null, null, "POST", false, null, null, null, null, null, null))
def headers = new LinkedMultiValueMap<String, String>()
def dispatchInfo = Mock(DispatchInfo) {
getExecutionId() >> executionId
getInput() >> input
getTaskInfo() >> taskInfo
getHeaders() >> headers
}
def requestParams = Mock(HttpParameter)
def url = "http://test.com/api"
def errorResponse = "Error occurred"
def exception = Mock(RestClientResponseException) {
getResponseBodyAsString() >> errorResponse
getRawStatusCode() >> 500
}

when:
dispatcher.handle(resource, dispatchInfo)

then:
1 * switcherManager.getSwitcherState("ENABLE_FUNCTION_DISPATCH_RET_CHECK") >> false
1 * httpInvokeHelper.functionRequestParams(executionId, taskName, resource, input) >> requestParams
1 * httpInvokeHelper.buildUrl(resource, requestParams.queryParams) >> url
1 * httpInvokeHelper.invokeRequest(executionId, taskName, url, _ as HttpEntity, HttpMethod.POST, 1) >> { throw exception }
1 * dagResourceStatistic.updateUrlTypeResourceStatus(executionId, taskName, _, errorResponse)
thrown(TaskException)
}

def "should handle form-urlencoded POST request"() {
given:
def executionId = "exec-123"
def taskName = "testTask"
def resource = Mock(Resource)
def input = [key: "value"]
def taskInfo = new TaskInfo(name: taskName, task: new FunctionTask(taskName, null, null, "function", null, false, null, null, null, null, null, null, null, null, null, null, null, null, "POST", false, null, null, null, null, null, null))
def headers = new LinkedMultiValueMap<String, String>()
headers.add(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE)
def dispatchInfo = Mock(DispatchInfo) {
getExecutionId() >> executionId
getInput() >> input
getTaskInfo() >> taskInfo
getHeaders() >> headers
}
def requestParams = Mock(HttpParameter) {
getBody() >> [stringParam: "test", mapParam: [key: "value"], listParam: ["item1"]]
}
def url = "http://test.com/api"
def expectedResponse = "success"

when:
def result = dispatcher.handle(resource, dispatchInfo)

then:
1 * switcherManager.getSwitcherState("ENABLE_FUNCTION_DISPATCH_RET_CHECK") >> false
1 * httpInvokeHelper.functionRequestParams(executionId, taskName, resource, input) >> requestParams
1 * httpInvokeHelper.buildUrl(resource, requestParams.queryParams) >> url
1 * httpInvokeHelper.invokeRequest(executionId, taskName, url, _ as HttpEntity, HttpMethod.POST, 1) >> expectedResponse
1 * dagResourceStatistic.updateUrlTypeResourceStatus(executionId, taskName, _, expectedResponse)
result == expectedResponse
}
}
Loading