소스 검색

add tars local auth process

shevqko 4 년 전
부모
커밋
e676634b0d

+ 24 - 0
servant/libservant/AdapterProxy.cpp

@@ -492,6 +492,30 @@ void AdapterProxy::finishInvoke(ResponsePacket & rsp)
     TLOGINFO("[TARS][AdapterProxy::finishInvoke(ResponsePacket) objname:" << _objectProxy->name() << ",desc:" << _endpoint.desc() 
         << ",id:" << rsp.iRequestId << endl);
 
+    if (_trans->getAuthState() != AUTH_SUCC)
+    {
+        std::string ret(rsp.sBuffer.begin(), rsp.sBuffer.end());
+        tars::AUTH_STATE tmp = AUTH_SUCC;
+        tars::stoe(ret, tmp);
+        int newstate = tmp;
+
+        TLOGINFO("[TARS]AdapterProxy::finishInvoke from state " << _trans->getAuthState() << " to " << newstate << endl);
+        _trans->setAuthState(newstate);
+
+        if (newstate == AUTH_SUCC)
+        {
+            // flush old buffered msg when auth is not complete
+            doInvoke();
+        }
+        else
+        {
+            TLOGERROR("newstate is " << newstate << ", error close!\n");
+            _trans->close();
+        }
+
+        return;
+    }
+
     ReqMessage * msg = NULL;
 
     //requestid 为0 是push消息

+ 12 - 1
servant/libservant/Application.cpp

@@ -24,6 +24,7 @@
 #include "servant/BaseF.h"
 #include "servant/AppCache.h"
 #include "servant/NotifyObserver.h"
+#include "servant/AuthLogic.h"
 
 #include <signal.h>
 #include <sys/resource.h>
@@ -1128,7 +1129,17 @@ void Application::bindAdapter(vector<TC_EpollServer::BindAdapterPtr>& adapters)
             ServantHelperManager::getInstance()->setAdapterServant(adapterName[i], servant);
 
             TC_EpollServer::BindAdapterPtr bindAdapter = new TC_EpollServer::BindAdapter(_epollServer.get());
-               
+
+            // 设置该obj的鉴权账号密码,只要一组就够了
+            {
+                std::string accKey = _conf.get("/tars/application/server/" + adapterName[i] + "<accesskey>");
+                std::string secretKey = _conf.get("/tars/application/server/" + adapterName[i] + "<secretkey>");
+
+                if (!accKey.empty())
+                    bindAdapter->setAkSk(accKey, secretKey);
+
+                bindAdapter->setAuthProcessWrapper(&tars::processAuth);
+            }
 
             string sLastPath = "/tars/application/server/" + adapterName[i];
 

+ 255 - 0
servant/libservant/AuthLogic.cpp

@@ -0,0 +1,255 @@
+/**
+ * Tencent is pleased to support the open source community by making Tars available.
+ *
+ * Copyright (C) 2016THL A29 Limited, a Tencent company. All rights reserved.
+ *
+ * Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 
+ * in compliance with the License. You may obtain a copy of the License at
+ *
+ * https://opensource.org/licenses/BSD-3-Clause
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed 
+ * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 
+ * CONDITIONS OF ANY KIND, either express or implied. See the License for the 
+ * specific language governing permissions and limitations under the License.
+ */
+
+#include "util/tc_epoll_server.h"
+#include "util/tc_tea.h"
+#include "util/tc_sha.h"
+#include "util/tc_md5.h"
+#include "servant/Application.h"
+#include "servant/AuthLogic.h"
+#include <iostream>
+#include <cassert>
+
+
+namespace tars
+{
+
+bool processAuth(void* c, const string& data)
+{
+    TC_EpollServer::NetThread::Connection* const conn = (TC_EpollServer::NetThread::Connection*)c;
+    conn->tryInitAuthState(AUTH_INIT);
+
+    if (conn->_authState == AUTH_SUCC)
+        return false; // data to be processed
+
+    TC_EpollServer::BindAdapter* adapter = conn->getBindAdapter();
+
+    const int type = adapter->getEndpoint().getAuthType();
+    if (type == AUTH_TYPENONE)
+    {
+        adapter->getEpollServer()->info("no auth func, so eAuthSucc");
+        conn->_authState = AUTH_SUCC;
+        return false;
+    }
+
+    // got auth request
+    RequestPacket request;
+    if (adapter->isTarsProtocol())
+    {
+        TarsInputStream<BufferReader> is;
+        is.setBuffer(data.data(), data.size());
+        try {
+            request.readFrom(is);
+             ostringstream oos;
+             request.display(oos);
+        }
+        catch(...) {
+            conn->setClose();
+            return true;
+        }
+    }
+    else
+    {
+        request.sBuffer.assign(data.begin(), data.end());
+    }
+
+    const int currentState = conn->_authState;
+    int newstate = tars::defaultProcessAuthReq(request.sBuffer.data(), request.sBuffer.size(), adapter->getName());
+    std::string out = tars::etos((tars::AUTH_STATE)newstate);
+
+    if (newstate < 0)
+    {
+        // 验证错误,断开连接
+        adapter->getEpollServer()->error("authProcess failed with new state [" + out + "]");
+        conn->setClose();
+        return true;
+    }
+
+    adapter->getEpollServer()->info(TC_Common::tostr(conn->getId()) + "'s auth response[" + out + "], change state from " +
+                                    TC_Common::tostr(currentState) + " to " + out);
+    conn->_authState = newstate;
+
+    if (adapter->isTarsProtocol())
+    {
+        TarsOutputStream<BufferWriter> os;
+        ResponsePacket response;
+        response.iVersion = TARSVERSION;
+        response.iRequestId = request.iRequestId;
+        response.iMessageType = request.iMessageType;
+        response.cPacketType = request.cPacketType;
+        response.iRet = 0;
+        response.sBuffer.assign(out.begin(), out.end());
+
+        response.writeTo(os);
+
+        tars::Int32 iHeaderLen = htonl(sizeof(tars::Int32) + os.getLength());
+
+        std::string s;
+        s.append((const char*)&iHeaderLen, sizeof(tars::Int32));
+        s.append(os.getBuffer(), os.getLength());
+
+        adapter->getEpollServer()->send(conn->getId(), s, "", 0, conn->getfd());
+    }
+    else
+    {
+        adapter->getEpollServer()->send(conn->getId(), out, "", 0, conn->getfd());
+    }
+
+    return true; // processed
+}
+
+
+int processAuthReqHelper(const BasicAuthPackage& pkg, const BasicAuthInfo& info)
+{
+    // 明文:objName, accessKey, time, hashMethod
+    // 密文:use TmpKey to enc secret1;
+    // and tmpKey = sha1(secret2 | timestamp);
+    if (pkg.sObjName != info.sObjName)
+        return AUTH_WRONG_OBJ;
+
+    if (pkg.sAccessKey != info.sAccessKey)
+        return AUTH_WRONG_AK;
+
+    time_t now = TNOW;
+    const int range = 60 * 60;
+    if (!(pkg.iTime > (now - range) && pkg.iTime < (now + range)))
+        return AUTH_WRONG_TIME;
+
+    if (pkg.sHashMethod != "sha1")
+        return AUTH_NOT_SUPPORT_ENC;
+
+    // 用secret1 = sha1(password); secret2 = sha1(secret1);
+    // 1.client create TmpKey use timestamp and secret2;
+    // 2.client use TmpKey to enc secret1;
+    // 3.server use TmpKey same as client, to dec secret1;
+    // 4.server got secret1, then sha1(secret1), to compare secret2;
+    // 下面这个是123456的两次sha1值
+    //assert (info.sHashSecretKey2 == "69c5fcebaa65b560eaf06c3fbeb481ae44b8d618");
+
+    string tmpKey;
+    string hash2;
+    {
+        string hash1 = TC_SHA::sha1str(info.sHashSecretKey2.data(), info.sHashSecretKey2.size());
+        hash2 = TC_SHA::sha1str(hash1.data(), hash1.size());
+        string tmp = hash2;
+        const char* pt = (const char* )&pkg.iTime;
+        for (size_t i = 0; i < sizeof pkg.iTime; ++ i)
+        {
+            tmp[i] |= pt[i];
+        }
+
+        tmpKey = TC_MD5::md5bin(tmp);
+    }
+
+    string secret1;
+    {
+        vector<char> dec;
+        try
+        {
+            dec = TC_Tea::decrypt2(tmpKey.data(), pkg.sSignature.data(), pkg.sSignature.size());
+            secret1.assign(dec.begin(), dec.end());
+        }
+        catch (const TC_Tea_Exception& )
+        {
+            return AUTH_DEC_FAIL;
+        }
+    }
+
+    // 4.server got secret1, then sha1(secret1), to compare secret2;
+    string clientSecret2 = TC_SHA::sha1str(secret1.data(), secret1.size());
+    if (clientSecret2.size() != hash2.size() ||
+        !std::equal(clientSecret2.begin(), clientSecret2.end(), hash2.begin()))
+    {
+        return AUTH_ERROR;
+    }
+
+    return AUTH_SUCC;
+}
+
+// 只需要传入 expect 的objname;
+// 内部根据obj查找access账号集
+int defaultProcessAuthReq(const char* request, size_t len, const string& expectObj)
+{
+    if (len <= 20)
+        return AUTH_PROTO_ERR;
+
+    BasicAuthPackage pkg;
+    TarsInputStream<BufferReader> is;
+    is.setBuffer(request, len);
+    try {
+        pkg.readFrom(is);
+    }
+    catch(...) {
+        return AUTH_PROTO_ERR;
+    }
+
+    TC_EpollServer::BindAdapterPtr bap = Application::getEpollServer()->getBindAdapter(expectObj);
+    if (!bap)
+        return AUTH_WRONG_OBJ;
+
+    BasicAuthInfo info;
+    string expectServantName = ServantHelperManager::getInstance()->getAdapterServant(expectObj);
+    info.sObjName = expectServantName;
+    info.sAccessKey = pkg.sAccessKey;
+    info.sHashSecretKey2 = bap->getSk(info.sAccessKey);
+    if (info.sHashSecretKey2.empty())
+        return AUTH_WRONG_AK;
+
+    return processAuthReqHelper(pkg, info);
+}
+
+int defaultProcessAuthReq(const string& request, const string& expectObj)
+{
+    return defaultProcessAuthReq(request.data(), request.size(), expectObj);
+}
+
+string defaultCreateAuthReq(const BasicAuthInfo& info /*, const string& hashMethod*/ )
+{
+    // 明文:objName, accessKey, time, hashMethod
+    // 密文:use TmpKey to enc secret1;
+    TarsOutputStream<BufferWriter> os;
+    BasicAuthPackage pkg;
+    pkg.sObjName = info.sObjName;
+    pkg.sAccessKey = info.sAccessKey;
+    pkg.iTime = TNOW;
+
+    string secret1 = TC_SHA::sha1str(info.sSecretKey.data(), info.sSecretKey.size());
+    string secret2 = TC_SHA::sha1str(secret1.data(), secret1.size());
+
+    // create tmpKey
+    string tmpKey;
+    {
+        string tmp = secret2;
+        const char* pt = (const char* )&pkg.iTime;
+        for (size_t i = 0; i < sizeof pkg.iTime; ++ i)
+        {
+            tmp[i] |= pt[i];
+        }
+        // 保证key是16字节
+        tmpKey = TC_MD5::md5bin(tmp);
+    }
+
+    // then use tmpKey to enc secret1, show server that I know secret1, ie, I know secret.
+    vector<char> secret1Enc = TC_Tea::encrypt2(tmpKey.data(), secret1.data(), secret1.size());
+
+    pkg.sSignature.assign(secret1Enc.begin(), secret1Enc.end());
+    pkg.writeTo(os);
+
+    return string(os.getBuffer(), os.getLength());
+}
+
+} // end namespace tars
+

+ 62 - 7
servant/libservant/Transceiver.cpp

@@ -19,8 +19,8 @@
 #include "servant/AdapterProxy.h"
 #include "servant/Application.h"
 #include "servant/TarsLogger.h"
-//#include "servant/AuthLogic.h"
-#include "servant/Auth.h"
+#include "servant/AuthLogic.h"
+//#include "servant/Auth.h"
 
 #if TARS_SSL
 #include "util/tc_openssl.h"
@@ -196,10 +196,54 @@ void Transceiver::_onConnect()
 void Transceiver::_doAuthReq()
 {
     ObjectProxy* obj = _adapterProxy->getObjProxy();
-        
-    TLOGINFO("[TARS][_onConnect:" << obj->name() << " auth Type is " << _adapterProxy->endpoint().authType() << endl);
-    
-    _adapterProxy->doInvoke();
+
+    TLOGINFO("[TARS][_onConnect:" << obj->name() << " auth type is " << _adapterProxy->endpoint().authType() << endl);
+
+    if (_adapterProxy->endpoint().authType() == AUTH_TYPENONE)
+    {
+        _authState = AUTH_SUCC;
+        _adapterProxy->doInvoke();
+    }
+    else
+    {
+        BasicAuthInfo basic;
+        basic.sObjName = obj->name();
+        basic.sAccessKey = obj->getAccessKey();
+        basic.sSecretKey = obj->getSecretKey();
+
+        this->sendAuthData(basic);
+    }
+}
+
+bool Transceiver::sendAuthData(const BasicAuthInfo& info)
+{
+    assert (_authState != AUTH_SUCC);
+
+    ObjectProxy* objPrx = _adapterProxy->getObjProxy();
+
+    // 走框架的AK/SK认证
+    std::string out = tars::defaultCreateAuthReq(info);
+
+    const int kAuthType = 0x40;
+    RequestPacket request;
+    request.sFuncName = "tarsInnerAuthServer";
+    request.sServantName = "authServant";
+    request.iVersion = TARSVERSION;
+    request.iRequestId = 0;
+    request.cPacketType = TARSNORMAL;
+    request.iMessageType = kAuthType;
+    request.sBuffer.assign(out.begin(), out.end());
+
+    std::string toSend;
+    objPrx->getProxyProtocol().requestFunc(request, toSend);
+    if (sendRequest(toSend.data(), toSend.size(), true) == eRetError)
+    {
+        TLOGERROR("[TARS][Transceiver::setConnected failed sendRequest for Auth\n");
+        close();
+        return false;
+    }
+
+    return true;
 }
 
 void Transceiver::close()
@@ -290,7 +334,18 @@ int Transceiver::sendRequest(const char * pData, size_t iSize, bool forceSend)
     {
         return eRetError;
     }
-        
+
+    if (!forceSend && _authState != AUTH_SUCC)
+    {
+#if TARS_SSL
+        if (isSSL() && !_openssl)
+            return eRetError;
+#endif
+        ObjectProxy* obj = _adapterProxy->getObjProxy();
+        TLOGINFO("[TARS][Transceiver::sendRequest temporary failed because need auth for " << obj->name() << endl);
+        return eRetError; // 需要鉴权但还没通过,不能发送非认证消息
+    }
+
     //buf不为空,直接返回失败
     //等buffer可写了,epoll会通知写时间
     if(!_sendBuffer.IsEmpty())

+ 28 - 0
servant/servant/AuthLogic.h

@@ -0,0 +1,28 @@
+#include "servant/Auth.h"
+
+namespace tars
+{
+
+/**
+ * server :默认鉴权逻辑
+ */
+bool processAuth(void* c, const string& data);
+
+/**
+ * server :默认鉴权逻辑
+ */
+int processAuthReqHelper(const BasicAuthPackage& pkg, const BasicAuthInfo& info);
+
+/**
+ * server :默认鉴权方法
+ */
+int defaultProcessAuthReq(const char* request, size_t len, const string& expectObj);
+int defaultProcessAuthReq(const string& request, const string& expectObj);
+
+/**
+ * client:默认生成鉴权请求方法
+ */
+string defaultCreateAuthReq(const BasicAuthInfo& info /*, const string& hashMethod = "sha1" */ );
+
+} // end namespace tars
+

+ 15 - 0
servant/servant/Transceiver.h

@@ -20,6 +20,7 @@
 #include "servant/EndpointInfo.h"
 #include "servant/NetworkUtil.h"
 #include "servant/CommunicatorEpoll.h"
+#include "servant/AuthLogic.h"
 #include "util/tc_buffer.h"
 #include <list>
 #include <sys/uio.h>
@@ -205,6 +206,20 @@ public:
         _connStatus = eUnconnected; 
     }
 
+    /**
+     * 设置鉴权状态
+     */
+    void setAuthState(int newstate) { _authState = newstate; }
+
+    /*
+     * 获取鉴权状态
+     */
+    int getAuthState() const { return _authState; }
+
+    /*
+     * 发送鉴权数据
+     */
+    bool sendAuthData(const BasicAuthInfo& );
 protected:
     /** 
      ** 物理连接成功回调