UDF中调用rpcx

161 阅读2分钟

在udf中使用rpcx调用rpc服务

构造rpcx客户端,通过参数传递相关参数,就可以在hive中调用rpc了。(注意使用分析场景是否适用)。(同理这个模式也可以写一个支持http调用的udf)

  • 服务双方使用rpcx通讯,避开grpc的沉重语法,但是java客户端较旧。
  • 返回是 Map<String , String>,获取值比较方便

使用

select rpcx_get('ip:端口','服务名称','方法名称', id,100);
{"1":"dac","2":"fjdsoi"}

select rpcx_get('ip:端口','服务名称','方法名称', id,100)['1'];
"dac"

代码

import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.Text;
import org.json.JSONObject;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import com.colobu.rpcx.client.*;
import com.colobu.rpcx.protocol.*;

public abstract class BaseRpcxUDF extends GenericUDF {
    private transient Client client;
    private int seq = 0 ;

    @Override
    public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
        // 使用MapObjectInspector来返回Map类型
        // return ObjectInspectorFactory.getReflectionObjectInspector(Map.class, ObjectInspectorFactory.ObjectInspectorOptions.JAVA);
        return ObjectInspectorFactory.getStandardMapObjectInspector(
                PrimitiveObjectInspectorFactory.javaStringObjectInspector, // Key 是 String
                PrimitiveObjectInspectorFactory.javaStringObjectInspector // Value 是 String
        );
    }

    public abstract String before(DeferredObject[] arguments) throws HiveException;

    public abstract String after(DeferredObject[] arguments, String result) throws HiveException;

    @Override
    public Object evaluate(DeferredObject[] arguments) throws HiveException {
        if (arguments[0] == null || arguments[1] == null || arguments[2] == null) {
            return null;
        }
        String host = arguments[0].get().toString();
        String srv = arguments[1].get().toString();
        int port = 0;
        // int port = getJavaInt(arguments[2]);
        String method = arguments[2].get().toString();
        String beforeResult = before(arguments);

        if (beforeResult == null) return new Text("err");
        if (client == null ){
            String[] parts = host.split(":", 2);
            if (parts.length == 2) {
                host = parts[0];
                port = Integer.parseInt(parts[1]);
            } else {
                System.out.println("invalid rpc port");
            }
            try {
                client = new Client();
                client.connect(host, port);
                System.out.println("new client");

            }catch (IOException e) {
                throw new HiveException("连接失败: " + e.getMessage());
            }
        }

        // 构建请求消息
        Message req = new Message(srv,method );
        req.setVersion((byte)0);
        req.setMessageType(MessageType.Request);
        req.setHeartbeat(false);
        req.setOneway(false);
        req.setCompressType(CompressType.None);
        req.setSerializeType(SerializeType.JSON); // 设置序列化类型
        req.setSeq(this.seq);
        this.seq  = this.seq + 1;

        String responseString = "";
        try {
            req.payload = beforeResult.getBytes("UTF-8"); // 设置请求的payload
            // 调用远程方法
            Message res = client.call(req);

            // 处理响应
            if (res != null) {
                responseString = new String(res.payload);
                // System.out.println("Response: " + responseString);
            } else {
                //System.out.println("No response received");
                responseString = "No response received";
            }

        } catch (IOException e) {
            throw new HiveException("请求失败: " + e.getMessage());
        } finally {
            String afterResult = after(arguments, responseString);
            JSONObject jsonResponse = new JSONObject(afterResult);
            JSONObject dataObject = jsonResponse.optJSONObject("data");
            Map<String, String> resultMap = convertJsonToMap(dataObject);
            return resultMap;
        }
    }

    @Override
    public String getDisplayString(String[] children) {
        return "rpcx_request(" + String.join(", ", children) + ")";
    }

    private Map<String, String> convertJsonToMap(JSONObject jsonObject) {
        if (jsonObject == null) {
            return null;
        }
        // System.out.println(jsonObject);

        Map<String, String> map = new HashMap<>();
        for (String key : jsonObject.keySet()) {
            Object value = jsonObject.get(key);
            // 只处理String类型的值,如果需要处理其他类型,可以在这里添加逻辑
            if (value instanceof String) {
                map.put(key, (String) value);
            }
        }
        return map;
    }
}